o
    Ih*                     @   sP  d dl Z d dlZd dlZd dlmZmZ d dl mZ d dlmZm	Z	m
Z
mZ d dlZd dlm  mZ d dlmZ d dlmZ d dlmZ d dlmZmZmZmZ ejed	Z d
d Z!G dd dej"Z#G dd deZ$e j%dd Z&e$ fde$fddZ'G dd deZ(G dd de(Z)G dd de(Z*G dd de(Z+dej"fddZ,dS )     N)ABCabstractmethod)AbstractContextManager)AnyCallableOptionalUnion)$_functionalization_reapply_views_tls)_get_dispatch_mode_pre_dispatch)is_sparse_any)_detect_infra_mode_disable_infra_modereturn_and_correct_aliasingTorchDispatchModenot_implementedc                     s    fdd}|S )Nc                    s   | j |i i | S N)toselfargskwargsextra_kwargs W/var/www/vscode/kcb/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py_(   s   z&_conversion_method_template.<locals>._r   )r   r   r   r   r   _conversion_method_template'   s   r   c                       s  e Zd ZU dZejed< ejjj	Z
ejjejjjZejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejjjjejj j!jgZ"dZ#e$d  ed< dd Z%d*dd	Z&d
e'fddZ(e)dd Z*dd Z+d
e,fddZ-d+ddZ.d+ddZ/d+ddZ0d+ddZ1d
e2fddZ3 fddZ4d,dd Z5e6ej7d!Z8e6e!d"d#Z9e6ej:d!Z:e6ej;d!Z<e6ej=d!Z>e6ej?d!Z@e6ej,d!Z,e6ejAd!ZBe6ejCd!ZDe6ejEd!ZFd$d% ZGeHd&d' ZId(d) ZJ  ZKS )-FunctionalTensoraF  
    Functional tensors represent tensors that will remove mutations
    from a program. If you perform a mutable operation on a functional tensor,
    it will re-dispatch to the functional variant of that operation.

    Historically, functionalization is implemented in C++ in the dispatcher.
    This class is a lightweight python shim around the C++ functionalization logic.

    FunctionalTensor is required to be used with a corresponding
    FunctionalTensormode active, because it relies
    on using the mode for dispatch (which can properly handle factory functions).
    elemN_inference_mode_basec                 C   s   t |sJ tjt j|@ }t j| |jt	|s|
 nd t	|s'| nd d |j|j|jd|jd dd|}t j| ||_|jsnt  rnt jjjrn| r^d |_||j|j < |S |j|j  |_|jd usnJ |S NF)torch_is_functional_tensorr   _extra_dispatch_keys_C_dispatch_keysTensor_make_wrapper_subclassshaper   stridestorage_offsetdtypelayoutdevicerequires_grad_set_throw_on_mutable_data_ptrr   exportis_inference_mode_enabled	_inductorconfigenable_auto_functionalized_v2is_base_tensorr   _storage_to_baseuntyped_storage)clsr   modeextra_dispatch_keysoutr   r   r   __new__b   sH   zFunctionalTensor.__new__r   c                 C   s   dd |D }|rt d| tS |d u ri }|tjv rgt|dks$J |tjjj	j
tjjjjfv rNt|dkr?t|d tsAJ |t|d j|d S t|dkr[t|d ts]J |t|d jS td)Nc                 S   s$   g | ]}|t jt jjtfvr|qS r   )r!   r&   _subclasses
FakeTensorr   .0tr   r   r   
<listcomp>   s
    z7FunctionalTensor.__torch_dispatch__.<locals>.<listcomp>.FunctionalTensor unrecognized subclass(es): %sr         zqAttempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode())not_implemented_logdebugNotImplementedr   metadata_fnslenr!   opsatenis_strides_like_formatdefaultis_contiguousmemory_format
isinstance_from_functional_tensorr   RuntimeError)r   functypesr   r   unrecognized_typesr   r   r   __torch_dispatch__   s.   


z#FunctionalTensor.__torch_dispatch__returnc                 C   s   dt | j dS )NzFunctionalTensor())reprr   r   r   r   r   __repr__      zFunctionalTensor.__repr__c                 C   s~   t | rJ t | }tt jjj}|d usJ | t | | t||}t || W d    |S 1 s8w   Y  |S r   )	r!   r"   _to_functional_tensorr   r$   _TorchDispatchModeKey
FUNCTIONAL_mirror_autograd_meta_tor   )xx_functionalfunctional_moder;   r   r   r   to_functional   s   


zFunctionalTensor.to_functionalc                 C   s   t |  t | jS r   )r!   _syncrR   r   r[   r   r   r   from_functional   s   
z FunctionalTensor.from_functionalc                 C   s   t | jS r   )r!   _is_functional_tensor_baser   r[   r   r   r   r5         zFunctionalTensor.is_base_tensorc                 C   s   t | j| d S r   )r!   _functionalize_replacer   )r   outputr   r   r   replace_   r]   zFunctionalTensor.replace_c                 C      t | j d S r   )r!   _functionalize_commit_updater   r[   r   r   r   commit_update      zFunctionalTensor.commit_updatec                 C   rm   r   )r!   _functionalize_syncr   r[   r   r   r   sync   rp   zFunctionalTensor.syncc                 C   rm   r   )r!   1_functionalize_mark_mutation_hidden_from_autogradr   r[   r   r   r   "mark_mutation_hidden_from_autograd   rp   z3FunctionalTensor.mark_mutation_hidden_from_autogradc                 C   sF   | j  dkr| j  S | j  dkrdd | j D S dd | j D S )Nr   rE   c                 S      g | ]}|  qS r   )itemr@   r   r   r   r   rB         z+FunctionalTensor.tolist.<locals>.<listcomp>c                 S   ru   r   )tolistrw   r   r   r   rB     rx   )r   dimrv   r[   r   r   r   ry      s
   
zFunctionalTensor.tolistc                    sV   t tjjjjr"tdd |D dkr"t j|i i |ddiS t j|i |S )Nc                 S   s   g | ]	}t |tr|qS r   )rQ   boolr@   argr   r   r   rB     s    z'FunctionalTensor.to.<locals>.<listcomp>rE   copyT)	r   r!   r$   r_   r`   r0   rJ   superr   r   	__class__r   r   r     s   zFunctionalTensor.toc                 O   sF   |pt j }t|dkr| j|g|R i |S | jdd|i|S )Nr   r-   r   )r!   cudacurrent_devicerJ   r   )r   r-   r   r   r   r   r   r     s   zFunctionalTensor.cudar+   cpu)r-   c                 C   
   | j  S r   )r   to_denser[   r   r   r   r        
zFunctionalTensor.to_densec                 C   s   | j jS r   )r   r,   r[   r   r   r   r,   "  s   zFunctionalTensor.layoutc                 C   s   t |  S r   )r{   rv   r[   r   r   r   __bool__&  ri   zFunctionalTensor.__bool__r   NrX   Nr   )L__name__
__module____qualname____doc__r!   r&   __annotations__r$   r_   r`   	_mode_key,_additional_keys_to_prop_for_wrapper_tensorsaddDispatchKey
ZeroTensorr#   rK   rL   rO   rN   rP   rM   is_non_overlapping_and_densesizesym_sizer)   
sym_strider*   sym_storage_offsetnumel	sym_numelrz   primr-   rI   r   r   r<   rW   strr\   staticmethodre   rg   r{   r5   rl   ro   rr   rt   r   ry   r   r   r   int8charr   bfloat16uint8bytefloat64doublefloat32floatfloat16halfint32intint64longr   propertyr,   r   __classcell__r   r   r   r   r   .   sf   
 

	














A+






r   c                       sR   e Zd Zd fdd	Z fddZ fddZdd
dZedefddZ	  Z
S )FunctionalTensorModeFc                    sb   t    || _d| _g | _tjjj| _	|| _
|rtjjjnd | _i | _i | _|| _t | _d S r    )r   __init__r0   is_on_stackenter_stackr!   r$   r_   r`   r   pre_dispatchr   PreDispatch_dispatch_key_tokens_tokens_forward_output_allow_token_discoveryweakrefWeakKeyDictionaryr6   )r   r   r0   r   r   r   r   r   +  s   
	zFunctionalTensorMode.__init__c                    s<    fdd}| d u r j d t  S  j d  S )Nc                      s0    j tjjjkrttjjjS tjtjjjS r   )	r   r!   r$   r   r   r
   r_   r`   _get_dispatch_moder   r[   r   r   _get_prev_modeL  s   z6FunctionalTensorMode.__enter__.<locals>._get_prev_modeTF)r   appendr   	__enter__)r   r   r   r[   r   r   K  s   
	
zFunctionalTensorMode.__enter__c                    s&   | j  }|rt ||| d S d S r   )r   popr   __exit__)r   abcr   r   r   r   r   \  s   
zFunctionalTensorMode.__exit__r   Nc              
      s  d u ri j rbtjjjjkr#dd  tjjjj	||S tjjjj
krbtdd jjD }t|dd  |dd  D ]\}}||< qCdd  tjjjj	||d d S dd |D }|rstd| tS fdd	}	tjvr|	rtj r j|i }
|
tur|
W  d    S W d    n1 sw   Y  fd
d}dd }ddlm}m}m} |rtj tjjjsdd lm   m!} j s|j"s||S ||S ddl#m$}m%} ||rtj tjjjrJ |j&j'|S t()t||f\ }tj*tjjj}tj+tjjj}|s?|r?J tj, tj-tjjjB }tj. /tjjjtj0 }tj1|| z|t2d}tjv r i |}t()tj3||}nVj4tjjjg R i |}j r͈tjjj5j	tjjjj	fv r͇ fdd}| rt6 d rtjjj7j	t8 d  d j
d nt9| t()tj3||}W t:  t2| n
t:  t2| w W d    n	1 sw   Y  tj*tjjj}tj+tjjj}|s|rJ t;dd t(<|D r*tjjj=j	kr,|S tj>j?j@v r[tjjjAjBur[tjCjDE  |i  W d    n	1 sVw   Y  tF||S )Nr~   c                 s       | ]}|j V  qd S r   )namer|   r   r   r   	<genexpr>n      z:FunctionalTensorMode.__torch_dispatch__.<locals>.<genexpr>rE   c                 S   s,   g | ]}t |tjjs|tjtfvr|qS r   )
issubclassr!   r=   r>   r&   r   r?   r   r   r   rB   v  s    z;FunctionalTensorMode.__torch_dispatch__.<locals>.<listcomp>rC   c                    s    j r| tjjjjkrdS ddlm} || rdS tdd | j	j
D }|s*| j	jr,dS  j rH jrF| jdvrD|  rDtd|  d	 dS dS dS )
NFr   )#_should_decompose_because_unsafe_opTc                 s   r   r   )
alias_infor|   r   r   r   r     r   zRFunctionalTensorMode.__torch_dispatch__.<locals>._can_decompose.<locals>.<genexpr>)rL   r   zAt pre-dispatch tracing, we assume that any custom op marked with CompositeImplicitAutograd and have functional schema are safe to not decompose. Found z to be one such op.)r0   r!   rK   rL   dropoutrN   torch._decompr   any_schema	arguments
is_mutabler   	namespace_can_decomposewarningswarn)rT   r   alias_info_presentr[   r   r   r     s&   z?FunctionalTensorMode.__torch_dispatch__.<locals>._can_decomposec                    s2   t | trJ t | tjrt| rt|  S | S r   )rQ   r   r!   r&   r"   rb   r[   r   r   wrap  s   
z5FunctionalTensorMode.__torch_dispatch__.<locals>.wrapc                 S   s   | j S r   )r   r   r   r   r   unwrap  s   z7FunctionalTensorMode.__torch_dispatch__.<locals>.unwrapr   )can_auto_functionalizedo_auto_functionalizedo_auto_functionalize_v2)handle_effectshas_effectsTc                      s*   t jjjjkodv od  d jkS )z
                                Return True if the output of the op must be copied, not an alias
                                r+   r   )r!   rK   rL   _to_copyrN   r+   r   )args_unwrappedrT   r   r   r   	must_copy  s
   z:FunctionalTensorMode.__torch_dispatch__.<locals>.must_copyr   c                 s   s    | ]}t |tV  qd S r   )rQ   r   )r@   rb   r   r   r   r   A  s
    
)Gr0   r!   rK   rL   r   dtype_layoutr   rW   r   rN   r+   tupler   r   ziprF   rG   rH   r   rI   r$   _dispatch_has_kernelr   	decompose*torch._higher_order_ops.auto_functionalizer   r   r   %_dispatch_has_kernel_for_dispatch_keyr   Functionalizetorch._inductor.configr2   r3   r4   torch._higher_order_ops.effectsr   r   r   r   pytreetree_map_only&_dispatch_tls_is_dispatch_key_included&_dispatch_tls_is_dispatch_key_excluded_dispatch_tls_local_include_setDispatchKeySet_dispatch_tls_local_exclude_setremover#   _ForceDispatchKeyGuard#_functionalize_enable_reapply_viewsr&   _op_dkr   r"   _assert_tensor_metadatarR   _freeze_functional_tensor_disable_functionalizationr   tree_leaves
lift_freshTaginplace_viewtagsset_source_Tensorutils_mode_utilsno_dispatchr   )r   rT   rU   r   r   schemar}   r   rV   r   rr   r   r   r   r   inductor_configr   r   kwargs_unwrappedis_includedis_excludedinclude_to_setexclude_to_setold_apply_viewsouts_unwrappedouts_wrappedr   r   )r   rT   r   r   r   rW   a  s  "

*	


	
	



<
z'FunctionalTensorMode.__torch_dispatch__rX   c                 C   s   dS )NTr   )r8   r   r   r   is_infra_modeY     z"FunctionalTensorMode.is_infra_mode)FFFr   )r   r   r   r   r   r   rW   classmethodr{   r  r   r   r   r   r   r   *  s     
 yr   c                   C   s   t tjjjS r   )r   r!   r$   r_   r`   r   r   r   r   disable_functional_mode^  s   r  r9   c                    s&   dd dd   fdd}|S )Nc                 S   s   t | tjrt| S | S r   )rQ   r!   r&   r   re   rA   r   r   r   to_funm  s   
z&dispatch_functionalize.<locals>.to_func                 S   s>   t | tst | tjrt| rJ | S t|  t| jS r   )rQ   r   r!   r&   r"   rf   rR   r   r  r   r   r   from_funr  s   

z(dispatch_functionalize.<locals>.from_func               	      s   t jt jt jjj}|B / tt j| }tt j|}|i |}tt	 |}|W  d    W  d    S 1 sDw   Y  W d    d S 1 sTw   Y  d S r   )
r!   r$   _ExcludeDispatchKeyGuardr   r   r   r   r   r&   r   )r   r   disable_above	func_argsfunc_kwargsfunc_outputsoutputsr  rT   r9   r  r   r   inner{  s   Rz%dispatch_functionalize.<locals>.innerr   )rT   r9   r  r   r  r   dispatch_functionalizek  s   	r  c                   @   s   e Zd Zedee dee fddZedeej	eej	df f defddZ
ededefd	d
ZedefddZedddZedddZedddZedddZdS )BaseFunctionalizeAPIr   rX   c                 C      d S r   r   r   r   r   r   r   wrap_tensors  r  z!BaseFunctionalizeAPI.wrap_tensors.c                 C   r!  r   r   r"  r   r   r   unwrap_tensors  s   z#BaseFunctionalizeAPI.unwrap_tensorsinner_fc                 C   r!  r   r   r   r%  r   r   r   functionalize  r  z"BaseFunctionalizeAPI.functionalizec                 C   r!  r   r   r[   r   r   r   redispatch_to_next  r  z'BaseFunctionalizeAPI.redispatch_to_nextNc                 C   r!  r   r   r   input_tensoroutput_tensorr   r   r   replace  r  zBaseFunctionalizeAPI.replacec                 C   r!  r   r   r   tensorr   r   r   ro     r  z"BaseFunctionalizeAPI.commit_updatec                 C   r!  r   r   r-  r   r   r   rr     r  zBaseFunctionalizeAPI.syncc                 C   r!  r   r   r-  r   r   r   rt     r  z7BaseFunctionalizeAPI.mark_mutation_hidden_from_autogradr   )r   r   r   r   r   r   r#  r   r!   r&   r$  r   r'  r   r(  r,  ro   rr   rt   r   r   r   r   r     s*    r   c                       s   e Zd Z	ddee deddf fddZdee dee fd	d
Z	de
ejeejdf eej f defddZdedefddZdefddZdddZdddZdddZdddZ  ZS )PythonFunctionalizeAPINFr9   r   rX   c                    s$   t    |r	|nt | _|| _d S r   )r   r   r   r9   r   )r   r9   r   r   r   r   r     s   

zPythonFunctionalizeAPI.__init__r   c                 C   s@   | j  tjjtjtj|W  d    S 1 sw   Y  d S r   )r9   r!   r  _pytreer   r&   r   re   r"  r   r   r   r#    s
   
$z#PythonFunctionalizeAPI.wrap_tensors.c                 C   s   t jjttj|S r   )r!   r  r0  r   r   rg   r"  r   r   r   r$    s   z%PythonFunctionalizeAPI.unwrap_tensorsr%  c                 C   s   t || jS r   )r  r9   r&  r   r   r   r'    ri   z$PythonFunctionalizeAPI.functionalizec                 C   s   t  S r   )
contextlibnullcontextr[   r   r   r   r(    s   z)PythonFunctionalizeAPI.redispatch_to_nextc                 C   s*   t |tsJ t |trJ || d S r   )rQ   r   rl   r)  r   r   r   r,    s   zPythonFunctionalizeAPI.replacec                 C      t |tsJ |  d S r   )rQ   r   ro   r-  r   r   r   ro        z$PythonFunctionalizeAPI.commit_updatec                 C   r3  r   )rQ   r   rr   r-  r   r   r   rr     r4  zPythonFunctionalizeAPI.syncc                 C   r3  r   )rQ   r   rt   r-  r   r   r   rt     r4  z9PythonFunctionalizeAPI.mark_mutation_hidden_from_autogradr    r   )r   r   r   r   r   r{   r   r   r   r#  r   r!   r&   listr$  r   r'  r   r(  r,  ro   rr   rt   r   r   r   r   r   r/    s*    



r/  c                   @   s   e Zd Zdee dee fddZdeejeejdf f deejeejdf f fddZ	de
de
fd	d
ZdefddZdddZdddZdddZdddZdS )CppFunctionalizeAPIr   rX   c                 C   s   ddl m} ||ddS Nr   )_wrap_all_tensors_to_functional)level)!torch._functorch.eager_transformsr8  r   r   r8  r   r   r   r#    s   z CppFunctionalizeAPI.wrap_tensors.c                 C   s   ddl m} ||t dS Nr   )#_unwrap_all_tensors_from_functional)reapply_views)r:  r=  _reapply_viewsr   r   r=  r   r   r   r$    s   z"CppFunctionalizeAPI.unwrap_tensorsr%  c                 C   s   t j|S r   )r!   rT   r'  r&  r   r   r   r'    ri   z!CppFunctionalizeAPI.functionalizec                 C   s   t jt jt jjjS r   )r!   r$   r  r   r   r   r[   r   r   r   r(    s   z&CppFunctionalizeAPI.redispatch_to_nextNc                 C      t || d S r   r!   rj   r)  r   r   r   r,    rp   zCppFunctionalizeAPI.replacec                 C      t | d S r   r!   rn   r-  r   r   r   ro        z!CppFunctionalizeAPI.commit_updatec                 C   rC  r   r!   rq   r-  r   r   r   rr     rE  zCppFunctionalizeAPI.syncc                 C   rC  r   r!   rs   r-  r   r   r   rt      rE  z6CppFunctionalizeAPI.mark_mutation_hidden_from_autogradr   )r   r   r   r   r   r#  r   r!   r&   r$  r   r'  r   r(  r,  ro   rr   rt   r   r   r   r   r6    s    
	


r6  c                   @   s   e Zd Zdd Zdee dee fddZdeej	eej	df f deej	eej	df f fdd	Z
d
edefddZdefddZdddZdddZdddZdddZdS )FunctorchFunctionalizeAPIc                 C   s
   || _ d S r   )interpreter)r   rI  r   r   r   r     r   z"FunctorchFunctionalizeAPI.__init__r   rX   c                 C      ddl m} ||| j dS r7  )r:  r8  rI  r9  r;  r   r   r   r#    s   z&FunctorchFunctionalizeAPI.wrap_tensors.c                 C   rJ  r<  )r:  r=  rI  functionalize_add_back_viewsr@  r   r   r   r$    s   
z(FunctorchFunctionalizeAPI.unwrap_tensorsr%  c                 C   s"   t jj|| j rddS ddS )Nmutations_and_views	mutations)r   )r!   rT   r'  rI  rK  r&  r   r   r   r'    s   z'FunctorchFunctionalizeAPI.functionalizec                 C   r   r   )rI  lowerr[   r   r   r   r(  "  r   z,FunctorchFunctionalizeAPI.redispatch_to_nextNc                 C   rA  r   rB  r)  r   r   r   r,  %  rp   z!FunctorchFunctionalizeAPI.replacec                 C   rC  r   rD  r-  r   r   r   ro   (  rE  z'FunctorchFunctionalizeAPI.commit_updatec                 C   rC  r   rF  r-  r   r   r   rr   +  rE  zFunctorchFunctionalizeAPI.syncc                 C   rC  r   rG  r-  r   r   r   rt   .  rE  z<FunctorchFunctionalizeAPI.mark_mutation_hidden_from_autogradr   )r   r   r   r   r   r   r#  r   r!   r&   r$  r   r'  r   r(  r,  ro   rr   rt   r   r   r   r   rH    s    




rH  r.  c                 C   s   t | trt| jS | S r   )rQ   r   r!   rR   r   )r.  r   r   r   mb_unwrap_functional_tensor2  s   
rO  )-r1  r   r   abcr   r   r   typingr   r   r   r   r!   torch.utils._pytreer  r0  r   torch._Cr	   r?  
torch._opsr
   torch._subclasses.meta_utilsr   torch.utils._python_dispatchr   r   r   r   _logginggetArtifactLoggerr   rF   r   r&   r   r   contextmanagerr  r  r   r/  r6  rH  rO  r   r   r   r   <module>   s6    }  6
$2$.