
    rh`O                        d dl mZmZmZ d dlZd dlZd dlmZ ddlm	Z	m
Z
 ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZ ddlmZmZmZmZmZ ddlm Z  erddl!m"Z" ddiZ#dZ$ ejJ                  e&      Z' G d de      Z( G d de e      Z) G d dejT                        Z+ G d de      Z, G d de      Z- G d de      Z. G d d e      Z/ G d! d"e      Z0g d#Z1y)$    )TYPE_CHECKINGAnyOptionalN)nn   )CacheDynamicCache)PretrainedConfig)create_causal_mask)BaseModelOutputWithPast)Unpack)
AddedTokenPreTrainedTokenizer)TransformersKwargslogging   )LlamaForCausalLMLlamaForSequenceClassificationLlamaForTokenClassificationLlamaMLP
LlamaModel)LlamaTokenizer)	TextInput
vocab_fileztokenizer.modelu   ▁c                        e Zd ZdZdZdgZddddddddZdgdgfd	d
gd	gfd	gd	gfdZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d fd	Z xZ	S )GemmaConfiga  
    This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the Gemma-7B.
    e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        vocab_size (`int`, *optional*, defaults to 256000):
            Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`GemmaModel`]
        hidden_size (`int`, *optional*, defaults to 3072):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 24576):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 28):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 16):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*, defaults to 16):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details, check out [this
            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
            `num_attention_heads`.
        head_dim (`int`, *optional*, defaults to 256):
            The attention head dimension.
        hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The legacy activation function. It is overwritten by the `hidden_activation`.
        hidden_activation (`str` or `function`, *optional*):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
            if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
        max_position_embeddings (`int`, *optional*, defaults to 8192):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*, defaults to 0):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 1):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 2):
            Beginning of stream token id.
        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
    ```python
    >>> from transformers import GemmaModel, GemmaConfig
    >>> # Initializing a Gemma gemma-7b style configuration
    >>> configuration = GemmaConfig()
    >>> # Initializing a model from the gemma-7b style configuration
    >>> model = GemmaModel(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```gemmapast_key_valuescolwiserowwise)zlayers.*.self_attn.q_projzlayers.*.self_attn.k_projzlayers.*.self_attn.v_projzlayers.*.self_attn.o_projzlayers.*.mlp.gate_projzlayers.*.mlp.up_projzlayers.*.mlp.down_proj	input_idsinputs_embedshidden_statesattention_mask)embed_tokenslayersnormc                    || _         |
| _        || _        || _        || _        || _        || _        || _        || _        |	| _	        || _
        || _        || _        || _        || _        || _        t!        | D  d||||d| y )N)pad_token_idbos_token_ideos_token_idtie_word_embeddings )
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headshead_dimnum_key_value_heads
hidden_acthidden_activationinitializer_rangerms_norm_eps	use_cache
rope_thetaattention_biasattention_dropoutsuper__init__)selfr.   r0   r1   r2   r3   r5   r4   r6   r7   r/   r8   r9   r:   r)   r+   r*   r,   r;   r<   r=   kwargs	__class__s                         z/var/www/html/ai-insurance-compliance-backend/venv/lib/python3.12/site-packages/transformers/models/gemma/modular_gemma.pyr?   zGemmaConfig.__init__   s    0 %'>$&!2!2#6  #6 $!2!2("$,!2 	
%%% 3		

 	
    )i  i   i `        rF      gelu_pytorch_tanhNi    g{Gz?ư>Tr      r   Tg     @Fg        )
__name__
__module____qualname____doc__
model_typekeys_to_ignore_at_inferencebase_model_tp_planbase_model_pp_planr?   __classcell__rB   s   @rC   r   r   2   s    AF J#4"5%.%.%.%."+ )"+ &(9:#%568IJ!"_$56 & $ +/
 /
rD   r   c            	           e Zd ZdZ	 	 	 	 	 	 	 	 	 	 ddeeeef      fdZd Z	d Z
ddd	ee   fd
Zd Z	 	 ddee   deded	efdZd Zy)GemmaTokenizera
  
    Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
    no padding token in the original model.

    Args:
        vocab_file (`str`):
            Path to the vocabulary file.
        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
            The end of sequence token.
        pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
            attention mechanisms or loss computation.
        sp_model_kwargs (`dict[str, Any]`, `Optional`, *optional*):
            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
            to set:

            - `enable_sampling`: Enable subword regularization.
            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

              - `nbest_size = {0,1}`: No sampling is performed.
              - `nbest_size > 1`: samples from the nbest_size results.
              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                using forward-filtering-and-backward-sampling algorithm.

            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
              BPE-dropout.

        add_bos_token (`bool`, *optional*, defaults to `True`):
            Whether or not to add an `bos_token` at the start of sequences.
        add_eos_token (`bool`, *optional*, defaults to `False`):
            Whether or not to add an `eos_token` at the end of sequences.
        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
            Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
            extra spaces.
        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
            Whether or not the default system prompt for Gemma should be used.
        spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to add spaces between special tokens.
    Nsp_model_kwargsc                    |i n|| _         t        |t              rt        |dd      n|}t        |t              rt        |dd      n|}t        |t              rt        |dd      n|}t        |t              rt        |dd      n|}|| _        || _        || _        |
| _        t        j                  di | j                   | _
        | j                  j                  |       t        j                  | f||||||||	|
|d
| y )NFT)
normalizedspecial)
	bos_token	eos_token	unk_token	pad_tokenadd_bos_tokenadd_eos_tokenrW   clean_up_tokenization_spacesuse_default_system_promptspaces_between_special_tokensr-   )rW   
isinstancestrr   r   r_   r`   rb   spmSentencePieceProcessorsp_modelLoadr   r?   )r@   r   r]   r[   r\   r^   rW   r_   r`   ra   rb   rc   rA   s                rC   r?   zGemmaTokenizer.__init__   s    &5%<r/MWXacfMgJyUDImv	MWXacfMgJyUDImv	MWXacfMgJyUDImv	MWXacfMgJyUDImv	$**)B&22JT5I5IJ:&$$	
''+)E&?*G	
 	
rD   c                     t        d      NzNot needed for GemmaAttributeErrorr@   s    rC   get_spm_processorz GemmaTokenizer.get_spm_processor      344rD   c                     t        d      rk   rl   rn   s    rC   unk_token_lengthzGemmaTokenizer.unk_token_length  rp   rD   textr   returnc                 0    t        j                  | |fi |S )ze
        Args:
            text: TextInput
        Simply calls PreTrainedTokenizer's method
        )r   tokenizer@   rs   rA   s      rC   rv   zGemmaTokenizer.tokenize  s     #++D$A&AArD   c                 D    | j                   j                  |t              S )z
        Args:
            text: TextInput
        Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
        )out_type)rh   encodere   rw   s      rC   	_tokenizezGemmaTokenizer._tokenize!  s     }}##D3#77rD   	token_idsskip_special_tokensrc   c                    g }g }|D ]  }|r|| j                   v r|| j                  v rW|r*|j                  | j                  j	                  |             |j                  | j                  |   j
                         g }y|j                  |        |r*|j                  | j                  j	                  |             |rdj                  |      }ndj                  |      }|j                  t        d      S )N  )	all_special_ids_added_tokens_decoderappendrh   decodecontentjoinreplaceSPIECE_UNDERLINE)r@   r|   r}   rc   rA   	sub_textscurrent_sub_textidss           rC   _decodezGemmaTokenizer._decode)  s     	 		-C"sd.B.B'Bd000#$$T]]%9%9:J%KL  !;!;C!@!H!HI#%  '',		- T]]112BCD(+I	*I  !1377rD   c                     g }d}|D ]E  }|| j                   v r$|| j                  j                  |      |z   z  }g }5|j                  |       G || j                  j                  |      z  }|S )z:Converts a sequence of tokens (string) in a single string.r   )_added_tokens_encoderrh   r   r   )r@   tokenscurrent_sub_tokens
out_stringtokens        rC   convert_tokens_to_stringz'GemmaTokenizer.convert_tokens_to_stringF  s    
 	1E222dmm223EFNN
%'""))%0	1 	dmm**+=>>
rD   )
z<unk>z<bos>z<eos>z<pad>NTFFFF)FF)rK   rL   rM   rN   r   dictre   r   r?   ro   rr   listrv   r{   intboolr   r   r-   rD   rC   rV   rV      s    ,b 48%*"'&+)
 "$sCx.1)
V55B[ BtCy B8 %*.3	898 "8 (,	8 
8:rD   rV   c                   <     e Zd Zddedef fdZd Zd Zd Z xZ	S )GemmaRMSNormdimepsc                     t         |           || _        t        j                  t        j                  |            | _        y )N)r>   r?   r   r   	Parametertorchzerosweight)r@   r   r   rB   s      rC   r?   zGemmaRMSNorm.__init__V  s.    ll5;;s#34rD   c                     |t        j                  |j                  d      j                  dd      | j                  z         z  S )Nr   T)keepdim)r   rsqrtpowmeanr   )r@   xs     rC   _normzGemmaRMSNorm._norm[  s4    5;;quuQx}}R}>IJJJrD   c                     | j                  |j                               }|d| j                  j                         z   z  }|j                  |      S )Ng      ?)r   floatr   type_as)r@   r   outputs      rC   forwardzGemmaRMSNorm.forward^  sC    AGGI& 3!2!2!445~~a  rD   c                 ^    t        | j                  j                         d| j                   S )Nz, eps=)tupler   shaper   rn   s    rC   
extra_reprzGemmaRMSNorm.extra_repre  s'    ))*+6$((<<rD   )rI   )
rK   rL   rM   r   r   r?   r   r   r   rS   rT   s   @rC   r   r   U  s&    5C 5e 5
K!=rD   r   c                        e Zd Z fdZ xZS )GemmaMLPc                 H   t         |           t        j                  | j                  | j
                  d      | _        t        j                  | j                  | j
                  d      | _        t        j                  | j
                  | j                  d      | _        y )NF)bias)	r>   r?   r   Linearr0   r1   	gate_projup_proj	down_proj)r@   configrB   s     rC   r?   zGemmaMLP.__init__j  sq    4#3#3T5K5KRWXyy!1!143I3IPUV4#9#94;K;KRWXrD   )rK   rL   rM   r?   rS   rT   s   @rC   r   r   i  s    Y YrD   r   c                       e Zd Z	 	 	 	 	 	 	 ddeej
                     deej                     deej
                     dee   deej                     dee	   deej
                     d	e
e   d
efdZy)
GemmaModelNr!   r$   position_idsr   r"   r:   cache_positionrA   rt   c                    |d u |d uz  rt        d      || j                  |      }|r|
t               }|F||j                         nd}	t	        j
                  |	|	|j                  d   z   |j                        }||j                  d      }t        | j                  |||||      }
|}| j                  ||      }t	        j                  | j                  j                  dz  |j                        }||z  }| j                  d | j                  j                    D ]  } ||f|
|||||d|} | j#                  |      }t%        ||r|	      S d 	      S )
Nz:You must specify exactly one of input_ids or inputs_embedsr   rJ   )device)r   input_embedsr$   r   r   r   g      ?)dtype)r$   r   past_key_valuer:   r   position_embeddings)last_hidden_stater   )
ValueErrorr%   r	   get_seq_lengthr   aranger   r   	unsqueezer   r   
rotary_embtensorr0   r   r&   r2   r'   r   )r@   r!   r$   r   r   r"   r:   r   rA   past_seen_tokenscausal_maskr#   r   
normalizerdecoder_layers                  rC   r   zGemmaModel.forwardr  s    -t";<YZZ  --i8M0*nO!CRC^==?de"\\ "2]5H5H5K"KTaThThN )33A6L(;;&))+%
 & #oom\J
 \\$++"9"93">mFYFYZ
%
2![[)H4;;+H+HI 
	M)	*).#-$7	 	M
	 		-0&+/8O
 	
>B
 	
rD   )NNNNNNN)rK   rL   rM   r   r   
LongTensorTensorr   FloatTensorr   r   r   r   r   r-   rD   rC   r   r   q  s     151537+/59$(59A
E,,-A
 !.A
 u//0	A

 "%A
   1 12A
 D>A
 !!1!12A
 +,A
 
!A
rD   r   c                        e Zd Z fdZ xZS )GemmaForCausalLMc                  8     t               j                  di | S )a|  
        Example:

        ```python
        >>> from transformers import AutoTokenizer, GemmaForCausalLM

        >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")

        >>> prompt = "What is your favorite condiment?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "What is your favorite condiment?"
        ```r-   )r>   r   )super_kwargsrB   s    rC   r   zGemmaForCausalLM.forward  s    $ uw...rD   )rK   rL   rM   r   rS   rT   s   @rC   r   r     s    / /rD   r   c                       e Zd Zy)GemmaForSequenceClassificationNrK   rL   rM   r-   rD   rC   r   r         rD   r   c                       e Zd Zy)GemmaForTokenClassificationNr   r-   rD   rC   r   r     r   rD   r   )r   rV   r   r   r   r   GemmaPreTrainedModel)2typingr   r   r   sentencepiecerf   r   r   cache_utilsr   r	   configuration_utilsr
   masking_utilsr   modeling_outputsr   processing_utilsr   tokenization_utilsr   r   utilsr   r   llama.modeling_llamar   r   r   r   r   llama.tokenization_llamar   tokenization_utils_baser   VOCAB_FILES_NAMESr   
get_loggerrK   loggerr   rV   Moduler   r   r   r   r   r   __all__r-   rD   rC   <module>r      s     0 /    . 3 / 7 & A 0  6 4!#45   
		H	%D
" D
NY^%8 Yx=299 =(Yx YB
 B
J/' /,	%C 		"= 	rD   