
    rh\                       d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZ  ej>                  e       Z!e ed       G d de                    Z"ee G d de                    Z# ed      e G d de                    Z$e G d de$             Z% ed       G d de$             Z& ed       G d d e$e             Z'g d!Z(y)"zRAG model implementation.    N)	dataclass)CallableOptionalUnion)nn   )CacheEncoderDecoderCache)PretrainedConfig)GenerationConfigGenerationMixinLogitsProcessorListStoppingCriteriaList)ModelOutput)PreTrainedModel)auto_docstringlogging   )	RagConfig)RagRetrieverzI
    Base class for retriever augmented marginalized models outputs.
    )custom_introc                      e Zd ZU dZdZeej                     ed<   dZ	eej                     ed<   dZ
eej                     ed<   dZee   ed<   dZeej                     ed<   dZeej                     ed<   dZeej                     ed	<   dZeej                     ed
<   dZeej                     ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeej                     ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   y)RetrievAugLMMarginOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_heads, sequence_length, embed_size_per_head)`).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nlosslogits
doc_scorespast_key_valuesretrieved_doc_embedsretrieved_doc_idscontext_input_idscontext_attention_mask"question_encoder_last_hidden_state.question_enc_hidden_statesquestion_enc_attentionsgenerator_enc_last_hidden_stategenerator_enc_hidden_statesgenerator_enc_attentionsgenerator_dec_hidden_statesgenerator_dec_attentionsgenerator_cross_attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r	   r   r   
LongTensorr    r!   r"   r#   tupler$   r%   r&   r'   r(   r)   r*        w/var/www/html/ai-insurance-compliance-backend/venv/lib/python3.12/site-packages/transformers/models/rag/modeling_rag.pyr   r   %   s   EN )-D(5$$
%,*.FHU&&'..2J**+2'+OXe_+8<(5#4#45<48x 0 01848x 0 0189=HU%5%56=FJ&1B1B(CJJNu/@/@#/E)F GNGKXeE,=,=s,B&CDKCG#Xe.?.?%@GKO%0A0A30F*G!HOHLhuU->->-C'DELKO%0A0A30F*G!HOHLhuU->->-C'DELJNu/@/@#/E)F GNr5   r   c                      e Zd ZU dZdZeej                     ed<   dZ	eej                     ed<   dZ
ee   ed<   dZeej                     ed<   dZeej                     ed<   dZeej                     ed<   dZeej                     ed	<   dZeej                     ed
<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeej                     ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   dZeeej                  df      ed<   y)RetrievAugLMOutputa7  
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        List of `torch.FloatTensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size,
        num_heads, sequence_length, embed_size_per_head)`).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nr   r   r   r   r   r    r!   r"   .r#   r$   r%   r&   r'   r(   r)   r*   )r+   r,   r-   r.   r   r   r/   r0   r1   r   r   r	   r   r   r2   r    r!   r"   r#   r3   r$   r%   r&   r'   r(   r)   r*   r4   r5   r6   r8   r8      s   CJ +/FHU&&'..2J**+2'+OXe_+8<(5#4#45<48x 0 01848x 0 0189=HU%5%56=FJ&1B1B(CJJNu/@/@#/E)F GNGKXeE,=,=s,B&CDKCG#Xe.?.?%@GKO%0A0A30F*G!HOHLhuU->->-C'DELKO%0A0A30F*G!HOHLhuU->->-C'DELJNu/@/@#/E)F GNr5   r8   a  
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
    c            
       Z    e Zd ZU eed<   dZdZdZe	 	 	 d
de	e
   de	e
   dedefd	       Zy)RagPreTrainedModelconfigragTN.question_encoder_pretrained_model_name_or_path'generator_pretrained_model_name_or_path	retrieverreturnc                 ^   |j                         D ci c]%  \  }}|j                  d      r|t        d      d |' }}}|j                         D ci c]%  \  }}|j                  d      r|t        d      d |' }}}|D ]  }	|d|	z   = 
 |D ]  }	|d|	z   = 
 |j                  dd      }
|
K|J d       ddlm} d|vr%dd	lm}  |j                  |fi |d
di\  }}||d<    |j                  |fi |}
|j                  dd      }|K|J d       ddlm	} d|vr%dd	lm}  |j                  |fi |d
di\  }}||d<    |j                  |fi |}|j                  d      }|+t        j                  |
j                  |j                  fi |} | |
|||      S c c}}w c c}}w )a  
        Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
        model checkpoints.

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
        the model, you need to first set it back in training mode with `model.train()`.

        Params:
            question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the question encoder. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the generator. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args (remaining positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            retriever ([`RagRetriever`], *optional*):
                The retriever to use.
            kwwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`).

                - To update the question_encoder configuration, use the prefix *question_encoder_* for each
                  configuration parameter.
                - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
                - To update the parent model configuration, do not use a prefix for each configuration parameter.

                Behaves differently depending on whether a `config` is provided or automatically loaded.

        Example:

        ```python
        >>> from transformers import RagModel

        >>> # initialize a RAG from two pretrained models.
        >>> model = RagModel.from_pretrained_question_encoder_generator(
        ...     "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
        ... )
        >>> # saving model after fine-tuning
        >>> model.save_pretrained("./rag")
        >>> # load fine-tuned model
        >>> model = RagModel.from_pretrained("./rag")
        ```question_encoder_N
generator_modelznIf `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined   	AutoModelr;   )
AutoConfigreturn_unused_kwargsTzqIf `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be definedAutoModelForSeq2SeqLM)question_encoder	generatorr;   r?   )items
startswithlenpopauto.modeling_autorG   auto.configuration_autorH   from_pretrainedrK   getr   'from_question_encoder_generator_configsr;   )clsr=   r>   r?   kwargsargumentvaluekwargs_question_encoderkwargs_generatorkeyrL   rG   rH   question_encoder_configrM   rK   generator_configr;   s                     r6   *from_pretrained_question_encoder_generatorz=RagPreTrainedModel.from_pretrained_question_encoder_generator   so   L $*<<>#
%""#67 S,-/0%7#
 #
 $*<<>
%""<0 S&()50
 
 + 	2C*S01	2# 	+C|c)*	+ 366wE#AM M 766@C]:C]C]BD-D *.D@')@
 5L'18y88> BY  %(($7	:F !F C//@5OZ5O5O;6?O6fj62 "2 .> *=-==7;KI
 H%>FF '')9)9=CF $4	RXdmnnO#

s   *F#*F))NNN)r+   r,   r-   r   r1   base_model_prefix_supports_flash_attn_supports_sdpaclassmethodr   strr   r   r`   r4   r5   r6   r:   r:      st     N IMAE"&	Jo8@Jo 2:#Jo  	Jo 
Jo Jor5   r:   c            "           e Zd Z	 	 	 	 ddee   dee   dee   dee   f fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j                     dee	j                     deeee	j                           d	ee	j                     d
ee	j                     dee   dee	j                     dee	j                     dee	j                     dee   dee   dee   dee   dee   deee	j                     ef   fd       Z xZS )RagModelr;   rL   rM   r?   c                 H   |||J d       |,t        j                  |j                  |j                  fi |}n-t        || j                        sJ d| d| j                          t
        |   |       |!ddlm} |j                  |j                        }|!ddlm} |j                  |j                        }|| _        | j                  5t        |t              sJ dt        | j                         d	       || _        || _
        || _        d| _        d
| _        y)  
        question_encoder (`PreTrainedModel`, *optional*):
            The model responsible for encoding the question into hidden states for retrieval.
        generator (`PreTrainedModel`, *optional*):
            The model responsible for generating text based on retrieved documents.
        retriever (`RagRetriever`, *optional*):
            The component responsible for retrieving documents from a knowledge base given the encoded question.
        NzQEither a configuration or an question_encoder and a generator has to be provided.zconfig: z has to be of type rE   rF   rJ   z`self.retriever` is of type z&, but should be of type `RagRetriever`F)r   rV   r;   
isinstanceconfig_classsuper__init__rR   rG   from_configrL   rK   rM   r?   r   typectx_encodercontext_encoder_training)	selfr;   rL   rM   r?   rX   rG   rK   	__class__s	           r6   rm   zRagModel.__init__  s<     !&6&ByG\ 	
_	
] >FF '')9)9=CF fd&7&78sHVHL_`d`q`q_r:ss8 #6(44V5L5LMB-99&:J:JKI">>%i6 .tDNN/C.DDjk6 'DN 0"(-%r5   	input_idsattention_maskencoder_outputsdecoder_input_idsdecoder_attention_maskr   r   r    r!   	use_cacheoutput_attentionsoutput_hidden_statesoutput_retrievedn_docsr@   c                    ||n| j                   j                  }|
|
n| j                   j                  }
||n| j                   j                  }||n| j                   j                  }||n| j                   j
                  }| j                  duxr |du xs
 |	du xs |du xr |du }|)|r| j                  ||d      }|d   }| j                  ||j                         j                  dt        j                        j                         | j                  j                   j                  |d      }| j                  r|d	   |d
   |d   |d   |d   |d   f\  }}	}}}}|j                  |      }|	j                  |      }	|j                  |      }|j                  |      }| j!                  ||d      j"                  }|j%                  d||j&                  d         }t        j(                  |j+                  d      |j-                  dd            j/                  d      }n|d	   |d
   |d   |d   f\  }}	}}|j                  |      }|j                  |      }|	j                  |      }	t        j(                  |j+                  d      |j-                  dd            j/                  d      }n|J d       |	J d       |J d       |J d       |j&                  d   |z  dk(  sJ d| d|j&                  d    d       ||j1                  |d      }||j1                  |d      }| j                  ||	|||||
|d	      }|sd}d}d}d}d}nj2                  }|j4                  }|r|sd}d}	d}d}t7        d)i d|j8                  d|d|j:                  d	|d
|	ddd d!|d"|d#|j<                  d$|j>                  d%|j@                  d&|jB                  d'|jD                  d(|jF                  S )*ay  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> outputs = model(input_ids=inputs["input_ids"])
        ```NT)ru   return_dictr   cpudevicedtypeptprefixr}   return_tensorsr    r!   r   tokenized_doc_idstokenized_doc_attention_maskdoc_idsr   rE   zMake sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.z^Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function.M The first dimension of `context_input_ids` should be a multiple of `n_docs`=	, but is .dim)	rt   ru   rv   rw   rx   r   ry   rz   r   Nr   r   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )$r;   r}   ry   rz   r{   r|   r?   rL   detachtor/   float32numpyrM   r   rq   rp   pooler_outputviewshapebmm	unsqueeze	transposesqueezerepeat_interleavehidden_states
attentionsr8   r   r   encoder_last_hidden_stateencoder_hidden_statesencoder_attentionsdecoder_hidden_statesdecoder_attentionscross_attentions)rr   rt   ru   rv   rw   rx   r   r   r    r!   ry   rz   r{   r|   r}   has_to_retrievequestion_enc_outputsr"   retriever_outputsr   retrieved_doc_input_idsretrieved_doc_attention_maskr   gen_outputsr#   r$   s                             r6   forwardzRagModel.forward  s   R "-4;;3E3E!*!6IDKK<Q<Q	1B1N-TXT_T_TqTq$8$D $++JjJj 	 0@/K+QUQ\Q\QmQm NN$& ("d*b.D.LbPZ^bPb(4' 	 "'+'<'<n$ (= ($ 6J!5L2$(NN6==?BB%W\WdWdBekkm>>0077!#' %3 %! 00 **=>)*BC)*@A)*=>)*HI))4).,/4) ):(<(<Y(G%-C-F-Fy-Q*.E.H.H.S+3O3R3RS\3]0+/+;+;/@\jn ,< ,#m ) ,@+D+DF$F$L$LQ$O,(
 "':DDQGI]IgIghiklIm"gaj  **=>)*BC)*@A))4	jf%'=?SUf ,@+B+BCe+f((9(<(<Y(G%-C-F-Fy-Q* "':DDQGI]IgIghiklIm"gaj  )4 P4 .9 T9 "- J-
 % 	
l	
%   #f,2 	
[\b[c d!''*+1.	
2 ( 1 C CFPQ C R!-%;%M%MfZ[%M%\"nn'1+/#9+/ % 

 15.)-&&*##'  $)=)K)K&&:&E&E#&6 '%)"#'  $! 
%%
!
 (77
 0	

 $:
 "6
 0
 0R
 (B
 %<
 -8,Q,Q
 )4(I(I
 &1%C%C
 )4(I(I
 &1%C%C
  (3'C'C!
 	
r5   NNNN)NNNNNNNNNNNNNN)r+   r,   r-   r   r   r   r   rm   r   r/   r2   Tensorr3   r0   
BoolTensorr	   boolintr   r8   r   __classcell__rs   s   @r6   rg   rg   ~  s    .26:/3,00.)*0. #?30. O,	0.
 L)0.d  1515EI8<=A+/268<=A$(,0/3+/ $d
E,,-d
 !.d
 "%e.?.?(@"AB	d

 $E$4$45d
 !))9)9 :d
 "%d
 U../d
 $E$4$45d
 !))9)9 :d
 D>d
 $D>d
 'tnd
 #4.d
 d
  
uU\\"$66	7!d
 d
r5   rg   zu
    A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
    c            &       r    e Zd Z	 	 	 	 d%dee   dee   dee   dee   f fdZdefdZdefdZ	e
	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d&d	eej                     d
eej                     deeeej                           deej                     deej                     dee   deej                     deej                     deej"                     dee   dee   dee   dee   dee   dee   deej                     dee   def$d       Zed        Zed        Zed        Z ej4                         	 	 	 	 	 	 	 	 	 d'd	eej                     d
eej                     deej                     deej                     deej"                     dee   d ee   d!ee   dee   dej                  fd"       Z	 d(d#Zed$        Z xZS ))RagSequenceForGenerationr;   rL   rM   r?   c                     |||J d       |+t        j                  |j                  |j                  fi |}t        |   |       t        ||||      | _        yri   NzHEither a configuration or an encoder and a generator has to be provided.)r;   rL   rM   r?   r   rV   r;   rl   rm   rg   r<   rr   r;   rL   rM   r?   rX   rs   s         r6   rm   z!RagSequenceForGeneration.__init__  s|      !&6&ByG\ 	
V	
] >FF '')9)9=CF 	  6<LXamvwr5   c                 &    || j                   _        y r   r<   r?   rr   r?   s     r6   set_retrieverz&RagSequenceForGeneration.set_retriever      &r5   rp   c                 H    d| j                   _        || j                   _        y NTr<   rq   rp   rr   rp   s     r6    set_context_encoder_for_trainingz9RagSequenceForGeneration.set_context_encoder_for_training      ,0)*r5   rt   ru   rv   rw   rx   r   r    r!   r   ry   rz   r{   r|   exclude_bos_scorereduce_losslabelsr}   r@   c                 ,   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|||}d}
| j	                  ||||||||	||
||||      }d}|@| j                  |j                  |j                  ||| j                   j                  ||      }t        di d|d|j                  d|j                  d|j                  d	|j                  d
|j                  d|j                  d|j                  d|j                  d|j                   d|j"                  d|j$                  d|j&                  d|j(                  d|j*                  d|j,                  d|j.                  S )a3  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        exclude_bos_score (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
            the loss.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )
        ```NFrt   ru   rv   rw   rx   r    r!   r   r   ry   rz   r{   r|   r}   )r   epsilonr   r}   r   r   r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )r;   r}   r   r   r<   get_nllr   r   label_smoothingr   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   )rr   rt   ru   rv   rw   rx   r   r    r!   r   ry   rz   r{   r|   r   r   r   r}   rX   outputsr   s                        r6   r   z RagSequenceForGeneration.forward  s   N "-4;;3E3E1B1N-TXT_T_TqTq%0%<k$++BYBY ($*!I(()+/#9/#9!+/!5-  
" <<""!'33"3   D ( 

>>
 ))
 $33	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
" (/'I'I#
 	
r5   c                 .    | j                   j                  S r   r   rr   s    r6   r?   z"RagSequenceForGeneration.retrievere      xx!!!r5   c                 .    | j                   j                  S r   r<   rM   r   s    r6   rM   z"RagSequenceForGeneration.generatori  r   r5   c                 .    | j                   j                  S r   r<   rL   r   s    r6   rL   z)RagSequenceForGeneration.question_encoderm      xx(((r5   do_deduplicationnum_return_sequences	num_beamsc
                    |	|	n| j                   j                  }	||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|	|J d       | j
                  || j                  ||      d   }| j                  ||j                         j                  dt        j                        j                         | j                  j                   j                  |	d      d	   }|j                  |      }g }||
d
<   ||
d<   d|
d<   ||j                  d   n|j                  d   |	z  }t        |      D ]T  }|||	z  |dz   |	z   } | j                  j                   |fi |
}|rRt        j"                  t%        |D ci c]  }t'        |j)                               | c}j+                                     }|j                  d   }|$|||dz    j-                  |d      } | ||d      }nq|J d       |J d       |j-                  |d      }|||	z  |dz   |	z   }|j-                  |d      }|||dz   ddf   }|j-                  |d      } | ||||d      }|d    j/                  |      d   }|j1                  ||          W | j3                  || j                   j                  j4                        S c c}w )a  
        Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
        for more information on how to set other generate input parameters.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
                retriever.
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
                `context_attention_mask` have to be provided to the forward pass. They are returned by
                [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
                provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
            do_deduplication (`bool`, *optional*):
                Whether or not to deduplicate the generations from different context documents for a given input. Has
                to be set to `False` if used while training with distributed backend.
            num_return_sequences(`int`, *optional*, defaults to 1):
                The number of independently computed returned sequences for each element in the batch. Note that this
                is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
                where we set `num_return_sequences` to `num_beams`.
            num_beams (`int`, *optional*, defaults to 1):
                Number of beams for beam search. 1 means no beam search.
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            kwargs (`dict[str, Any]`, *optional*):
                Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nz= At least one of input_ids or context_input_ids must be givenru   r   r   r   r   r   r    r   r   ru   r   T)r   r   zMake sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.)r    r!   r   r   r   r   )pad_token_id)r;   r}   r   r   r   r?   rL   r   r   r/   r   r   rM   r   r   rangegeneratestacklistre   tolistvaluesrepeattopkappend_cat_and_padr   )rr   rt   ru   r    r!   r   r   r   r   r}   model_kwargsnum_doc_return_sequencesquestion_hidden_stateshypos
batch_sizeindexgenerator_input_idsoutput_sequencesknum_candidatesnew_input_idsr   individual_input_idsindividual_attention_maskindividual_doc_scorestop_cand_indss                             r6   r   z!RagSequenceForGeneration.generateq  sx   B "-4;;3E3E/?/K+QUQ\Q\QmQm$8$D $++JjJj 	! "+!6IDKK<Q<Q	$(9(E 	
K	
E >>%*;*C%)%:%:9Uc%:%def%g" $&--/22%u}}2U[[]~~,,33# !/ ! "!# !2 4 4Y ?$-[!/8+,)-%&+4+@Y__Q'FWF]F]^_F`djFj
:& 3	:E"3EFNeaiSYEY"Z6t~~66#    #(;;tQa4bAS_a5G4b4i4i4k/l#m -33N
 $ )%%!) < C CNTU V}5EY]^-9 T9 "- J-
 (;'A'A"A($ -C56>UZ]^U^bhTh,i),E,L,L^]^,_)(25EAI3F3I(J%(=(D(D^UV(W%&:+D4+&* &fo-334LMaPM LL)-89g3	:j   T[[5J5J5W5W XXW 5cs   #!K#c                     t        j                  d d dd f   j                  j                  d   d      j	                   j
                  j                  j                        gd      ||n j
                  j                  } j
                  j                  xs   j
                  j                  j                  }|d uxr& d d df   j                  |      j                         }	 fd}
t        j                  j                  |d      j                  |j                  d   |z  |d|j!                  d            }t        j                  j                  |d      j#                  d      j#                  d      }|d d d d d dd d f   }|d d d d ddd d f   }|d d d d dd d d f   }t        j                  |||z   |gd      }j#                  d      j#                  d      j%                  d|dd      j'                         |j'                         k(  sJ |j)                  d      }|j+                  dd	      } |
||      \  }}|r|	r|d d d d dd f   j+                  d      n|j+                  d      }|j+                  d      }|j-                  d      }|j-                  d      }| }| }|r |j+                         }|j+                         }||j!                  d      z  }d
|z
  |z  ||z  z   }|S )Nr   r   c                    j                  j                  j                  j                        }|j	                         r$| j                  |d       |j                  |d       | j                  d      |j                  d      fS N        r   eqr;   rM   r   anymasked_fill_r   ll
smooth_objpad_maskrr   targets      r6   
_mask_padsz4RagSequenceForGeneration.get_nll.<locals>._mask_pads  f    yy!6!6!C!CDH||~#.''#6::b>:#5#5b#999r5   r   r   rE   r   r   Tr   keepdim      ?)r/   catnewr   fill_r;   rM   r   r}   bos_token_idr   allr   
functionallog_softmaxr   sizer   r   r   gathersum	logsumexp)rr   
seq_logitsr   r  r   r   r   r}   r  use_bosr  seq_logprobsdoc_logprobsfirst_token_scoressecond_token_scores	remainderrag_logprobsr   r   nll_losssmooth_losseps_ir   s   `  `                   r6   r   z RagSequenceForGeneration.get_nll  s    AqrE]FJJv||A:@@AVAVAcAcdegh
 "-4;;3E3E {{//U4;;3H3H3U3Ud*Rvad||/L/P/P/R	: }}000DIIQ6)62zr7J
 }}000CMMbQ[[\^_ *!QA+6*1a1a<8 Aqr1-	yy"46IL6XZc!djkl !!!$..r299!VQJzz||//1111  Rv 6!%%"d%;
#B
3J %6'R1ab\a rvvay^^A&
\\!_))!,
3!k||~H%//+K,++B//g)EK,??r5   c           
      x   | d   j                  t        | D cg c]  }|j                  d    c}      t        | D cg c]  }|j                  d    c}            j	                  |      }d}| D ]<  }|||||j                  d   z   d |j                  d   f<   ||j                  d   z  }> |S c c}w c c}w )Nr   r   )r	  r  r   maxr
  )tensorsr   toutputinds        r6   r   z%RagSequenceForGeneration._cat_and_padC  s     AJNN3G<q
<=sX_C`STAGGAJC`?abhhiuv 	  	A;<F3qwwqz))<QWWQZ<781771:C	   =C`s
   B2B7r   NNNNNNNNNNNNNNNNN)	NNNNNNNNN)Fr   FN) r+   r,   r-   r   r   r   r   rm   r   r   r   r/   r2   r   r3   r   r	   r0   r   r   r   r   propertyr?   rM   rL   no_gradr   r   staticmethodr   r   r   s   @r6   r   r     se    .26:/3,0x)*x #?3x O,	x
 L)x:'| '+O +  1515@D8<=A+/8<=A26$(,0/3+/,0&*-1 $%^
E,,-^
 !.^
 "%ell(;"<=	^

 $E$4$45^
 !))9)9 :^
 "%^
 $E$4$45^
 !))9)9 :^
 U../^
 D>^
 $D>^
 'tn^
 #4.^
 $D>^
  d^!^
" ))*#^
$ %^
( 
")^
 ^
@ " " " " ) ) U]]_ 15598<=A26+/.2#' $TYE,,-TY !!1!12TY $E$4$45	TY
 !))9)9 :TY U../TY #4.TY 'smTY C=TY TY 
		TY TYn os9v  r5   r   zo
    A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    c            &           e Zd Z	 	 	 	 d.dee   dee   dee   dee   f fdZdefdZdefd	Z		 	 	 	 	 	 d/d
Z
ed        Zed        Zed        Zed        Zd0dZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d1deej(                     deej*                     deeeej.                           deej(                     deej0                     dee   deej(                     deej(                     deej*                     dee   dee   dee   dee   dee   dee   deej(                     d ee   d!ef$d"       Z ej<                         dddddddd e        e        f
deej(                     deej(                     deej(                     deej(                     deej*                     d ee   d#ee!   d$ee"eej.                  ge#e   f      d%ee   d&ee    d!ej(                  fd'       Z$d( Z%d) Z&d* Z'd+ Z(d0d,Z)d2d-Z* xZ+S )3RagTokenForGenerationNr;   rL   rM   r?   c                     |||J d       |+t        j                  |j                  |j                  fi |}t        |   |       t        ||||      | _        yr   r   r   s         r6   rm   zRagTokenForGeneration.__init__U  s|      !&6&ByG\ 	
V	
] >FF '')9)9=CF 	  6<LXamvwr5   c                 &    || j                   _        y r   r   r   s     r6   r   z#RagTokenForGeneration.set_retrievers  r   r5   rp   c                 H    d| j                   _        || j                   _        y r   r   r   s     r6   r   z6RagTokenForGeneration.set_context_encoder_for_trainingv  r   r5   c           
      4    ||d d dd f   }d ||||||d|d	S )Nr   T)	rt   rv   r   r!   rw   r   ry   do_marginalizer}   r4   )	rr   rw   r   ru   ry   rv   r   r}   rX   s	            r6   prepare_inputs_for_generationz3RagTokenForGeneration.prepare_inputs_for_generationz  sB     & 1!RS& 9 .$&4!2.""

 
	
r5   c                 .    | j                   j                  S r   r   r   s    r6   r?   zRagTokenForGeneration.retriever  r   r5   c                 .    | j                   j                  S r   r   r   s    r6   rM   zRagTokenForGeneration.generator  r   r5   c                 .    | j                   j                  S r   r   r   s    r6   rL   z&RagTokenForGeneration.question_encoder  r   r5   c                     d d}| D ]  }|t        fd|D              fz  } t        | t              rt        j                  |      }|S )zeReorders cache for generation. BART-inspired but we need to take care of the extra dimension for docsc                     | j                   d   |j                   d   z  } | j                  d|g| j                   dd   } | j                  d|      }  | j                  dg| j                   dd   }|S )Nr   r   r   rE   )r   r   index_select)r   	new_orderr}   results       r6   _reorder_stackedz>RagTokenForGeneration._reorder_cache.<locals>._reorder_stacked  s    "((+yq/AAF.M..r6TM<O<OPQPR<STM)66q)DM']''E]-@-@-DEFMr5   r4   c              3   b   K   | ]&  } |j                  |j                               ( y wr   )r   r   ).0
past_stater8  beam_idxs     r6   	<genexpr>z7RagTokenForGeneration._reorder_cache.<locals>.<genexpr>  s)     pWa&z8;;z?P?P3QRps   ,/)r3   rj   r
   from_legacy_cache)r   r<  reordered_past
layer_pastr8  s    `  @r6   _reorder_cachez$RagTokenForGeneration._reorder_cache  s_    	 ) 	Jpeopp N	
 o':;0BB>RNr5   c                 |   ||n| j                   j                  }t        j                  j	                  |d      j                  |j                  d   |z  |d|j                  d            }t        j                  |d      }||j                  d      j                  d      z   }t        j                  |d      S )Nr   r   r   r   )r;   r}   r   r  r  r   r   r  r/   r   r  )rr   r  r   r}   r  r  log_prob_sums          r6   marginalizez!RagTokenForGeneration.marginalize  s    !-4;;3E3E }}000DIIQ6)62zr7J
 ((;#l&<&<R&@&J&J2&NN|33r5   rt   ru   rv   rw   rx   r   r    r!   r   ry   rz   r{   r|   r.  r   r   r}   r@   c                 t   ||n| j                   j                  }||n| j                   j                  }||n| j                   j                  }|||}d}
| j	                  ||||||||	||
||||      }d}|j
                  }|C|J | j                  |j
                  |j                  ||| j                   j                  |      }|r| j                  ||j                  |      }t        di d|d|d|j                  d|j                  d	|j                  d
|j                  d|j                  d|j                  d|j                   d|j"                  d|j$                  d|j&                  d|j(                  d|j*                  d|j,                  d|j.                  d|j0                  S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        do_marginalize (`bool`, *optional*):
            If `True`, the logits are marginalized over all documents by making use of
            `torch.nn.functional.log_softmax`.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )

        >>> # or directly generate
        >>> generated = model.generate(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ... )
        >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
        ```NFr   )r   r   r}   r   r   r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )r;   r}   r.  r   r<   r   r   r   r   rD  r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   )rr   rt   ru   rv   rw   rx   r   r    r!   r   ry   rz   r{   r|   r.  r   r   r}   rX   r   r   r   s                         r6   r   zRagTokenForGeneration.forward  s"   ^ "-4;;3E3E+9+E4;;KeKe%0%<k$++BYBY ($*!I(()+/#9/#9!+/!5-  
" $000<<""'33   D %%fg.@.@&IF' 


 ))
 $33	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
" (/'I'I#
 	
r5   generation_configprefix_allowed_tokens_fnlogits_processorstopping_criteriac           	         || j                   }t        j                  |      } |j                  d&i |}|j	                  dd      du}| j                  ||       n| j                  j                  | j                  || j                  ||      d   }| j                  ||j                         j                  dt        j                        j                         | j                  j                  j                   d      }|d	   |d
   |d   }}}|j                  |      }|j                  |      }|j                  |      }t        j"                  |j%                  d      |j'                  dd            j)                  d      }|j*                  d   z  dk(  sJ d d|j*                  d    d       |j*                  d   z  | j,                  j                  j/                         } |||d      }t        j0                  |j2                  z  df|j4                  t        j6                  t9        | j;                               j<                        }|j*                  d   }|d   }d'fd	} |||j2                        } |||j2                        |d<   |j?                  |j2                  d      }||d<   ||d<   ||d<   |d<   | jA                  |||||	|j<                        }| jC                  ||
      }| jE                  ||d|j*                  d   |jF                  dz
  |j<                         |j2                  dk(  rA|jH                  dkD  rtK        d|jH                   d        | jL                  |f|||d!dd"|S |j2                  dkD  r<|jH                  |j2                  kD  rtK        d#       | jN                  |f|||d!d$|S tK        d%|j2                         )(a  
        Implements RAG token decoding.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which has the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
                `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
                the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
                constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://huggingface.co/papers/2010.00904).
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and a
                model's config. If a logit processor is passed that is already created with the arguments or a model's
                config an error is thrown.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                model's config. If a stopping criteria is passed that is already created with the arguments or a
                model's config an error is thrown.
            kwargs (`dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model.

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nru   r   r   r   r   r   r   r    r!   r   r   rE   r   r   r   T)rt   ru   r   )r   r   r   last_hidden_statec                     | d d d d f   j                  df| j                  dd  z         } | j                  |f| j                  dd  z         } | j                  |z  z  f| j                  dd  z         S )Nr   r   )reshaper   expand)tensorr   r   r}   s     r6   extend_enc_outputz9RagTokenForGeneration.generate.<locals>.extend_enc_output  s    D$M*22J63JV\\Z[Z\M]3]^F]]J	6#BV\\RSRTEU#UVF>>:	#9F#B"Dv||TUTVGW"WXXr5   )r   r   r   rv   r}   )rF  input_ids_seq_lengthencoder_input_idsrG  rH  r   )rF  rI  )assistant_modelr   max_cache_lengthr   z)num_return_sequences has to be 1, but is z when doing greedy search.F)rH  rI  rF  synced_gpusstreamerzA`num_return_sequences` has to be smaller or equal to `num_beams`.)rH  rI  rF  rU  uH   `num_beams` has to be an integer strictly superior to 0 (≥ 1), but is r4   r   )(rF  copydeepcopyupdaterU   _prepare_special_tokensr;   r}   r?   rL   r   r   r/   r   r   rM   r   r   r   r   r   r   r<   get_encoderfullr   decoder_start_token_idlongnext
parametersr   r   _get_logits_processor_get_stopping_criteria_prepare_cache_for_generation
max_lengthr   
ValueError_sample_beam_search)rr   rt   ru   r    r!   r   r}   rF  rG  rH  rI  rX   r   kwargs_has_attention_maskr   outr   encoderrv   rQ  rK  rP  pre_processorprepared_stopping_criteriar   s         `                 @r6   r   zRagTokenForGeneration.generateq  s   b $ $ 6 6 MM*;</(//9&9$0$4$45Et$LTX$X!$$%68QR "-4;;3E3E >>%*;*C%)%:%:9Uc%:%def%g"..&--/22%u}}2U[[]~~,,33# ! C '(,-*+ 8L5 $8#:#:;Q#R  1 4 4Y ?%;%>%>y%I" #9#C#CA#FH\HfHfghjkHlmuuJ "''*V39 	
[\b[c d!''*+1.	
9 ',,Q/69
(($$002!,=NdrvwJJ+555q944**)*11	
	  )r2+,?@	Y "33IUfUpUp!q/@):)D)D0
+,  112C2M2MST1U
 &0\"*9&')?%&!'X22/!5/%=-## 3 
 &*%@%@/CT &A &
" 	**  q).99A=## 	+ 	
 &&!+ 559 ?@Q@f@f?g h& &   4<<!."<"3!   ((1, 558I8S8SS !dee$4$$!."<"3!   Z[l[v[vZwx r5   c                 *    | j                  ||      }|S r   )rA  )rr   r   r<  s      r6   _temporary_reorder_cachez.RagTokenForGeneration._temporary_reorder_cacheH  s     --oxHr5   c                 J    | j                   j                  j                         S r   )r<   rM   get_input_embeddingsr   s    r6   rp  z*RagTokenForGeneration.get_input_embeddingsO  s    xx!!6688r5   c                 J    | j                   j                  j                         S r   )r<   rM   get_output_embeddingsr   s    r6   rr  z+RagTokenForGeneration.get_output_embeddingsR  s    xx!!7799r5   c                 L    | j                   j                  j                  |      S r   )r<   rM   set_output_embeddings)rr   new_embeddingss     r6   rt  z+RagTokenForGeneration.set_output_embeddingsU  s    xx!!77GGr5   c                     || j                   j                  }|j                  |j                        }|ddddf   j	                         |ddddf<   ||dddf<   |S )zCShift input ids one token to the right, and pad with start_token_idNr   r   r   )r;   r]  	new_zerosr   clone)rr   rt   start_token_idshifted_input_idss       r6   shift_tokens_rightz(RagTokenForGeneration.shift_tokens_rightX  sh    !![[??N%//	@#,QV#4#:#:#<!QR% "0!Q$  r5   c                     ||n j                   j                  }t        j                  d d dd f   j	                  j
                  d   d      j                   j                   j                  j                        gd       fd} j                  |||      }j                  d      j                         |j                         k(  sJ |j                  d      }	|j                  dd      }
 ||	|
      \  }	}
|	j                  d      }	|
j                  d      }
|	 }|
 }|r |j                         }|j                         }||j                  d      z  }d|z
  |z  ||z  z   }|S )	Nr   r   c                    j                  j                  j                  j                        }|j	                         r$| j                  |d       |j                  |d       | j                  d      |j                  d      fS r   r   r   s      r6   r  z1RagTokenForGeneration.get_nll.<locals>._mask_padsh  r  r5   r   r  Tr  r  )r;   r}   r/   r  r	  r   r
  rM   r   rD  r   r   r  r  r  )rr   r  r   r  r   r   r}   r  r  r   r   r  r  r  r   s   `  `           r6   r   zRagTokenForGeneration.get_nlla  sa   !-4;;3E3EAqrE]FJJv||A:@@AVAVAcAcdegh
	: ''
JG!!"%zz||//1111  Rv 6!%%"d%;
#B
3JVVAY^^A&
3!k||~H%//+K,++B//g)EK,??r5   r   )NNNNNNr   r$  )Fr   N),r+   r,   r-   r   r   r   r   rm   r   r   r/  r%  r?   rM   rL   r'  rA  rD  r   r/   r2   r0   r3   r   r   r	   r   r   r   r   r&  r   r   r   r   r   r   rn  rp  rr  rt  r{  r   r   r   s   @r6   r)  r)  O  s    .26:/3,0x)*x #?3x O,	x
 L)x<'| '+O + 
: " " " " ) )  *	4  156:@D8<=A+/8<=A26$(,0/3+/)-&*-1 $%j
E,,-j
 !!2!23j
 "%ell(;"<=	j

 $E$4$45j
 !))9)9 :j
 "%j
 $E$4$45j
 !))9)9 :j
 U../j
 D>j
 $D>j
 'tnj
 #4.j
 !j
  d^!j
" ))*#j
$ %j
( 
")j
 j
X U]]_ 15598<=A26 $8<W[:M:O<P<RSE,,-S !!1!12S $E$4$45	S
 !))9)9 :S U../S S $$45S #+8S%,,4Gc4R+S"TS ##67S $$89S 
		S Sl9:H!"r5   r)  )rg   r:   r   r)  ))r.   rW  dataclassesr   typingr   r   r   r/   r   cache_utilsr	   r
   configuration_utilsr   
generationr   r   r   r   modeling_outputsr   modeling_utilsr   utilsr   r   configuration_ragr   retrieval_ragr   
get_loggerr+   loggerr   r8   r:   rg   r   r)  __all__r4   r5   r6   <module>r     s]      ! , ,   5 3 f f + - , ( ' 
		H	% 
XO{ XO XOv UO UO  UOp  Qo Qo Qoh X
! X
 X
v 
m1 m
m` 
o. o
od br5   