
    rh                     :   d Z ddlZddlmZmZ ddlZddlmZ ddlmZ ddl	m
Z
 ddlmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ  ej>                  e       Z! G d dejD                        Z# G d dejD                        Z$ G d dejJ                        Z& G d dejJ                        Z' G d de      Z(e G d de             Z) G d de)      Z* ed        G d! d"e)             Z+ ed#        G d$ d%e)e             Z,d%dgZ-y)&z/PyTorch TrOCR decoder model (based on RoBERTa).    N)OptionalUnion)nn)CrossEntropyLoss   )ACT2FN)CacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging   )TrOCRConfigc                   n     e Zd ZdZdedef fdZd	dej                  dedej                  f fdZ xZ	S )
TrOCRLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                 N    d| _         t        | 	  || j                   z   |       y )N   )offsetsuper__init__)selfr   r   	__class__s      {/var/www/html/ai-insurance-compliance-backend/venv/lib/python3.12/site-packages/transformers/models/trocr/modeling_trocr.pyr   z(TrOCRLearnedPositionalEmbedding.__init__/   s$     $++5}E    	input_idspast_key_values_lengthposition_idsc                 $   |a|j                   dd \  }}t        j                  |||z   t        j                  | j                  j
                        j                  |d      }n|j                  d      }t        | %  || j                  z         S )z3`input_ids' shape is expected to be [bsz x seqlen].Nr   )dtypedevicer   )shapetorcharangelongweightr(   expand	unsqueezer   forwardr   )r   r#   r$   r%   bszseq_lenr    s         r!   r1   z'TrOCRLearnedPositionalEmbedding.forward5   s     $??2A.LC <<&(>(HPUPZPZcgcncncucufS"o  (11!4Lw|dkk9::r"   )r   N)
__name__
__module____qualname____doc__intr   r+   Tensorr1   __classcell__r    s   @r!   r   r   *   sH    Fs F3 F; ;s ;^c^j^j ; ;r"   r   c            
       `     e Zd ZdZd	dedededee   f fdZdej                  f fdZ
 xZS )
TrOCRScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
    r   r   padding_idxembed_scalec                 6    t         |   |||       || _        y N)r   r   r?   )r   r   r   r>   r?   r    s        r!   r   z!TrOCRScaledWordEmbedding.__init__I   s    D&r"   r#   c                 <    t         |   |      | j                  z  S rA   )r   r1   r?   )r   r#   r    s     r!   r1   z TrOCRScaledWordEmbedding.forwardM   s    wy)D,<,<<<r"   )      ?)r4   r5   r6   r7   r8   r   floatr   r+   r9   r1   r:   r;   s   @r!   r=   r=   D   sE    's '3 'S '_ghm_n '= = =r"   r=   c            	            e Zd ZdZddededee   f fdZeddededee   fd       Z e	j                         dde	j                  d	efd
       Z	 dde	j                  ded	ee   fdZ xZS )"TrOCRSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.num_positionsr   r>   c                     t         |           d| _        || _        || _        | j                  |||      | _        | j                  dt        j                  d             y )Nr   _float_tensorr   )
r   r   r   r   r>   get_embeddingweightsregister_bufferr+   FloatTensor)r   rG   r   r>   r    s       r!   r   z+TrOCRSinusoidalPositionalEmbedding.__init__T   sV    *&))-T_e.?.?.BCr"   r   c                    |dz  }t        j                  d      |dz
  z  }t        j                  t        j                  |t        j
                        j                         | z        }t        j                  | t        j
                        j                         j                  d      |j                  d      z  }t        j                  t        j                  |      t        j                  |      gd      j                  | d      }|dz  dk(  r-t        j                  |t        j                  | d      gd      }|	d||ddf<   |j                  t        j                               S )	z
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        r   i'  r   )r'   r   dimr)   N)mathlogr+   expr,   int64rD   r0   catsincosviewzerostoget_default_dtype)r   r   r>   half_dimembs        r!   rJ   z0TrOCRSinusoidalPositionalEmbedding.get_embedding\   s    !A%hhuoA.iiXU[[AGGISDPQll>=CCEOOPQRUXUbUbcdUeeii338a@EEnVXY1!))S%++na"@AqIC""#CQvve--/00r"   r#   r$   c                 P   |j                         \  }}| j                  || j                  |      j                  |j                        }| j                  dz   |z   }| j
                  || j
                  j                  d      kD  r,| j                  || j                  | j                        | _        | j
                  j                  | j                        | _        | j
                  j                  d|j                  d            j                  ||d      j                         }|S )Nr   r   r)   )size"create_position_ids_from_input_idsr>   rZ   r(   rK   rJ   r   rI   index_selectrX   detach)r   r#   r$   r2   r3   r%   max_posxs           r!   r1   z*TrOCRSinusoidalPositionalEmbedding.forwardo   s     ~~'W>>y$JZJZ\rsvv

 ""Q&0<<7T\\->->q-A#A--gt7I7I4K[K[\DL||t'9'9:LL%%a):):2)>?DDS'SUV]]_r"   c                     |j                  |      j                         }t        j                  |d      j	                  |      |z   |z  }|j                         |z   S )z
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.
        r   rO   )ner8   r+   cumsumtype_asr-   )r   r#   r>   r$   maskincremental_indicess         r!   r`   zETrOCRSinusoidalPositionalEmbedding.create_position_ids_from_input_ids   sW     ||K(,,.$||Da8@@FI__cgg"'')K77r"   rA   )r   )r4   r5   r6   r7   r8   r   r   staticmethodrJ   r+   no_gradr9   r1   r`   r:   r;   s   @r!   rF   rF   Q   s    NDc D# DHUXM D 1c 1# 1HUXM 1 1$ U]]_ s  & bc
8
847
8QYZ]Q^
8r"   rF   c                       e Zd ZdZ	 	 	 	 	 	 	 ddededee   dee   dee   dee   dee   d	ee   d
ee   f fdZ	 	 	 	 	 	 dde	j                  dee	j                     dee   dee	j                     dee	j                     dee   dee	j                     dee	j                  ee	j                     eee	j                        f   fdZ xZS )TrOCRAttentionz>Multi-headed attention from 'Attention Is All You Need' paper.	embed_dim	num_headskdimvdimdropout
is_decoderbiasis_cross_attention	layer_idxc                 P   t         |           || _        ||n|| _        ||n|| _        || _        || _        ||z  | _        | j                  |z  | j                  k(  st        d| j                   d| d      | j                  dz  | _	        || _
        |
| _        t        j                  | j                  ||      | _        t        j                  | j                  ||      | _        t        j                  |||      | _        t        j                  |||      | _        y )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩ru   )r   r   ro   rq   rr   rp   rs   head_dim
ValueErrorscalingrt   rw   r   Lineark_projv_projq_projout_proj)r   configro   rp   rq   rr   rs   rt   ru   rv   rw   r    s              r!   r   zTrOCRAttention.__init__   s    	" ,D)	 ,D)	"!Y.	)T^^;MdnnM] ^;b"  }}d*$"ii		94@ii		94@ii	94@		)YTBr"   hidden_stateskey_value_statespast_key_valueattention_masklayer_head_maskoutput_attentionscache_positionreturnc                 
   |du}|j                         \  }	}
}| j                  |      | j                  z  }|St        |t              rA|j
                  j                  | j                        }|r|j                  }n|j                  }n|}|r|n|}|rK|IrGj                  | j                     j                  }|j                  | j                     j                  }n| j                  |      }| j                  |      }|j                  |	d| j                   | j"                        j%                  dd      }|j                  |	d| j                   | j"                        j%                  dd      }|D|s|nd}j'                  ||| j                  d|i      \  }}|rd|j
                  | j                  <   |	| j                   z  d| j"                  f}|j                  |	|
| j                   | j"                        j%                  dd      } |j(                  | } |j(                  | } |j(                  | }|j                  d      }t+        j,                  ||j%                  dd            }|j                         |	| j                   z  |
|fk7  r/t/        d|	| j                   z  |
|f d|j                                |{|j                         |	d|
|fk7  r#t/        d	|	d|
|f d|j                                |j                  |	| j                   |
|      |z   }|j                  |	| j                   z  |
|      }t0        j2                  j5                  |d
      }||j                         | j                   fk7  r*t/        d| j                   f d|j                                |j                  dddd      |j                  |	| j                   |
|      z  }|j                  |	| j                   z  |
|      }|r?|j                  |	| j                   |
|      }|j                  |	| j                   z  |
|      }nd}t0        j2                  j7                  || j6                  | j8                        }t+        j,                  ||      }|j                         |	| j                   z  |
| j"                  fk7  r7t/        d|	| j                   |
| j"                  f d|j                                |j                  |	| j                   |
| j"                        }|j%                  dd      }|j)                  |	|
|      }| j;                  |      }||fS )z#Input shape: Batch x Time x ChannelNr)   r   r   r   Tz$Attention weights should be of size z	, but is z!Attention mask should be of size rO   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )r_   r   r|   
isinstancer
   
is_updatedgetrw   cross_attention_cacheself_attention_cachelayerskeysvaluesr~   r   rX   rp   rz   	transposeupdatereshaper+   bmmr{   r   
functionalsoftmaxrs   r   r   )r   r   r   r   r   r   r   r   rv   r2   tgt_lenro   query_statesr   curr_past_key_valuecurrent_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputs                           r!   r1   zTrOCRAttention.forward   s7    .T9"/"4"4"6Wi {{=1DLL@%.*=>+66::4>>J
%*8*N*N'*8*M*M'&4#-?)]."<,33DNNCHHJ.55dnnELLL^4J;;~6L#b$..$--PZZ[\^_`J',,S"dnndmmT^^_`bcdL)7It+>+E+Ednn?OQ_>`,(
L &@DN--dnn=DNN*B>
#((gt~~t}}U__`acde+|++Z8'Z''4
+|++Z8//!$yyz/C/CAq/IJ3#7'"JJ6dnn8LgW^7_6` a %%'(* 
 %""$a'(BB 7a'8R7SS\]k]p]p]r\st  (,,S$..'7SVddL',,S4>>-A7GTL}},,\r,B&##%$..):: Et~~FWEX Y',,./1  +//2q!<|?P?PQTVZVdVdfmov?wwL',,S4>>-A7GTL
 %1$5$5c4>>7T[$\!055cDNN6JGU\]L$(!]]**<4<<RVR_R_*`
ii
L9#"6!OO2CRVR_R_3`2a b$$&') 
 "&&sDNNGT]]S!++Aq1!))#w	BmmK0111r"   )NN        FTFN)NNNNFN)r4   r5   r6   r7   r8   r   rD   boolr   r+   r9   r	   tupler1   r:   r;   s   @r!   rn   rn      sj   H #"#&%*#-2$(!C !C 	!C
 sm!C sm!C %!C TN!C tn!C %TN!C D>!CL 48*.1526,115p2||p2 #5<<0p2 !	p2
 !.p2 "%,,/p2 $D>p2 !.p2 
u||Xell3XeELL>Q5RR	Sp2r"   rn   c                   ,    e Zd Zddef fdZ	 	 	 	 	 	 	 	 	 ddej                  deej                     deej                     deej                     deej                     deej                     d	ee   d
ee	   dee	   deej                     fdZ
 xZS )TrOCRDecoderLayerr   c                 b   t         |           |j                  | _        t	        || j                  |j
                  |j                  d|      | _        |j                  | _        t        |j                     | _        |j                  | _        t        j                  | j                        | _        |j                   rot	        || j                  |j
                  |j"                  |j"                  |j                  dd|	      | _        t        j                  | j                        | _        t        j(                  | j                  |j*                        | _        t        j(                  |j*                  | j                        | _        t        j                  | j                        | _        y )NT)ro   rp   rs   rt   rw   )ro   rp   rq   rr   rs   rt   rv   rw   )r   r   hidden_sizero   rn   decoder_attention_headsattention_dropout	self_attnrs   r   activation_functionactivation_fnactivation_dropoutr   	LayerNormself_attn_layer_normrt   cross_attention_hidden_sizeencoder_attnencoder_attn_layer_normr}   decoder_ffn_dimfc1fc2final_layer_norm)r   r   rw   r    s      r!   r   zTrOCRDecoderLayer.__init__)  s=   ++'nn44,,
 ~~#F$>$>?"(";";$&LL$@! ... 88777700#'#
!D ,.<<+GD(99T^^V-C-CD99V33T^^D "T^^ <r"   r   r   encoder_hidden_statesencoder_attention_maskr   cross_attn_layer_head_maskr   r   	use_cacher   c           	      2   |}| j                  ||||||
      \  }}t        j                  j                  || j                  | j                        }||z   }| j                  |      }d}|i|}| j                  |||||||
      \  }}t        j                  j                  || j                  | j                        }||z   }| j                  |      }|}| j                  | j                  |            }t        j                  j                  || j                  | j                        }| j                  |      }t        j                  j                  || j                  | j                        }||z   }| j                  |      }|f}|r|||fz  }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size *(decoder_attention_heads,)*.
            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   r   r   r   N)r   r   r   r   r   r   r   )r   r   r   rs   r   r   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   r   r   r   r   residualself_attn_weightscross_attn_weightsoutputss                  r!   r1   zTrOCRDecoderLayer.forwardM  s   > ! ,0>>'))+/) ,: ,
(( --mt||VZVcVc-d =011-@ " ,$H040A0A+!65 :-"3- 1B 1-M- MM11-4<<Z^ZgZg1hM$}4M 88GM !**488M+BC--mt?V?Vaeanan-o/--mt||VZVcVc-d =0--m< ")+=>>Gr"   rA   )	NNNNNNFTN)r4   r5   r6   r   r   r+   r9   r   r	   r   r1   r:   r;   s   @r!   r   r   (  s    "={ "=N 268<9=26=A*.,1$(15Q||Q !.Q  (5	Q
 !) 6Q "%,,/Q %-U\\$:Q !Q $D>Q D>Q !.Qr"   r   c                   ,    e Zd ZU eed<   dZdZdgZd Zy)TrOCRPreTrainedModelr   modelTr   c                 6   | j                   j                  }t        |t        j                  t        j
                  f      rY|j                  j                  j                  d|       |j                  %|j                  j                  j                          y y t        |t        j                        rf|j                  j                  j                  d|       |j                  2|j                  j                  |j                     j                          y y y )Nr   )meanstd)r   init_stdr   r   r}   Conv1dr.   datanormal_ru   zero_	Embeddingr>   )r   moduler   s      r!   _init_weightsz"TrOCRPreTrainedModel._init_weights  s    kk""fryy"))45MM&&CS&9{{&  &&( '-MM&&CS&9!!-""6#5#56<<> . .r"   N)	r4   r5   r6   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modulesr    r"   r!   r   r     s"    &*#,-	?r"   r   c                   J     e Zd ZdZdef fdZ	 	 	 	 	 	 	 	 	 	 	 	 	 ddZ xZS )TrOCRDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]

    Args:
        config: TrOCRConfig
    r   c           	      `   t         |   |       |j                  | _        |j                  | _        |j
                  | _        |j                  rt        j                  |j                        nd}t        |j                  |j                  | j                  |      | _        |j                  r&t        |j                   |j                        | _        n@t%        |j                   | j                  z   dz   |j                  | j                        | _        |j&                  r%t)        j*                  |j                        | _        nd | _        t)        j,                  t/        |j0                        D cg c]  }t3        ||       c}      | _        d| _        | j9                          y c c}w )NrC   )r?   r   )rw   F)r   r   rs   decoder_layerdrop	layerdroppad_token_idr>   scale_embeddingrQ   sqrtr   r=   
vocab_sizeembed_tokensuse_learned_position_embeddingsr   max_position_embeddingsembed_positionsrF   layernorm_embeddingr   r   
ModuleListrangedecoder_layersr   r   gradient_checkpointing	post_init)r   r   r?   ir    s       r!   r   zTrOCRDecoder.__init__  sP    ~~11!..7=7M7Mdii 2 23SV4v1143C3CQ\
 11#B6CaCacicucu#vD #E..1A1AAAE""  $D  %%')||F4F4F'GD$'+D$mmUZ[a[p[pUq$rPQ%6v%K$rs&+#	 %ss   3F+c                 ,   |
|
n| j                   j                  }
||n| j                   j                  }|	|	n| j                   j                  }	||n| j                   j                  }||t        d      |"|}|j                  d|j                  d         }n-| |j                         dd }|dddddf   }nt        d      | j                  r%| j                  r|	rt        j                  d       d}	d}|	r<t        |t              s,t        j                  d       d}t        j                   |      }||j#                         nd	}|| j%                  |      }| j                   j&                  r| j)                  ||
      }n| j)                  ||
      }||z   }| j*                  | j+                  |      }t,        j.                  j1                  || j0                  | j                        }|j                  }t3        ||||      }||t5        ||j6                  |d         }|rdnd}|
rdnd}|
r|dnd}t9        ||gddg      D ]j  \  }}|	|j                         d	   t;        | j<                        k7  s3t        d| dt;        | j<                         d|j                         d	    d       t?        | j<                        D ]{  \  }}|r||fz  }| j                  r%tA        jB                  g       }|| jD                  k  r? ||||||||   nd|||   nd||
|	|
      }|d	   }|
sg||d   fz  }|s||d   fz  }} |r||fz  }|r|jG                         }|stI        d |||||fD              S tK        |||||      S )a  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer)   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz^`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...FzPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.Tr   )r$   r   )r   r   	head_maskcross_attn_head_maskzThe `z` should be specified for z layers, but it is for .)r   r   r   r   r   r   r   r   r   c              3   $   K   | ]  }|| 
 y wrA   r   ).0vs     r!   	<genexpr>z'TrOCRDecoder.forward.<locals>.<genexpr>  s      = s   )last_hidden_statepast_key_valuesr   
attentionscross_attentions)&r   r   output_hidden_statesr   use_return_dictr{   rX   r*   r_   r   r   loggerwarning_oncer   r	   r
   from_legacy_cacheget_seq_lengthr   r   r   r   r   r   rs   r   r   r'   ziplenr   	enumerater+   randr   to_legacy_cacher   r   )r   r#   r   r   r   r   r   r   inputs_embedsr   r   r   return_dictr   inputinput_shapereturn_legacy_cacher$   	embed_posr   all_hidden_statesall_self_attnsall_cross_attentions	attn_mask	mask_nameidxdecoder_layerdropout_probabilitylayer_outputss                                r!   r1   zTrOCRDecoder.forward  s1   b 2C1N-TXT_T_TqTq$8$D $++JjJj 	 "+!6IDKK<Q<Q	%0%<k$++B]B]  ]%>stt"E!r5;;r?;I&',,.s3K!!Q(+Edee&&4==##t "	#Z?\
 #'1CCOTOETE`!?!?!Afg  --i8M;;66,,UKa,bI,,YOe,fI%	1##/ 44]CM--mt||VZVcVc-dkk:K8N

 !,1G1S%?&(;(;[QS_&"
 #7BD0d&7<Q<]rdh %(4H(IKYoKp$q 	 Iy$>>#A&3t{{+;<$	{*DSEUDV W%NN,Q/03 	 #,DKK"8 	@C#!m%55!}}&+jjn#&7)%'=3<3H3dI]Ii,@,Eos."3#-M *!,M =#3"55(4(]1-=,??(7	@<  -!11-==?O ':K^]qr  
 9+++%1
 	
r"   )NNNNNNNNNNNNN)r4   r5   r6   r7   r   r   r1   r:   r;   s   @r!   r   r     sD    { B "#!!P
r"   r   a  
    The TrOCR Model with a language modeling head. Can be used for summarization.
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    )custom_introc                   $     e Zd Z fdZd Z xZS )TrOCRDecoderWrapperc                 D    t         |   |       t        |      | _        y rA   )r   r   r   decoderr   r   r    s     r!   r   zTrOCRDecoderWrapper.__init__  s     #F+r"   c                 &     | j                   |i |S rA   )r  )r   argskwargss      r!   r1   zTrOCRDecoderWrapper.forward  s    t||T,V,,r"   )r4   r5   r6   r   r1   r:   r;   s   @r!   r  r    s    ,-r"   r  zy
    The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
    c            "           e Zd ZdgZ fdZd Zd Zd Zd Zd Z	d Z
e	 	 	 	 	 	 	 	 	 	 	 	 	 	 dd	eej                     d
eej                     deej                      deej                     deej                     deej                     deeeej                            deej                      deej                     dee   dee   dee   dee   deej                     deeef   fd       Z xZS )TrOCRForCausalLMzoutput_projection.weightc                     d|_         d|_        t        |   |       t	        |      | _        t        j                  |j                  |j                  d      | _
        | j                          y )NTFry   )rt   is_encoder_decoderr   r   r  r   r   r}   r   r   output_projectionr   r  s     r!   r   zTrOCRForCausalLM.__init__  sZ     $)! (0
!#6+=+=v?P?PW\!] 	r"   c                 B    | j                   j                  j                  S rA   r   r  r   r   s    r!   get_input_embeddingsz%TrOCRForCausalLM.get_input_embeddings  s    zz!!...r"   c                 :    || j                   j                  _        y rA   r$  )r   values     r!   set_input_embeddingsz%TrOCRForCausalLM.set_input_embeddings  s    */

'r"   c                     | j                   S rA   r"  r%  s    r!   get_output_embeddingsz&TrOCRForCausalLM.get_output_embeddings  s    %%%r"   c                     || _         y rA   r+  )r   new_embeddingss     r!   set_output_embeddingsz&TrOCRForCausalLM.set_output_embeddings  s
    !/r"   c                 &    || j                   _        y rA   r   r  )r   r  s     r!   set_decoderzTrOCRForCausalLM.set_decoder  s    $

r"   c                 .    | j                   j                  S rA   r1  r%  s    r!   get_decoderzTrOCRForCausalLM.get_decoder  s    zz!!!r"   r#   r   r   r   r   r   r   r  labelsr   r   r   r  r   r   c                 F   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }| j                  j                  |||||||||
||||      }| j                  |d         }d}|	Ft               } ||j                  d| j                   j                        |	j                  d            }|s|f|dd z   }||f|z   S |S t        |||j                  |j                  |j                  |j                        S )a
  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import (
        ...     TrOCRConfig,
        ...     TrOCRProcessor,
        ...     TrOCRForCausalLM,
        ...     ViTConfig,
        ...     ViTModel,
        ...     VisionEncoderDecoderModel,
        ... )
        >>> import requests
        >>> from PIL import Image

        >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
        >>> # init vision2text model with random weights
        >>> encoder = ViTModel(ViTConfig())
        >>> decoder = TrOCRForCausalLM(TrOCRConfig())
        >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

        >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
        >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
        >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

        >>> # load image from the IAM dataset
        >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
        >>> pixel_values = processor(image, return_tensors="pt").pixel_values
        >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"

        >>> # training
        >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
        >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
        >>> model.config.vocab_size = model.config.decoder.vocab_size

        >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
        >>> outputs = model(pixel_values, labels=labels)
        >>> loss = outputs.loss
        >>> round(loss.item(), 2)
        5.30

        >>> # inference
        >>> generated_ids = model.generate(pixel_values)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> generated_text
        'industry, " Mr. Brown commented icily. " Let us have a'
        ```N)r#   r   r   r   r   r   r   r  r   r   r   r  r   r   r)   r   )losslogitsr   r   r   r   )r   r   r   r   r   r  r"  r   rX   r   r   r   r   r   r   )r   r#   r   r   r   r   r   r   r  r5  r   r   r   r  r   r   r8  r7  loss_fctoutputs                       r!   r1   zTrOCRForCausalLM.forward  sK   Z 2C1N-TXT_T_TqTq$8$D $++JjJj 	 &1%<k$++B]B] **$$)"7#9!5+'/!5#) % 
  ''
3')HFKKDKK,B,BCV[[QS_UDY,F'+'7D7V#CVC0#33!//))$55
 	
r"   )NNNNNNNNNNNNNN)r4   r5   r6   _tied_weights_keysr   r&  r)  r,  r/  r2  r4  r   r   r+   
LongTensorr9   rM   r   r   r   r   r1   r:   r;   s   @r!   r  r    s    55	/0&0%"  1515=A=A,07;EI59-1$(,0/3&*15u
E,,-u
 !.u
  ((9(9:	u

 !))9)9 :u
 ELL)u
 'u||4u
 "%e.?.?(@"ABu
   1 12u
 ))*u
 D>u
 $D>u
 'tnu
 d^u
 !.u
  
u77	8!u
 u
r"   r  ).r7   rQ   typingr   r   r+   r   torch.nnr   activationsr   cache_utilsr	   r
   
generationr   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   configuration_trocrr   
get_loggerr4   r   r   r   r=   ModulerF   rn   r   r   r   r  r  __all__r   r"   r!   <module>rK     s*   6  "   % ! 5 ) : l - , , 
		H	%;bll ;4
=r|| 
=;8 ;8|V2RYY V2rv2 vr ?? ? ?$w
' w
t -. -- 
V
+_ V

V
r 5
6r"   