
    rh~                        d Z ddlZddlmZmZ ddlZddlZddlmZ ddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ  ej<                  e      Z  G d dejB                        Z" G d dejF                        Z$ G d dejF                        Z% G d de      Z&e G d de             Z'e G d de'             Z( ed       G d de'e             Z)g d Z*y)!zPyTorch XGLM model.    N)OptionalUnion)nn   )ACT2FN)CacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging   )
XGLMConfigc            
       `     e Zd ZdZd	dedededee   f fdZdej                  f fdZ
 xZS )
XGLMScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    num_embeddingsembedding_dimpadding_idxembed_scalec                 6    t         |   |||       || _        y N)super__init__r   )selfr   r   r   r   	__class__s        y/var/www/html/ai-insurance-compliance-backend/venv/lib/python3.12/site-packages/transformers/models/xglm/modeling_xglm.pyr   z XGLMScaledWordEmbedding.__init__,   s    D&    	input_idsc                 <    t         |   |      | j                  z  S r   )r   forwardr   )r   r#   r    s     r!   r%   zXGLMScaledWordEmbedding.forward0   s    wy)D,<,<<<r"   )      ?)__name__
__module____qualname____doc__intr   floatr   torchTensorr%   __classcell__r    s   @r!   r   r   '   sE    's '3 'S '_ghm_n '= = =r"   r   c            	            e Zd ZdZddededee   f fdZddededee   fdZeddededee   fd       Z	 e
j                         dd	ee
j                     d
efd       Z xZS )!XGLMSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.num_positionsr   r   c                     t         |           d| _        || _        || _        | j                  || j                  z   ||       y )N   )r   r   offsetr   r   make_weights)r   r3   r   r   r    s       r!   r   z*XGLMSinusoidalPositionalEmbedding.__init__7   s@    *&-$++5}kRr"   r   c                     | j                  |||      }t        | d      r;|j                  | j                  j                  | j                  j
                        }| j                  d|d       y )NweightsdtypedeviceF)
persistent)get_embeddinghasattrtor9   r;   r<   register_buffer)r   r   r   r   emb_weightss        r!   r7   z.XGLMSinusoidalPositionalEmbedding.make_weights>   s[    ((T4#%..t||/A/A$,,J]J].^KYFr"   c                    |dz  }t        j                  d      |dz
  z  }t        j                  t        j                  |t        j
                        j                         | z        }t        j                  | t        j
                        j                         j                  d      |j                  d      z  }t        j                  t        j                  |      t        j                  |      gd      j                  | d      }|dz  dk(  r-t        j                  |t        j                  | d      gd      }|	d||ddf<   |j                  t        j                               S )	z
        Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
        "Attention Is All You Need".
        r5   i'  r   )r;   r   dimN)mathlogr-   exparangeint64r,   	unsqueezecatsincosviewzerosr@   get_default_dtype)r   r   r   half_dimembs        r!   r>   z/XGLMSinusoidalPositionalEmbedding.get_embeddingF   s    !A%hhuoA.iiXU[[AGGISDPQll>=CCEOOPQRUXUbUbcdUeeii338a@EEnVXY1!))S%++na"@AqIC""#CQvve--/00r"   position_idspast_key_values_lengthc                    |j                         \  }}|| j                  z  }d|z   |z   }|| j                  j                  d      kD  r'| j                  || j                  | j
                         | j                  j                  d|j                  d            j                  ||| j                  j                  d         j                         S )Nr5   r   rF   )
sizer6   r9   r7   r   r   index_selectrP   shapedetach)r   rU   rV   bszseq_lenmax_poss         r!   r%   z)XGLMSinusoidalPositionalEmbedding.forward[   s    #((*W# g+ 66T\\&&q))gt'9'94;K;KL||((L,=,=b,ABGGWVZVbVbVhVhikVlmttvvr"   r   )Nr   )r'   r(   r)   r*   r+   r   r   r7   staticmethodr>   r-   no_gradr.   r%   r/   r0   s   @r!   r2   r2   4   s    NSc S# SHUXM SG3 Gs GQYZ]Q^ G 1c 1# 1HUXM 1 1( U]]_	wHU\\$: 	w[^ 	w 	wr"   r2   c                   p    e Zd ZdZ	 	 	 	 ddededee   dee   dee   dee   f fdZ	 	 	 	 	 	 dd	e	j                  d
ee	j                     dee   dee	j                     dee	j                     dedee	j                     dee	j                  ee	j                     eee	j                        f   fdZ xZS )XGLMAttentionz=Multi-headed attention from 'Attention Is All You Need' paper	embed_dim	num_headsdropout
is_decoderbias	layer_idxc                    t         |           || _        || _        || _        ||z  | _        | j
                  |z  | j                  k7  rt        d| j                   d| d      | j
                  dz  | _        || _        || _	        t        j                  |||      | _        t        j                  |||      | _        t        j                  |||      | _        t        j                  |||      | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rg   )r   r   rc   rd   re   head_dim
ValueErrorscalingrf   rh   r   Lineark_projv_projq_projout_proj)r   rc   rd   re   rf   rg   rh   r    s          r!   r   zXGLMAttention.__init__k   s     	""!Y.MMI%$..8MdnnM]$YKr3  }}d*$"ii	94@ii	94@ii	94@		)YTBr"   hidden_stateskey_value_statespast_key_valueattention_masklayer_head_maskoutput_attentionscache_positionreturnc                 B   |du}|j                         \  }	}
}|r|j                  d   n|
}| j                  |      | j                  z  }|St	        |t
              rA|j                  j                  | j                        }|r|j                  }n|j                  }n|}|r|n|}|rK|IrGj                  | j                     j                  }|j                  | j                     j                  }n| j                  |      }| j                  |      }|j!                  |	|d| j"                        j%                  dd      }|j!                  |	|d| j"                        j%                  dd      }|D|s|nd}j'                  ||| j                  d|i      \  }}|rd|j                  | j                  <   |	| j(                  z  d| j"                  f}|j!                  |	|
| j(                  | j"                        j%                  dd      } |j*                  | } |j*                  | } |j*                  | }|j                  d      }t-        j.                  ||j%                  dd            }|j                         |	| j(                  z  |
|fk7  r/t1        d|	| j(                  z  |
|f d|j                                ||j                         |	d|
|fk7  r#t1        d	|	d|
|f d|j                                |j!                  |	| j(                  |
|      |z   }t-        j2                  |t-        j4                  t-        j6                  |j8                        j:                  |j<                  
            }|j!                  |	| j(                  z  |
|      }|j8                  t,        j>                  k(  rNt@        jB                  jE                  |dt,        jF                        jI                  t,        j>                        }n!t@        jB                  jE                  |d      }||j                         | j(                  fk7  r*t1        d| j(                  f d|j                                |j!                  dddd      |j!                  |	| j(                  |
|      z  }|j!                  |	| j(                  z  |
|      }|r?|j!                  |	| j(                  |
|      }|j!                  |	| j(                  z  |
|      }nd}t@        jB                  jK                  || jJ                  | jL                        }t-        j.                  ||      }|j                         |	| j(                  z  |
| j"                  fk7  r7t1        d|	| j(                  |
| j"                  f d|j                                |j!                  |	| j(                  |
| j"                        }|j%                  dd      }|j+                  |	|
| jN                        }| jQ                  |      }||fS )z#Input shape: Batch x Time x ChannelNr   rF   r5   ry   Tz$Attention weights should be of size z	, but is z!Attention mask should be of size )r<   )rE   r;   rD   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size ))rX   rZ   rq   rm   
isinstancer	   
is_updatedgetrh   cross_attention_cacheself_attention_cachelayerskeysvaluesro   rp   rP   rk   	transposeupdaterd   reshaper-   bmmrl   maxtensorfinfor;   minr<   float16r   
functionalsoftmaxfloat32r@   re   r~   rc   rr   )r   rs   rt   ru   rv   rw   rx   ry   is_cross_attentionr\   tgt_len_src_lenquery_statesr   curr_past_key_valuecurrent_states
key_statesvalue_states
proj_shapeattn_weightsattn_weights_reshaped
attn_probsattn_outputs                           r!   r%   zXGLMAttention.forward   s    .T9',,.Wa/A"((+w {{=1DLL@%.*=>+66::4>>J
%*8*N*N'*8*M*M'&4#-?)]."<,33DNNCHHJ.55dnnELLL^4J;;~6L#gr4==ISSTUWXYJ',,S'2t}}MWWXY[\]L)7It+>+E+Ednn?OQ_>`,(
L &@DN--dnn=DNN*B>
#((gt~~t}}U__`acde+|++Z8'Z''4
+|++Z8//!$yyz/C/CAq/IJ3#7'"JJ6dnn8LgW^7_6` a %%'(* 
 %""$a'(BB 7a'8R7SS\]k]p]p]r\st  (,,S$..'7SVddL 99ell5;;|7I7I+J+N+NWcWjWjkL (,,S4>>-A7GTL .==002U]]0[^^_d_l_lmL==0020FL&##%$..):: Et~~FWEX Y',,./1  +//2q!<|?P?PQTVZVdVdfmov?wwL',,S4>>-A7GTL
 %1$5$5c4>>7T[$\!055cDNN6JGU\]L$(!]]**<4<<RVR_R_*`
ii
L9#"6!OO2CRVR_R_3`2a b$$&') 
 "&&sDNNGT]]S!++Aq1 "))#wGmmK0111r"   )        FTN)NNNNFN)r'   r(   r)   r*   r+   r   r,   boolr   r-   r.   r   tupler%   r/   r0   s   @r!   rb   rb   h   s.   G $'%*#$(CC C %	C
 TNC tnC D>C@ 48*.1526"'15|2|||2 #5<<0|2 !	|2
 !.|2 "%,,/|2  |2 !.|2 
u||Xell3XeELL>Q5RR	S|2r"   rb   c                   D    e Zd Zddef fdZ	 	 	 	 	 	 	 	 	 ddej                  deej                     deej                     deej                     deej                     deej                     d	ee   d
ee	   dee	   deej                     dej                  fdZ
 xZS )XGLMDecoderLayerconfigc                 0   t         |           |j                  | _        t	        | j                  |j
                  |j                  d|      | _        |j                  | _        t        |j                     | _        |j                  | _        |j                  rWt	        | j                  |j
                  |j                  d|      | _        t        j                   | j                        | _        t        j                   | j                        | _        t        j&                  | j                  |j(                        | _        t        j&                  |j(                  | j                        | _        t        j                   | j                        | _        y )NT)rc   rd   re   rf   rh   )r   r   d_modelrc   rb   attention_headsattention_dropout	self_attnre   r   activation_functionactivation_fnactivation_dropoutadd_cross_attentionencoder_attnr   	LayerNormencoder_attn_layer_normself_attn_layer_normrn   ffn_dimfc1fc2final_layer_norm)r   r   rh   r    s      r!   r   zXGLMDecoderLayer.__init__  s   &nn,,,,
 ~~#F$>$>?"(";";%% -.. 0000#!D ,.<<+GD($&LL$@!99T^^V^^<99V^^T^^< "T^^ <r"   rs   rv   encoder_hidden_statesencoder_attention_maskrw   cross_attn_layer_head_maskru   rx   	use_cachery   rz   c           	      2   |}| j                  |      }| j                  ||||||
      \  }}t        j                  j	                  || j                  | j
                        }||z   }d}|i|}| j                  |      }| j                  |||||||
      \  }}t        j                  j	                  || j                  | j
                        }||z   }|}| j                  |      }| j                  | j                  |            }t        j                  j	                  || j                  | j
                        }| j                  |      }t        j                  j	                  || j                  | j
                        }||z   }|f}|r|||fz  }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rs   ru   rv   rw   rx   ry   r|   N)rs   rt   rv   rw   ru   rx   ry   )r   r   r   r   re   r~   r   r   r   r   r   r   r   )r   rs   rv   r   r   rw   r   ru   rx   r   ry   residualself_attn_weightscross_attn_weightsoutputss                  r!   r%   zXGLMDecoderLayer.forward'  s   > !11-@ ,0>>'))+/) ,: ,
(( --mt||VZVcVc-d =0 " ,$H 88GM040A0A+!65 :-"3- 1B 1-M- MM11-4<<Z^ZgZg1hM$}4M !--m<**488M+BC--mt?V?Vaeanan-o/--mt||VZVcVc-d =0 ")+=>>Gr"   r   )	NNNNNNFTN)r'   r(   r)   r   r   r-   r.   r   r   r   r%   r/   r0   s   @r!   r   r     s    =z =D 268<9=26=A*.,1$(15N||N !.N  (5	N
 !) 6N "%,,/N %-U\\$:N !N $D>N D>N !.N 
Nr"   r   c                   ,    e Zd ZU eed<   dZdZdgZd Zy)XGLMPreTrainedModelr   modelTr   c                    | j                   j                  }t        |t        j                        rY|j
                  j                  j                  d|       |j                  %|j                  j                  j                          y y t        |t        j                        rf|j
                  j                  j                  d|       |j                  2|j
                  j                  |j                     j                          y y y )Nr   )meanstd)r   init_stdr   r   rn   weightdatanormal_rg   zero_	Embeddingr   )r   moduler   s      r!   _init_weightsz!XGLMPreTrainedModel._init_weights  s    kk""fbii(MM&&CS&9{{&  &&( '-MM&&CS&9!!-""6#5#56<<> . .r"   N)	r'   r(   r)   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modulesr    r"   r!   r   r   x  s"    &*#+,	?r"   r   c            "           e Zd Zddedeej                     f fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j                     dee	j                     dee	j                     dee	j                     dee	j                     d	ee	j                     d
ee	j                     deee	j                        dee	j                     dee   dee   dee   dee   dee	j                     deee	j                     ef   fd       Z xZS )	XGLMModelr   embed_tokensc           	         t         |   |       |j                  | _        |j                  | _        |j                  | _        |j                  | _        |j                  rt        j                  |j                        nd}||| _        n2t        |j                  |j                  | j
                  |      | _        t        |j                  |j                  |j                        | _        t#        j$                  t'        |j(                        D cg c]  }t+        ||       c}      | _        t#        j.                  |j                        | _        d| _        | j5                          yc c}w )zZ
        embed_tokens (`nn.Embedding`, *optional*):
            output embeddings
        r&   N)r   )rh   F)r   r   re   	layerdroppad_token_idr   max_position_embeddingsmax_target_positionsscale_embeddingrG   sqrtr   r   r   
vocab_sizer2   embed_positionsr   
ModuleListrange
num_layersr   r   r   
layer_normgradient_checkpointing	post_init)r   r   r   r   ir    s        r!   r   zXGLMModel.__init__  s   
 	 ~~))!..$*$B$B!393I3Idii/s# ,D 7!!6>>43C3CQ\!D  A**NN 

 mmTYZ`ZkZkTl$mq%5f%J$mn,,v~~6&+# %ns   E0r#   rv   rU   r   r   	head_maskcross_attn_head_maskpast_key_valuesinputs_embedsr   rx   output_hidden_statesreturn_dictry   rz   c                    ||n| j                   j                  }||n| j                   j                  }|
|
n| j                   j                  }
||n| j                   j                  }||	t        d      |8| j                  ||       |j                         }|j                  d|d         }n!|	|	j                         dd }nt        d      |	| j                  |      }	| j                  r%| j                  r|
rt        j                  d       d}
d}|
r<t        |t              s,d}t        j                  d       t!        j"                  |      }||j%                         nd	}t'        |||	|      }|Vt)        j*                  ||d   |z   t(        j,                  ||j.                  n|	j.                  
      }|j1                  d	      }||t3        ||	j4                  |d         }|	| j7                  ||      j9                  |	j.                        z   }t:        j<                  j?                  |tA        | j>                        | j                        }|rdnd}|rdnd}|r|dnd}tC        ||gddg      D ]j  \  }}|	|j                         d	   tE        | jF                        k7  s3t        d| dtE        | jF                         d|j                         d	    d       tI        | jF                        D ]{  \  }}|r||fz  }| j                  r%t)        jJ                  g       }|| jL                  k  r? ||||||||   nd|||   nd|||
|
      }|d	   }|sg||d   fz  }|s||d   fz  }} | jO                  |      }|r||fz  }|r|jQ                         }|stS        d |||||fD              S tU        |||||      S )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        NzDYou cannot specify both input_ids and inputs_embeds at the same timerF   z5You have to specify either input_ids or inputs_embedsz_`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...FTzPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   r:   )r   r|   r   r   r   zThe `z` should be specified for z layers, but it is for .)r   rw   r   ru   rx   r   ry   r   r5   c              3   $   K   | ]  }|| 
 y wr   r   ).0vs     r!   	<genexpr>z$XGLMModel.forward.<locals>.<genexpr>F  s      = s   )last_hidden_stater   rs   
attentionscross_attentions)+r   rx   r   r   use_return_dictrl   %warn_if_padding_and_no_attention_maskrX   rP   r   r   r~   loggerwarning_oncer   r   r	   from_legacy_cacheget_seq_lengthr   r-   rJ   longr<   rL   r   r;   r   r@   r   r   re   r,   ziplenr   	enumeraterandr   r   to_legacy_cacher   r   )r   r#   rv   rU   r   r   r   r   r   r   r   rx   r   r   ry   input_shapereturn_legacy_cacherV   rs   all_hidden_statesall_self_attnsall_cross_attentions	attn_mask	mask_nameidxdecoder_layerdropout_probabilitylayer_outputss                               r!   r%   zXGLMModel.forward  sj   H 2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	%0%<k$++B]B]  ]%>cdd"66y.Q#..*K!r;r?;I&',,.s3KTUU  --i8M&&4==##u "	 $Z?"&\
 2CCOTOETE`!?!?!Afg:K8N
  <<&B"88jj+4+@y''mFZFZ	L (11!4L !,1G1S%?&(;(;[QS_&" &(<(<\Ka(b(e(e  )
 
 --muT\\?R]a]j]j-k #7BD0d&7<Q<]rdh %(4H(IKYoKp$q 	 Iy$>>#A&#dkk*::$	{*DSEUDV W%NN,Q/03 	 #,DKK"8 	@C#!m%55!}}&+jjn#&7)%'=3<3H3dI]Ii,@,Eos."3#-M *!,M =#3"55(4(]1-=,??(7	@: 6  -!11-==?O ':K^]qr  
 9+++%1
 	
r"   r   )NNNNNNNNNNNNNN)r'   r(   r)   r   r   r   r   r   r   r-   r.   listFloatTensorr   r   r   r   r%   r/   r0   s   @r!   r   r     s   z ",,9O >  -115/38<9=,07;=A04$(,0/3&*15d
ELL)d
 !.d
 u||,	d

  (5d
 !) 6d
 ELL)d
 'u||4d
 "$u'8'8"9:d
  -d
 D>d
 $D>d
 'tnd
 d^d
 !.d
  
uU\\"$MM	N!d
 d
r"   r   z
    The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input
    embeddings).
    )custom_introc            $           e Zd ZdZdgZ fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     deej                     d	eej                     d
eej                     dee
ej                        deej                     deej                     dee   dee   dee   dee   deej                     deeej                     ef   f d       Z xZS )XGLMForCausalLMr   zlm_head.weightc                     t         |   |       t        |      | _        t	        j
                  |j                  |j                  d      | _        | j                          y )NFrj   )
r   r   r   r   r   rn   hidden_sizer   lm_headr   )r   r   r    s     r!   r   zXGLMForCausalLM.__init__^  sH     v&
yy!3!3V5F5FUS 	r"   r#   rv   rU   r   r   r   r   r   r   labelsr   rx   r   r   ry   rz   c                 &   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }| j	                  |||||||||	|||||      }| j                  |d         }d}|
? | j                  ||
f| j                   j                  | j                   j                  d|}|s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                  |j                        S )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of
            the decoder.
        encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
            Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
            selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        N)r#   rv   rU   r   r   r   r   r   r   r   rx   r   r   ry   r   )r   r   r   )losslogitsr   rs   r   r   )r   rx   r   r   r   r  loss_functionr   r   r   r   rs   r   r   )r   r#   rv   rU   r   r   r   r   r   r   r  r   rx   r   r   ry   kwargsr   r  r  outputs                        r!   r%   zXGLMForCausalLM.forwardf  sY   V 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B] **)%"7#9!5+'/!5#)  
" gaj)%4%%  ;;11![[55	
 D Y,F'+'7D7V#CVC0#33!//))$55
 	
r"   )NNNNNNNNNNNNNNN)r'   r(   r)   r   _tied_weights_keysr   r   r   r-   r.   r  r  r   r   r   r   r%   r/   r0   s   @r!   r  r  T  s     *+  -115/38<9=,07;=A04)-$(,0/3&*15!Y
ELL)Y
 !.Y
 u||,	Y

  (5Y
 !) 6Y
 ELL)Y
 'u||4Y
 "$u'8'8"9:Y
  -Y
 &Y
 D>Y
 $D>Y
 'tnY
 d^Y
  !.!Y
$ 
uU\\"$EE	F%Y
 Y
r"   r  )r  r   r   )+r*   rG   typingr   r   r-   torch.utils.checkpointr   activationsr   cache_utilsr   r	   
generationr
   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   configuration_xglmr   
get_loggerr'   r   r   r   Moduler2   rb   r   r   r   r  __all__r   r"   r!   <module>r+     s      "    ! 5 ) e 9 l - , * 
		H	%
=bll 
=1w		 1wh\2BII \2~n1 nb ?/ ? ?$ E
# E
 E
P f
)? f
f
R Br"   