from __future__ import annotations

import json
import os
from pathlib import Path

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self

from torch import Tensor, nn
from transformers.utils import logging

from sentence_transformers.models.InputModule import InputModule
from sentence_transformers.models.Module import Module
from sentence_transformers.util import import_from_string, load_dir_path

logger = logging.get_logger(__name__)


class Router(InputModule):
    forward_kwargs = {"task"}
    config_keys: list[str] = ["default_route", "allow_empty_key"]
    config_file_name = "router_config.json"

    def __init__(
        self, sub_modules: dict[str, list[Module]], default_route: str | None = None, allow_empty_key: bool = True
    ) -> None:
        r"""
        This model allows to create asymmetric SentenceTransformer models that apply different modules depending on the specified route,
        such as "query" or "document". Especially useful for models that have different encoders for queries and documents.

        Notably, the ``task`` argument of ``model.encode`` can be used to specify which route to use, and
        ``model.encode_query`` and ``model.encode_document`` are shorthands for using ``task="query"`` and
        ``task="document"``, respectively. These methods also optionally apply ``prompts`` specific to queries
        or documents.

        .. note::

            When training models with the :class:`~sentence_transformers.models.Router` module, you must use the
            ``router_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
            or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map the
            training dataset columns to the correct route ("query" or "document"). For example, if your training dataset(s)
            have ``["question", "positive", "negative"]`` columns, then you can use the following mapping::

                args = SparseEncoderTrainingArguments(
                    ...,
                    router_mapping={
                        "question": "query",
                        "positive": "document",
                        "negative": "document",
                    }
                )

            Additionally, it is common to use a different learning rate for the different routes. For this, you should
            use the ``learning_rate_mapping`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments`
            or :class:`~sentence_transformers.sparse_encoder.training_args.SparseEncoderTrainingArguments` to map parameter patterns
            to their learning rates. For example, if you want to use a learning rate of ``1e-3`` for an SparseStaticEmbedding module and
            ``2e-5`` for the rest of the model, you can do this::

                args = SparseEncoderTrainingArguments(
                    ...,
                    learning_rate=2e-5,
                    learning_rate_mapping={
                        r"SparseStaticEmbedding\.*": 1e-3,
                    }
                )

        In the below examples, the ``Router`` model is used to create asymmetric models with different encoders for
        queries and documents. In these examples, the "query" route is efficient (e.g., using SparseStaticEmbedding),
        while the "document" route uses a more complex model (e.g. a Transformers module). This allows for efficient
        query encoding while still using a powerful document encoder, but the combinations are not limited to this.

        Example:
            ::

                from sentence_transformers import SentenceTransformer
                from sentence_transformers.models import Router, Normalize

                # Use a regular SentenceTransformer for the document embeddings, and a static embedding model for the query embeddings
                document_embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
                query_embedder = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1")
                router = Router.for_query_document(
                    query_modules=list(query_embedder.children()),
                    document_modules=list(document_embedder.children()),
                )
                normalize = Normalize()

                # Create an asymmetric model with different encoders for queries and documents
                model = SentenceTransformer(
                    modules=[router, normalize],
                )

                # ... requires more training to align the vector spaces

                # Use the query & document routes
                query_embedding = model.encode_query("What is the capital of France?")
                document_embedding = model.encode_document("Paris is the capital of France.")

            ::

                from sentence_transformers.models import Router
                from sentence_transformers.sparse_encoder import SparseEncoder
                from sentence_transformers.sparse_encoder.models import MLMTransformer, SparseStaticEmbedding, SpladePooling

                # Load an asymmetric model with different encoders for queries and documents
                doc_encoder = MLMTransformer("opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill")
                router = Router.for_query_document(
                    query_modules=[
                        SparseStaticEmbedding.from_json(
                            "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-distill",
                            tokenizer=doc_encoder.tokenizer,
                            frozen=True,
                        ),
                    ],
                    document_modules=[
                        doc_encoder,
                        SpladePooling(pooling_strategy="max", activation_function="log1p_relu"),
                    ],
                )

                model = SparseEncoder(modules=[router], similarity_fn_name="dot")

                query = "What's the weather in ny now?"
                document = "Currently New York is rainy."

                query_embed = model.encode_query(query)
                document_embed = model.encode_document(document)

                sim = model.similarity(query_embed, document_embed)
                print(f"Similarity: {sim}")

                # Visualize top tokens for each text
                top_k = 10
                print(f"Top tokens {top_k} for each text:")

                decoded_query = model.decode(query_embed, top_k=top_k)
                decoded_document = model.decode(document_embed)

                for i in range(min(top_k, len(decoded_query))):
                    query_token, query_score = decoded_query[i]
                    doc_score = next((score for token, score in decoded_document if token == query_token), 0)
                    if doc_score != 0:
                        print(f"Token: {query_token}, Query score: {query_score:.4f}, Document score: {doc_score:.4f}")

                '''
                Similarity: tensor([[11.1105]], device='cuda:0')
                Top tokens 10 for each text:
                Token: ny, Query score: 5.7729, Document score: 0.8049
                Token: weather, Query score: 4.5684, Document score: 0.9710
                Token: now, Query score: 3.5895, Document score: 0.4720
                Token: ?, Query score: 3.3313, Document score: 0.0286
                Token: what, Query score: 2.7699, Document score: 0.0787
                Token: in, Query score: 0.4989, Document score: 0.0417
                '''

        Note:
            These models are not necessarily stronger than non-asymmetric models. Rudimentary experiments indicate
            that non-Router models perform better in many cases.

        Args:
            sub_modules: Mapping of route keys to lists of modules. Each key corresponds to a specific task type,
                often "query" or "document", and the list contains the modules to be applied for that task type.
            default_route: The default route to use if no task type is specified. If None, an exception will be thrown
                if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
                the default route. Defaults to None.
            allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
                ``default_route`` is None. Defaults to True.
        """
        super().__init__()
        if sub_modules is None or len(sub_modules) == 0:
            raise ValueError("The routes dictionary cannot be empty.")
        if default_route is not None and default_route not in sub_modules:
            raise ValueError(f"Default route '{default_route}' not found in route keys: {list(sub_modules.keys())}")

        self.sub_modules = nn.ModuleDict(
            {route_name: nn.Sequential(*modules) for route_name, modules in sub_modules.items()}
        )

        # If allow_empty_key is True, we can set a default route to the first key in sub_modules.
        if allow_empty_key and default_route is None:
            default_route = next(iter(sub_modules.keys()))
        self.default_route = default_route
        self.allow_empty_key = allow_empty_key

    @classmethod
    def for_query_document(
        cls,
        query_modules: list[Module],
        document_modules: list[Module],
        default_route: str | None = None,
        allow_empty_key: bool = True,
    ) -> Self:
        """
        Creates a Router model specifically for query and document modules, allowing convenient usage via `model.encode_query`
        and `model.encode_document`.

        Args:
            query_modules: List of modules to be applied for the "query" task type.
            document_modules: List of modules to be applied for the "document" task type.
            default_route: The default route to use if no task type is specified. If None, an exception will be thrown
                if no task type is specified. If ``allow_empty_key`` is True, the first key in sub_modules will be used as
                the default route. Defaults to None.
            allow_empty_key: If True, allows the default route to be set to the first key in `sub_modules` if
                ``default_route`` is None. Defaults to True.

        Returns:
            Router: An instance of the Router model with the specified query and document modules.
        """
        return cls(
            sub_modules={"query": query_modules, "document": document_modules},
            default_route=default_route or "document",
            allow_empty_key=allow_empty_key,
        )

    def forward(self, features: dict[str, Tensor], task: str | None = None, **kwargs) -> dict[str, Tensor]:
        if task is None:
            task = features.get("task", self.default_route)
        if task is None:
            if self.training:
                raise ValueError(
                    "You must provide a `router_mapping` argument on the training arguments, "
                    "or set a default route in the `Router` module."
                )
            else:
                raise ValueError(
                    "You must provide a `task` argument when calling this method, "
                    "or set a default route in the `Router` module."
                )

        if task not in self.sub_modules:
            raise ValueError(
                f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
            )

        kwargs["task"] = task
        for module in self.sub_modules[task]:
            module_kwargs = {
                key: value
                for key, value in kwargs.items()
                if hasattr(module, "forward_kwargs") and key in module.forward_kwargs
            }
            features = module(features, **module_kwargs)
        return features

    def get_sentence_embedding_dimension(self) -> int:
        for sub_modules in self.sub_modules.values():
            for module in reversed(sub_modules):
                if hasattr(module, "get_sentence_embedding_dimension"):
                    return module.get_sentence_embedding_dimension()
        return None

    def save(self, output_path: str, safe_serialization: bool = True, **kwargs):
        model_lookup = {}
        model_types = {}
        model_structure = {}

        for name, models in self.sub_modules.items():
            model_structure[name] = []
            for module_idx, model in enumerate(models):
                model_id = f"{name}_{module_idx}_{type(model).__name__}"
                model_lookup[model_id] = model
                model_types[model_id] = f"{type(model).__module__}.{type(model).__name__}"
                model_structure[name].append(model_id)

        for model_id, model in model_lookup.items():
            model_path = os.path.join(output_path, str(model_id))
            os.makedirs(model_path, exist_ok=True)
            try:
                model.save(model_path, safe_serialization=safe_serialization, **kwargs)
            except TypeError:
                # Fallback for legacy models that do not support kwargs
                model.save(model_path)

        with open(os.path.join(output_path, self.config_file_name), "w", encoding="utf8") as fOut:
            json.dump(
                {
                    "types": model_types,
                    "structure": model_structure,
                    "parameters": self.get_config_dict(),
                },
                fOut,
                indent=4,
            )

    def tokenize(self, texts: list[str] | list[tuple[str, str]], task: str | None = None, **kwargs):
        """Tokenizes a text and maps tokens to token-ids"""
        if isinstance(texts[0], dict):
            # Extract the task type key from the dictionaries
            if task is None:
                tasks = set(key for text in texts for key in text.keys())
                if len(tasks) > 1:
                    raise ValueError(
                        "You cannot pass a list of dictionaries with different task types. "
                        "Please ensure all dictionaries have the same task type key, or pass a single `task` argument."
                    )
                task = tasks.pop()

            # Remove dictionary structure
            texts = [text[task] for text in texts]

        if task is None:
            task = self.default_route
        if task is None:
            if self.training:
                raise ValueError(
                    "You must provide a `router_mapping` argument on the training arguments, "
                    "or set a default route in the `Router` module."
                )
            else:
                raise ValueError(
                    "You must provide a `task` argument when calling this method, "
                    "or set a default route in the `Router` module."
                )
        if task not in self.sub_modules:
            raise ValueError(
                f"No route found for task type '{task}'. Available routes: {list(self.sub_modules.keys())}"
            )

        input_module = self.sub_modules[task][0]
        tokenized = input_module.tokenize(texts, **kwargs)
        tokenized["task"] = task
        return tokenized

    @classmethod
    def load(
        cls,
        model_name_or_path: str,
        subfolder: str = "",
        token: bool | str | None = None,
        cache_folder: str | None = None,
        revision: str | None = None,
        local_files_only: bool = False,
        **kwargs,
    ) -> Self:
        hub_kwargs = {
            "token": token,
            "cache_folder": cache_folder,
            "revision": revision,
            "local_files_only": local_files_only,
        }
        # Try the official config file first, then fall back to the legacy config file
        config = cls.load_config(model_name_or_path=model_name_or_path, subfolder=subfolder, **hub_kwargs)
        if not config:
            config = cls.load_config(
                model_name_or_path=model_name_or_path, config_filename="config.json", subfolder=subfolder, **hub_kwargs
            )
        modules = {}
        for model_id, model_type in config["types"].items():
            module_class: Module = import_from_string(model_type)
            try:
                module = module_class.load(
                    model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs, **kwargs
                )
            except TypeError:
                local_path = load_dir_path(
                    model_name_or_path=model_name_or_path, subfolder=Path(subfolder, model_id).as_posix(), **hub_kwargs
                )
                module = module_class.load(local_path)
            modules[model_id] = module

        model_structure = {}
        for key_name, models_list in config["structure"].items():
            model_structure[key_name] = []
            for model_id in models_list:
                model_structure[key_name].append(modules[model_id])

        model = cls(model_structure, **config["parameters"])
        return model

    @property
    def tokenizer(self):
        # We might have multiple tokenizers, one for each route, but we can only return one here.
        for sub_modules in self.sub_modules.values():
            input_module: InputModule = sub_modules[0]
            if hasattr(input_module, "tokenizer") and input_module.tokenizer is not None:
                return input_module.tokenizer
        return None

    @property
    def max_seq_length(self) -> int:
        # Collect all unique max_seq_length values
        max_seq_lengths = set()
        for modules in self.sub_modules.values():
            input_module: InputModule = modules[0]
            if modules and hasattr(input_module, "max_seq_length"):
                max_seq_lengths.add(input_module.max_seq_length)

        if not max_seq_lengths:
            return None
        elif len(max_seq_lengths) == 1:
            # Only one unique max_seq_length
            return max_seq_lengths.pop()
        else:
            logger.warning_once(f"Different max_seq_lengths detected: {max_seq_lengths}. Using the maximum value.")
            return max(max_seq_lengths)

    @max_seq_length.setter
    def max_seq_length(self, value) -> None:
        # Check which modules have max_seq_length
        has_max_seq_length_keys = []
        for key, models in self.sub_modules.items():
            if models and hasattr(models[0], "max_seq_length"):
                has_max_seq_length_keys.append(key)

        if len(has_max_seq_length_keys) == 0:
            logger.warning("No modules have a max_seq_length attribute to set.")
            return

        for key in has_max_seq_length_keys:
            input_module: InputModule = self.sub_modules[key][0]
            input_module.max_seq_length = value


# For backwards compatibility, we ensure that the legacy `Asym` alias points to the new `Router` class.
Asym = Router
