o
    Whe                     @   sp  d dl Z d dlmZ d dlm  mZ d dlmZ d dlm	Z	m
Z
mZ d dlmZmZmZmZmZ d dlmZ ddlmZmZ ddlmZ G d	d
 d
ejZG dd dejZG dd dejZG dd dejZG dd deZG dd dejZG dd dZ G dd de Z!G dd de Z"G dd dZ#G dd de Z$G dd  d Z%G d!d" d"Z&G d#d$ d$e&Z'dS )%    N)	OKS_SIGMA)	crop_mask	xywh2xyxy	xyxy2xywh)RotatedTaskAlignedAssignerTaskAlignedAssigner	dist2bbox	dist2rboxmake_anchors)autocast   )bbox_iouprobiou)	bbox2distc                       *   e Zd ZdZd fdd	Zdd Z  ZS )	VarifocalLossa  
    Varifocal loss by Zhang et al.

    https://arxiv.org/abs/2008.13367.

    Args:
        gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
        alpha (float): The balancing factor used to address class imbalance.
           @      ?c                       t    || _|| _dS )z#Initialize the VarifocalLoss class.Nsuper__init__gammaalphaselfr   r   	__class__ J/var/www/vscode/kcb/lib/python3.10/site-packages/ultralytics/utils/loss.pyr         

zVarifocalLoss.__init__c                 C   s|   | j | | j d|  ||  }tdd tj| | dd| d	 }W d   |S 1 s7w   Y  |S )z<Compute varifocal loss between predictions and ground truth.r   F)enablednone	reductionN)
r   sigmoidpowr   r   F binary_cross_entropy_with_logitsfloatmeansum)r   
pred_scoregt_scorelabelweightlossr   r   r   forward!   s   &
zVarifocalLoss.forward)r   r   __name__
__module____qualname____doc__r   r1   __classcell__r   r   r   r   r      s    
r   c                       r   )		FocalLossa@  
    Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).

    Args:
        gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
        alpha (float): The balancing factor used to address class imbalance.
          ?      ?c                    r   )z.Initialize FocalLoss class with no parameters.Nr   r   r   r   r   r   6   r    zFocalLoss.__init__c                 C   s   t j||dd}| }|| d| d|   }d| | j }||9 }| jdkr:|| j d| d| j   }||9 }|d S )zACalculate focal loss with modulating factors for class imbalance.r"   r#   r         ?r   )r'   r(   r%   r   r   r*   r+   )r   predr.   r0   	pred_probp_tmodulating_factoralpha_factorr   r   r   r1   <   s   
zFocalLoss.forward)r9   r:   r2   r   r   r   r   r8   -   s    r8   c                       s,   e Zd ZdZd	d
 fddZdd Z  ZS )DFLossz<Criterion class for computing Distribution Focal Loss (DFL).   returnNc                       t    || _dS )z6Initialize the DFL module with regularization maximum.N)r   r   reg_maxr   rE   r   r   r   r   P      

zDFLoss.__init__c                 C   s   | d| jd d }| }|d }|| }d| }tj||ddd|j| tj||ddd|j|  jdddS )	zZReturn sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.r   r   g{Gz?r"   r#   Tkeepdim)clamp_rE   longr'   cross_entropyviewshaper*   )r   	pred_disttargettltrwlwrr   r   r   __call__U   s     zDFLoss.__call__rB   rC   N)r3   r4   r5   r6   r   rV   r7   r   r   r   r   rA   M   s    rA   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )BboxLosszACriterion class for computing training losses for bounding boxes.rB   c                    s*   t    |dkrt|| _dS d| _dS )LInitialize the BboxLoss module with regularization maximum and DFL settings.r   N)r   r   rA   dfl_lossrF   r   r   r   r   e   s   
 zBboxLoss.__init__c                 C   s   | d| d}t|| || ddd}	d|	 |   | }
| jrIt||| jjd }| || d| jj|| | }|  | }|
|fS td	|j
}|
|fS )z.Compute IoU and DFL losses for bounding boxes.rH   FT)xywhCIoUr;   r           )r+   	unsqueezer   r[   r   rE   rN   torchtensortodevicer   rP   pred_bboxesanchor_pointstarget_bboxestarget_scorestarget_scores_sumfg_maskr/   iouloss_ioutarget_ltrbloss_dflr   r   r   r1   j   s   $zBboxLoss.forwardrW   r2   r   r   r   r   rY   b       rY   c                       s(   e Zd ZdZ fddZdd Z  ZS )RotatedBboxLosszICriterion class for computing training losses for rotated bounding boxes.c                    s   t  | dS )rZ   N)r   r   rF   r   r   r   r   ~   s   zRotatedBboxLoss.__init__c                 C   s   | d| d}t|| || }	d|	 |   | }
| jrNt|t|dddf | jjd }| || d| jj|| | }|  | }|
|fS t	d
|j}|
|fS )z6Compute IoU and DFL losses for rotated bounding boxes.rH   r;   .N   r   r^   )r+   r_   r   r[   r   r   rE   rN   r`   ra   rb   rc   rd   r   r   r   r1      s   $$zRotatedBboxLoss.forwardr2   r   r   r   r   rp   {   s    rp   c                       s*   e Zd ZdZd fddZdd Z  ZS )	KeypointLossz.Criterion class for computing keypoint losses.rC   Nc                    rD   )z7Initialize the KeypointLoss class with keypoint sigmas.N)r   r   sigmas)r   rs   r   r   r   r      rG   zKeypointLoss.__init__c                 C   s   |d |d   d|d |d   d }|jd tj|dkddd  }|d| j  d|d  d  }|dddt|  |   S )	zICalculate keypoint loss factor and Euclidean distance loss for keypoints..r      .r   r   r   dimg&.>rH   )r&   rO   r`   r+   rs   rN   expr*   )r   	pred_kptsgt_kptskpt_maskareadkpt_loss_factorer   r   r   r1      s   ,  $zKeypointLoss.forwardrX   r2   r   r   r   r   rr      ro   rr   c                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )v8DetectionLosszJCriterion class for computing training losses for YOLOv8 object detection.
   c                 C   s   t | j}|j}|jd }tjdd| _|| _|j	| _	|j
| _
|j
|jd  | _|j| _|| _|jdk| _t|| j
ddd| _t|j|| _tj|jtj|d	| _d
S )zVInitialize v8DetectionLoss with model parameters and task-aligned assignment settings.rH   r"   r#   rq   r         ?      @topknum_classesr   beta)dtyperc   N)next
parametersrc   argsmodelnnBCEWithLogitsLossbcehypstridencrE   nouse_dflr   assignerrY   rb   	bbox_lossr`   aranger)   proj)r   r   tal_topkrc   hmr   r   r   r      s   
zv8DetectionLoss.__init__c                 C   s   |j \}}|dkrtj|d|d | jd}|S |dddf }|jdd\}}	|	jtjd}	tj||	 |d | jd}t|D ]}
||
k}|	  }rZ||ddf ||
d|f< q@t
|ddd	f ||ddd	f< |S )
zJPreprocess targets by converting to tensor format and scaling coordinates.r   r   rc   NTreturn_countsr   .   )rO   r`   zerosrc   uniquerb   int32maxranger+   r   mul_)r   targets
batch_sizescale_tensornlneouti_countsjmatchesnr   r   r   
preprocess   s   
&zv8DetectionLoss.preprocessc                 C   sJ   | j r|j\}}}|||d|d d| j|j}t||ddS )zUDecode predicted object bounding box coordinates from anchor points and distribution.rq      F)r\   )	r   rO   rN   softmaxmatmulr   typer   r   )r   rf   rP   bacr   r   r   bbox_decode   s   *zv8DetectionLoss.bbox_decodec              	      s,  t jdjd}t|tr|d n| t  fdd D djd jfd\}}|	ddd
 }|	ddd
 }|j}|jd }t j d jdd	 j|d
jd  }t jd\}	}
t |d dd|d dd|d fd}j|j||g d d}|dd\}}|jdddd}|	|}|  | |
 |j|	|
 |||\}}}}}t| d}||| | |d< | r||
 }|||	||||\|d< |d< |d  jj9  < |d  jj9  < |d  jj 9  < || | fS )LCalculate the sum of the loss for box, cls and dfl multiplied by batch size.r   r   r   c                    &   g | ]}|  d  jd  jdqS r   rH   rN   rO   r   .0xifeatsr   r   r   
<listcomp>      & z,v8DetectionLoss.__call__.<locals>.<listcomp>ru   rq   r   Nrc   r   r   	batch_idxrH   clsbboxesr   r   r   r   r   r   rq   TrI   r^   )!r`   r   rc   
isinstancetuplecatsplitrE   r   permute
contiguousr   rO   ra   r   r
   rN   r   rb   r+   gt_r   r   detachr%   r   r   r   r   r   boxr   dfl)r   predsbatchr0   pred_distripred_scoresr   r   imgszrf   stride_tensorr   	gt_labels	gt_bboxesmask_gtre   r   rg   rh   rj   ri   r   r   r   rV      sF   
*. 

zv8DetectionLoss.__call__N)r   )r3   r4   r5   r6   r   r   r   rV   r   r   r   r   r      s    
	r   c                       s   e Zd ZdZ fddZdd Zedejdejdejd	ejd
ejdejfddZ	dejdejdejdejdejdejdejdejde
dejfddZ  ZS )v8SegmentationLosszFCriterion class for computing training losses for YOLOv8 segmentation.c                    s   t  | |jj| _dS )zWInitialize the v8SegmentationLoss class with model parameters and mask overlap setting.N)r   r   r   overlap_maskoverlapr   r   r   r   r   r     s   zv8SegmentationLoss.__init__c                    s  t jdjd}t|dkr|n|d \ }}|j\}}}}	t  fdd D djd jfd\}
}|	ddd
 }|
	ddd
 }
|	ddd
 }|j}t j d jdd	 j|d
jd  }t jd\}}z=|d dd}t ||d dd|d fd}j|j||g d d}|dd\}}|jdddd}W n ty } ztd|d	}~ww ||
}|  | | |j|| |||\}}}}}t| d}||| | |d< | rK|
|||| |||\|d< |d< |d j }t|jdd	 ||	fkr:t j!|d	 ||	fddd }"||||||||j#	|d< n|d  |d  |d   7  < |d  j$j%9  < |d  j$j%9  < |d  j$j&9  < |d  j$j'9  < || | fS )zFCalculate and return the combined loss for detection and segmentation.rq   r   r   r   c                    r   r   r   r   r   r   r   r     r   z/v8SegmentationLoss.__call__.<locals>.<listcomp>ru   r   Nr   r   r   rH   r   r   r   r   r   TrI   r^   u  ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.
This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.
Verify your dataset is a correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' as an example.
See https://docs.ultralytics.com/datasets/segment/ for help.masksnearest)mode)(r`   r   rc   lenrO   r   r   rE   r   r   r   r   ra   r   r
   rN   r   rb   r+   r   RuntimeError	TypeErrorr   r   r   r%   r   r   r   r   r)   r   r'   interpolatecalculate_segmentation_lossr   r   r   r   r   )r   r   r   r0   
pred_masksprotor   r   mask_hmask_wr   r   r   r   rf   r   r   r   r   r   r   r   re   rg   rh   rj   target_gt_idxri   r   r   r   r   rV     sv   *" 

	


$zv8SegmentationLoss.__call__gt_maskr<   r   xyxyr}   rC   c                 C   s8   t d||}tj|| dd}t||jdd|  S )aX  
        Compute the instance segmentation loss for a single image.

        Args:
            gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
            pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
            proto (torch.Tensor): Prototype masks of shape (32, H, W).
            xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
            area (torch.Tensor): Area of each ground truth bounding box of shape (n,).

        Returns:
            (torch.Tensor): The calculated mask loss for a single image.

        Notes:
            The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
            predicted masks from the prototype masks and predicted mask coefficients.
        zin,nhw->ihwr"   r#   )r   ru   rw   )r`   einsumr'   r(   r   r*   r+   )r   r<   r   r   r}   	pred_maskr0   r   r   r   single_mask_lossi  s   z#v8SegmentationLoss.single_mask_lossrj   r   r   rg   r   r   r   r   c
              
   C   s"  |j \}
}
}}d}||g d  }t|dddf d}|tj||||g|jd }tt|||||||D ]R\}}|\}}}}}}}| r||| }|	r_||d 	dddk}|
 }n||	d|k | }|| ||| ||| || 7 }q8||d  |d   7 }q8||  S )	aF  
        Calculate the loss for instance segmentation.

        Args:
            fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
            masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
            target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
            target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
            batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
            proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
            pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
            imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
            overlap (bool): Whether the masks in `masks` tensor overlap.

        Returns:
            (torch.Tensor): The calculated loss for instance segmentation.

        Notes:
            The batch loss can be computed for improved speed at higher memory usage.
            For example, pred_mask can be computed as follows:
                pred_mask = torch.einsum('in,nhw->ihw', pred, proto)  # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
        r   r   .ru   Nr   r   rH   )rO   r   prodr`   ra   rc   	enumeratezipanyrN   r)   r   r+   )r   rj   r   r   rg   r   r   r   r   r   r   r   r   r0   target_bboxes_normalizedmareamxyxyr   single_i	fg_mask_itarget_gt_idx_ipred_masks_iproto_imxyxy_imarea_imasks_imask_idxr   r   r   r   r     s$   " 
z.v8SegmentationLoss.calculate_segmentation_loss)r3   r4   r5   r6   r   rV   staticmethodr`   Tensorr   boolr   r7   r   r   r   r   r     sN    S	
r   c                       s<   e Zd ZdZ fddZdd Zedd Zdd	 Z  Z	S )

v8PoseLosszICriterion class for computing training losses for YOLOv8 pose estimation.c                    sv   t  | |jd j| _t | _| jddgk}| jd }|r)tt	
| jn	tj|| jd| }t|d| _dS )zQInitialize v8PoseLoss with model parameters and keypoint-specific loss functions.rH      r   r   r   )rs   N)r   r   r   	kpt_shaper   r   bce_poser`   
from_numpyr   rb   rc   onesrr   keypoint_loss)r   r   is_posenkptrs   r   r   r   r     s   

*zv8PoseLoss.__init__c              	      s  t jdjd}t|d tr|n|d \ }t  fdd D djd jfd\}}|	ddd
 }|	ddd
 }|	ddd
 }|j}t j d jdd	 j|d
jd  }t jd\}	}
|jd }|d dd}t ||d dd|d fd}j|j||g d d}|dd\}}|jdddd}|	|}|	|j|dgjR  }|  | |
 |j|	|
 |||\}}}}}t| d}||| | |d< | r@||
 }|||	||||\|d< |d< |d j   }|d  |d 9  < |d  |d 9  < !|||||
||\|d< |d< |d  j"j#9  < |d  j"j$9  < |d  j"j%9  < |d  j"j&9  < |d  j"j'9  < || | fS )z;Calculate the total loss and detach it for pose estimation.r   r   r   r   c                    r   r   r   r   r   r   r   r     r   z'v8PoseLoss.__call__.<locals>.<listcomp>ru   rq   Nr   r   r   rH   r   r   r   r   r   TrI   r^   r   	keypointsrt   rv   )(r`   r   rc   r   listr   r   rE   r   r   r   r   ra   rO   r   r
   rN   r   rb   r+   r   r   kpts_decoder  r   r   r%   r   r   r   r   r)   clonecalculate_keypoints_lossr   r   posekobjr   r   )r   r   r   r0   rz   r   r   r   r   rf   r   r   r   r   r   r   r   re   r   rg   rh   rj   r   ri   r  r   r   r   rV     s\   *
" 
	
zv8PoseLoss.__call__c                 C   sh   |  }|dddf  d9  < |d  | dddgf d 7  < |d  | ddd	gf d 7  < |S )
z0Decode predicted keypoints to image coordinates..Nru   r   rt   r   r   rv   r   )r  )rf   rz   yr   r   r   r    s
   ""zv8PoseLoss.kpts_decodec              
   C   s|  |  }t|}tj|ddd  }	tj||	|jd |jd f|jd}
t|D ]}|||k }||
|d|jd f< q)|	d	d}|

d|dd|jd |jd }|d	ddf  |dddd  < d}d}| r|| }t|| ddddf jddd
}|| }|jd dkr|d dknt|d d}| ||||}|jd dkr| |d | }||fS )a  
        Calculate the keypoints loss for the model.

        This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
        based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
        a binary classification loss that classifies whether a keypoint is present or not.

        Args:
            masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
            target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
            keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
            batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
            stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
            target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
            pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).

        Returns:
            kpts_loss (torch.Tensor): The keypoints loss.
            kpts_obj_loss (torch.Tensor): The keypoints object loss.
        Tr   r   ru   r   Nr   rH   .rI   r   ).ru   rt   )flattenr   r`   r   r   r   rO   rc   r   r_   gatherexpandrN   r   r   r   	full_liker  r  r)   )r   r   r   r  r   r   rg   rz   r   max_kptsbatched_keypointsr   keypoints_itarget_gt_idx_expandedselected_keypoints	kpts_losskpts_obj_lossgt_kptr}   pred_kptr|   r   r   r   r    s2   $&*z#v8PoseLoss.calculate_keypoints_loss)
r3   r4   r5   r6   r   rV   r	  r  r  r7   r   r   r   r   r    s    
B
r  c                   @   s   e Zd ZdZdd ZdS )v8ClassificationLosszACriterion class for computing training losses for classification.c                 C   s>   t |ttfr|d n|}tj||d dd}| }||fS )zDCompute the classification loss between predictions and true labels.r   r   r*   r#   )r   r  r   r'   rM   r   )r   r   r   r0   
loss_itemsr   r   r   rV   c  s   zv8ClassificationLoss.__call__N)r3   r4   r5   r6   rV   r   r   r   r   r*  `  s    r*  c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
	v8OBBLosszdCalculates losses for object detection, classification, and box distribution in rotated YOLO models.c                    s8   t  | td| jddd| _t| j| j| _	dS )z^Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled.r   r   r   r   N)
r   r   r   r   r   rp   rE   rb   rc   r   r   r   r   r   r   n  s   zv8OBBLoss.__init__c                 C   s   |j d dkrtj|dd| jd}|S |dddf }|jdd\}}|jtjd}tj|| d| jd}t|D ]4}||k}	|		  }
rn||	ddf }|d	dd
f 
| tj||	ddf |gdd||d|
f< q:|S )z7Preprocess targets for oriented bounding box detection.r      r   NTr   r   ru   .rq   r   rH   rw   )rO   r`   r   rc   r   rb   r   r   r   r+   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   t  s   *zv8OBBLoss.preprocessc              
      s  t jdjd}t|d tr|n|d \ }|jd }t  fdd D djd j	fd\}}|
ddd }|
ddd }|
ddd }|j}t j d jdd	 j|d
jd  }	t jd\}
}zh|d dd}t ||d dd|d ddfd}|d	d	df |	d   |d	d	df |	d   }}||dk|dk@  }j|j||	g d d}|dd\}}|jdddd}W n ty } ztd|d	}~ww |
||}|  }|dd	df  |9  < |  ||j|
| |||\}}}}}t| d}||| | |d< | rW|dd	df  |  <  |||
||||\|d< |d< n|d  |d  7  < |d  j!j"9  < |d  j!j#9  < |d  j!j$9  < || | fS )zBCalculate and return the loss for oriented bounding box detection.r   r   r   r   c                    r   r   r   r   r   r   r   r     r   z&v8OBBLoss.__call__.<locals>.<listcomp>ru   rq   Nr   r   r   rH   r   r   r   r   r   )r   r   TrI   r^   uh  ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.
This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, i.e. 'yolo train model=yolo11n-obb.pt data=dota8.yaml'.
Verify your dataset is a correctly formatted 'OBB' dataset using 'data=dota8.yaml' as an example.
See https://docs.ultralytics.com/datasets/obb/ for help..)%r`   r   rc   r   r  rO   r   r   rE   r   r   r   r   ra   r   r
   rN   itemr   rb   r+   r   r   r   r   r  r   r   r%   r   r   r   r   r   r   r   r   )r   r   r   r0   
pred_angler   r   r   r   r   rf   r   r   r   rwrhr   r   r   r   re   bboxes_for_assignerr   rg   rh   rj   ri   r   r   r   rV     sf   
**: 


	
zv8OBBLoss.__call__c                 C   sV   | j r|j\}}}|||d|d d| j|j}tj	t
||||fddS )a  
        Decode predicted object bounding box coordinates from anchor points and distribution.

        Args:
            anchor_points (torch.Tensor): Anchor points, (h*w, 2).
            pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
            pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).

        Returns:
            (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
        rq   r   rH   rw   )r   rO   rN   r   r   r   r   r   r`   r   r	   )r   rf   rP   r/  r   r   r   r   r   r   r     s   *zv8OBBLoss.bbox_decode)	r3   r4   r5   r6   r   r   rV   r   r7   r   r   r   r   r,  k  s    Hr,  c                   @       e Zd ZdZdd Zdd ZdS )E2EDetectLosszGCriterion class for computing training losses for end-to-end detection.c                 C   s    t |dd| _t |dd| _dS )zcInitialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.r   )r   r   N)r   one2manyone2oner   r   r   r   r     s   zE2EDetectLoss.__init__c                 C   s^   t |tr	|d n|}|d }| ||}|d }| ||}|d |d  |d |d  fS )r   r   r5  r6  r   )r   r   r5  r6  )r   r   r   r5  loss_one2manyr6  loss_one2oner   r   r   rV     s    zE2EDetectLoss.__call__Nr3   r4   r5   r6   r   rV   r   r   r   r   r4    s    r4  c                   @   s(   e Zd ZdZdd Zdd Zdd ZdS )	TVPDetectLosszOCriterion class for computing training losses for text-visual prompt detection.c                 C   s,   t || _| jj| _| jj| _| jj| _dS )z^Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model.N)r   vp_criterionr   ori_ncr   ori_norE   ori_reg_maxr   r   r   r   r     s   


zTVPDetectLoss.__init__c                 C   s   t |tr	|d n|}| j| jjksJ | jd | j |d jd kr3tjd| jj	dd}||
 fS | |}| ||}|d d }||d fS )z4Calculate the loss for text-visual prompt detection.r   rq   r   r   Trc   requires_grad)r   r   r>  r;  rE   r<  rO   r`   r   rc   r   _get_vp_features)r   r   r   r   r0   vp_featsvp_lossbox_lossr   r   r   rV     s   
zTVPDetectLoss.__call__c                    sb   |d j d  jd   j  j_ jjd   j_ jj_dd  fdd|D D S )z5Extract visual-prompt features from the model output.r   r   rq   c                 S   s$   g | ]\}}}t j||fd dqS )r   rw   )r`   r   )r   r   r   cls_vpr   r   r   r     s    z2TVPDetectLoss._get_vp_features.<locals>.<listcomp>c                    s(   g | ]}|j  jd   jfddqS )rq   r   rw   )r   r>  r<  r   r   vncr   r   r     s   ( )	rO   r>  r<  r;  r   rE   r   r   r   )r   r   r   rF  r   rA  
  s   
zTVPDetectLoss._get_vp_featuresN)r3   r4   r5   r6   r   rV   rA  r   r   r   r   r:    s
    r:  c                   @   r3  )TVPSegmentLosszRCriterion class for computing training losses for text-visual prompt segmentation.c                 C   s   t || _dS )z_Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model.N)r   r;  r   r   r   r   r     s   zTVPSegmentLoss.__init__c           
      C   s   t |dkr|n|d \}}}| jj| jjksJ | jjd | jj |d jd kr:tjd| jjdd}||	 fS | 
|}| |||f|}|d d }	|	|d fS )z7Calculate the loss for text-visual prompt segmentation.r   r   rq   r   Tr?  ru   )r   tp_criterionrE   r;  r   rO   r`   r   rc   r   rA  )
r   r   r   r   r   r   r0   rB  rC  cls_lossr   r   r   rV     s   "
zTVPSegmentLoss.__call__Nr9  r   r   r   r   rH    s    rH  )(r`   torch.nnr   torch.nn.functional
functionalr'   ultralytics.utils.metricsr   ultralytics.utils.opsr   r   r   ultralytics.utils.talr   r   r   r	   r
   ultralytics.utils.torch_utilsr   metricsr   r   talr   Moduler   r8   rA   rY   rp   rr   r   r   r  r*  r,  r4  r:  rH  r   r   r   r   <module>   s2    j 8 t'