o
    VhT                     @   s   d Z ddlZddlZddlZddlmZ ddlZddlZddl	Z	ddl
mZmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZmZmZ dd	lmZmZ dd
lm Z  ddl!m"Z"m#Z# dZ$G dd dZ%dS )ar  
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.

Usage - sources:
    $ yolo mode=predict model=yolo11n.pt source=0                               # webcam
                                                img.jpg                         # image
                                                vid.mp4                         # video
                                                screen                          # screenshot
                                                path/                           # directory
                                                list.txt                        # list of images
                                                list.streams                    # list of streams
                                                'path/*.jpg'                    # glob
                                                'https://youtu.be/LNwODJXcvt4'  # YouTube
                                                'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP, TCP stream

Usage - formats:
    $ yolo mode=predict model=yolo11n.pt                 # PyTorch
                              yolo11n.torchscript        # TorchScript
                              yolo11n.onnx               # ONNX Runtime or OpenCV DNN with dnn=True
                              yolo11n_openvino_model     # OpenVINO
                              yolo11n.engine             # TensorRT
                              yolo11n.mlpackage          # CoreML (macOS-only)
                              yolo11n_saved_model        # TensorFlow SavedModel
                              yolo11n.pb                 # TensorFlow GraphDef
                              yolo11n.tflite             # TensorFlow Lite
                              yolo11n_edgetpu.tflite     # TensorFlow Edge TPU
                              yolo11n_paddle_model       # PaddlePaddle
                              yolo11n.mnn                # MNN
                              yolo11n_ncnn_model         # NCNN
                              yolo11n_imx_model          # Sony IMX
                              yolo11n_rknn_model         # Rockchip RKNN
    N)Path)get_cfgget_save_dir)load_inference_source)	LetterBoxclassify_transforms)AutoBackend)DEFAULT_CFGLOGGERMACOSWINDOWS	callbackscolorstrops)check_imgszcheck_imshow)increment_path)select_devicesmart_inference_modea  
inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.

Example:
    results = model(source=..., stream=True)  # generator of Results objects
    for r in results:
        boxes = r.boxes  # Boxes object for bbox outputs
        masks = r.masks  # Masks object for segment masks outputs
        probs = r.probs  # Class probabilities for classification outputs
c                   @   s   e Zd ZdZeddfddZdd Zdd Zd	d
 Zdd Z	d&ddZ
d'ddZdd Ze d'ddZd(ddZdd Zd)ddZd*dd Zd!efd"d#Zd!efd$d%ZdS )+BasePredictora  
    A base class for creating predictors.

    This class provides the foundation for prediction functionality, handling model setup, inference,
    and result processing across various input sources.

    Attributes:
        args (SimpleNamespace): Configuration for the predictor.
        save_dir (Path): Directory to save results.
        done_warmup (bool): Whether the predictor has finished setup.
        model (torch.nn.Module): Model used for prediction.
        data (dict): Data configuration.
        device (torch.device): Device used for prediction.
        dataset (Dataset): Dataset used for prediction.
        vid_writer (dict): Dictionary of {save_path: video_writer} for saving video output.
        plotted_img (numpy.ndarray): Last plotted image.
        source_type (SimpleNamespace): Type of input source.
        seen (int): Number of images processed.
        windows (list): List of window names for visualization.
        batch (tuple): Current batch data.
        results (list): Current batch results.
        transforms (callable): Image transforms for classification.
        callbacks (dict): Callback functions for different events.
        txt_path (Path): Path to save text results.
        _lock (threading.Lock): Lock for thread-safe inference.

    Methods:
        preprocess: Prepare input image before inference.
        inference: Run inference on a given image.
        postprocess: Process raw predictions into structured results.
        predict_cli: Run prediction for command line interface.
        setup_source: Set up input source and inference mode.
        stream_inference: Stream inference on input source.
        setup_model: Initialize and configure the model.
        write_results: Write inference results to files.
        save_predicted_images: Save prediction visualizations.
        show: Display results in a window.
        run_callbacks: Execute registered callbacks for an event.
        add_callback: Register a new callback function.
    Nc                 C   s   t ||| _t| j| _| jjdu rd| j_d| _| jjr$tdd| j_d| _| jj	| _	d| _
d| _d| _i | _d| _d| _d| _g | _d| _d| _d| _|pRt | _d| _t | _t|  dS )a$  
        Initialize the BasePredictor class.

        Args:
            cfg (str | dict): Path to a configuration file or a configuration dictionary.
            overrides (dict | None): Configuration overrides.
            _callbacks (dict | None): Dictionary of callback functions.
        Ng      ?FT)warnr   )r   argsr   save_dirconfdone_warmupshowr   modeldataimgszdevicedataset
vid_writerplotted_imgsource_typeseenwindowsbatchresults
transformsr   get_default_callbackstxt_path	threadingLock_lockadd_integration_callbacks)selfcfg	overrides
_callbacks r3   P/var/www/vscode/kcb/lib/python3.10/site-packages/ultralytics/engine/predictor.py__init__l   s0   	

zBasePredictor.__init__c                 C   s   t |tj }|r0t| |}|jd dkr!|ddddf }|d}t|}t	|}|
| j}| jjr>| n| }|rH|d }|S )z
        Prepares input image before inference.

        Args:
            im (torch.Tensor | List(np.ndarray)): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.
           .N)r   r7            )
isinstancetorchTensornpstackpre_transformshape	transposeascontiguousarray
from_numpytor   r   fp16halffloat)r/   im
not_tensorr3   r3   r4   
preprocess   s   


zBasePredictor.preprocessc                 O   s^   | j jr| jjst| jt| jd d j ddnd}| j	|g|R | j j
|| j jd|S )zGRun inference on a given image using the specified model and arguments.r   T)mkdirF)augment	visualizeembed)r   rN   r#   tensorr   r   r   r&   stemr   rM   rO   )r/   rI   r   kwargsrN   r3   r3   r4   	inference   s   $(zBasePredictor.inferencec                    sd   t dd |D dk}t| j|o#| jjo#| jjp#t| jddo#| jj | jj	d  fdd|D S )	a  
        Pre-transform input image before inference.

        Args:
            im (List[np.ndarray]): Images of shape (N, 3, h, w) for tensor, [(h, w, 3) x N] for list.

        Returns:
            (List[np.ndarray]): A list of transformed images.
        c                 S   s   h | ]}|j qS r3   )rA   .0xr3   r3   r4   	<setcomp>   s    z.BasePredictor.pre_transform.<locals>.<setcomp>r8   dynamicF)autostridec                    s   g | ]} |d qS ))imager3   rT   	letterboxr3   r4   
<listcomp>   s    z/BasePredictor.pre_transform.<locals>.<listcomp>)
lenr   r   r   rectr   ptgetattrimxrZ   )r/   rI   same_shapesr3   r\   r4   r@      s   
zBasePredictor.pre_transformc                 C   s   |S )z6Post-process predictions for an image and return them.r3   )r/   predsimg	orig_imgsr3   r3   r4   postprocess   s   zBasePredictor.postprocessFc                 O   sB   || _ |r| j||g|R i |S t| j||g|R i |S )a  
        Perform inference on an image or stream.

        Args:
            source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
                Source for inference.
            model (str | Path | torch.nn.Module | None): Model for inference.
            stream (bool): Whether to stream the inference results. If True, returns a generator.
            *args (Any): Additional arguments for the inference method.
            **kwargs (Any): Additional keyword arguments for the inference method.

        Returns:
            (List[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
        )streamstream_inferencelist)r/   sourcer   ri   r   rR   r3   r3   r4   __call__   s   zBasePredictor.__call__c                 C   s   |  ||}|D ]}qdS )a>  
        Method used for Command Line Interface (CLI) prediction.

        This function is designed to run predictions using the CLI. It sets up the source and model, then processes
        the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
        generator without storing results.

        Args:
            source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
                Source for inference.
            model (str | Path | torch.nn.Module | None): Model for inference.

        Note:
            Do not modify this function or remove the generator. The generator ensures that no outputs are
            accumulated in memory, which is critical for preventing memory issues during long-running predictions.
        N)rj   )r/   rl   r   gen_r3   r3   r4   predict_cli   s   zBasePredictor.predict_clic                 C   s   t | jj| jjdd| _| jjdkrt| jjdt| jd nd| _t	|| jj
| jj| jjd| _| jj| _t| dd	sZ| jjsU| jjsUt| jd
ksUtt| jddgrZtt i | _dS )z
        Set up source and inference mode.

        Args:
            source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor):
                Source for inference.
        r9   )rZ   min_dimclassifyr(   r   N)rl   r&   
vid_stridebufferri   Ti  
video_flagF)r   r   r   r   rZ   taskrb   r   r(   r   r&   rs   stream_bufferr    r#   ri   
screenshotr_   anyr
   warningSTREAM_WARNINGr!   )r/   rl   r3   r3   r4   setup_source   s4   	


zBasePredictor.setup_sourcec                 /   s,    j jr
td  js |  ji  |dur|n j j  j j	s, j j
r= j j
r5 jd n jjddd  js^ jj jjsK jjrMdn jj jjg jR d d _dg d _ _ _tj jd	tj jd	tj jd	f} d
  jD ] _ d  j\}}}|d   |}	W d   n1 sw   Y  |d .  j|	g|R i |}
 j jrt|
t j!r|
gn|
E dH  	 W d   qW d   n1 sw   Y  |d   "|
|	| _#W d   n1 sw   Y   d t$|}t%|D ]O}  jd7  _|d j&d | |d j&d | |d j&d | d j#| _' j jsK j j	sK j j
sK j j(r]||   )|t*|| |	|7  < q j jrltd+|  d  j#E dH  qW d   n	1 sw   Y   j,- D ]}t|t.j/r|0  q j jrψ jrt1 fdd|D }tdt2 j j jt3 jddg|	j4dd R  |   j j	sވ j j
sވ j j5rt$t6 j7d} j j
rd| dd|dk  d jd  nd}tdt8d j |   d dS )a7  
        Stream real-time inference on camera feed and save results to file.

        Args:
            source (str | Path | List[str] | List[Path] | List[np.ndarray] | np.ndarray | torch.Tensor | None):
                Source for inference.
            model (str | Path | torch.nn.Module | None): Model for inference.
            *args (Any): Additional arguments for the inference method.
            **kwargs (Any): Additional keyword arguments for the inference method.

        Yields:
            (ultralytics.engine.results.Results): Results objects.
         NlabelsTparentsexist_okr8   )r   r   )r   on_predict_starton_predict_batch_startr9   on_predict_postprocess_end     @@)rK   rS   rh   
on_predict_batch_endc                 3   s     | ]}|j  j d  V  qdS )r   N)tr$   rT   r/   r3   r4   	<genexpr>m  s    z1BasePredictor.stream_inference.<locals>.<genexpr>zRSpeed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape chr7   zlabels/*.txtz labelsz
 saved to zResults saved to boldon_predict_end)9r   verboser
   infor   setup_modelr-   r|   rl   savesave_txtr   rL   r   warmupra   tritonr    bsr   r   r$   r%   r&   r   Profiler   run_callbacksrK   rS   rO   r;   r<   r=   rh   r'   r_   rangedtspeedr   write_resultsr   joinr!   valuescv2VideoWriterreleasetupleminrb   rA   	save_croprk   globr   )r/   rl   r   r   rR   	profilerspathsim0sr   rI   re   nivr   nlr3   r   r4   rj     s   


"*





($

>,2zBasePredictor.stream_inferenceTc              
   C   sb   t |p| jjt| jj|d| jj| jj| jj| jjd|d| _| jj| _| jj	| j_| j
  dS )z
        Initialize YOLO model with given parameters and set it to evaluation mode.

        Args:
            model (str | Path | torch.nn.Module | None): Model to load or use.
            verbose (bool): Whether to print verbose output.
        )r   T)weightsr   dnnr   rF   r&   fuser   N)r   r   r   r   r   r   r   rG   r&   rF   eval)r/   r   r   r3   r3   r4   r   x  s   

zBasePredictor.setup_modelc           	      C   s  d}t |jdkr|d }| jjs| jjs| jjr%|| d7 }| jj}nt	d|| }|r5t
|d nd}| jd |j| jjdkrEdnd	|   | _|d
j|jdd  7 }| j| }| j |_||  |jd dd7 }| jjs{| jjr|j| jj| jj| jj| jj| jjrdn|| d| _| jjr|j| j d| jjd | jj r|j | jd | jjd | jjr| t!| | jjr| "t!| j|j# | |S )aq  
        Write inference results to a file or directory.

        Args:
            i (int): Index of the current image in the batch.
            p (Path): Path to the current image.
            im (torch.Tensor): Preprocessed image tensor.
            s (List[str]): List of result strings.

        Returns:
            (str): String with result information.
        r}   r7   Nz: zframe (\d+)/r8   r~   r[   ro   z
{:g}x{:g} r9   rS   z.1fms)
line_widthboxesr   r~   im_gpuz.txt)	save_confcrops)r   	file_name)$r_   rA   r#   ri   from_imgrP   r    countresearchintr   rQ   moder*   formatr'   __str__r   r   r   r   r   plotr   
show_boxes	show_confshow_labelsretina_masksr"   r   r   r   strsave_predicted_imagesname)	r/   r   prI   r   stringframematchresultr3   r3   r4   r     s<   
,
	zBasePredictor.write_resultsr}   r   c                 C   s  | j }| jjdv rw| jjdkr| jjnd}|ddd  d}|| jvr\| jjr2t|j	ddd	 t
r6d
ntr:dnd\}}tjtt||tj| ||jd |jd fd| j|< | j| | | jjrut| | d| dS dS ttt|d| dS )z
        Save video predictions as mp4 or images as jpg at specified path.

        Args:
            save_path (str): Path to save the results.
            frame (int): Frame number for video mode.
        >   videori   r      .r8   r   z_frames/Tr   )z.mp4avc1).aviWMV2)r   MJPG)filenamefourccfps	frameSizez.jpgN)r"   r    r   r   splitr!   r   save_framesr   rL   r   r   r   r   r   with_suffixVideoWriter_fourccrA   writeimwrite)r/   	save_pathr   rI   r   frames_pathsuffixr   r3   r3   r4   r     s&   
z#BasePredictor.save_predicted_imagesc                 C   s   | j }t dkr+|| jvr+| j| t|tjtjB  t	||j
d |j
d  t|| t| jjdkr>d dS d dS )zDisplay an image in a window.Linuxr8   r   r[   i,  N)r"   platformsystemr%   appendr   namedWindowWINDOW_NORMALWINDOW_KEEPRATIOresizeWindowrA   imshowwaitKeyr    r   )r/   r   rI   r3   r3   r4   r     s   $zBasePredictor.showeventc                 C   s    | j |g D ]}||  qdS )z2Run all registered callbacks for a specific event.N)r   get)r/   r   callbackr3   r3   r4   r     s   
zBasePredictor.run_callbacksc                 C   s   | j | | dS )z-Add a callback function for a specific event.N)r   r   )r/   r   funcr3   r3   r4   add_callback  s   zBasePredictor.add_callback)NNF)NN)T)r}   r   )r}   )__name__
__module____qualname____doc__r	   r5   rK   rS   r@   rh   rm   rp   r|   r   rj   r   r   r   r   r   r   r   r3   r3   r3   r4   r   B   s$    )$	

"
d
3
"
r   )&r   r   r   r+   pathlibr   r   numpyr>   r<   ultralytics.cfgr   r   ultralytics.datar   ultralytics.data.augmentr   r   ultralytics.nn.autobackendr   ultralytics.utilsr	   r
   r   r   r   r   r   ultralytics.utils.checksr   r   ultralytics.utils.filesr   ultralytics.utils.torch_utilsr   r   r{   r   r3   r3   r3   r4   <module>   s$   !$