from __future__ import annotations

import random
from collections.abc import Iterable
from typing import Any

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import (
    CachedGISTEmbedLoss,
    CachedMultipleNegativesRankingLoss,
    CachedMultipleNegativesSymmetricRankingLoss,
)


def shrink(tensor: Tensor, dim: int) -> Tensor:
    tensor_dim = tensor.shape[-1]
    if dim > tensor_dim:
        raise ValueError(
            f"Dimension {dim} in matryoshka_dims cannot be greater than the model's embedding dimension: {tensor_dim}"
        )
    tensor = tensor[..., :dim]
    tensor = F.normalize(tensor, p=2, dim=-1)
    return tensor


class ForwardDecorator:
    """
    This decorator is used to cache the output of the Sentence Transformer's forward pass,
    so that it can be shrank and reused for multiple loss calculations. This prevents the
    model from recalculating the embeddings for each desired Matryoshka dimensionality.

    This decorator is applied to `SentenceTransformer.forward`.
    """

    def __init__(self, fn) -> None:
        self.fn = fn

        self.dim = None
        self.cache = []
        self.cache_dim = None
        self.idx = 0

    def set_dim(self, dim) -> None:
        self.dim = dim
        self.idx = 0

    def __call__(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
        # Growing cache:
        if self.cache_dim is None or self.cache_dim == self.dim:
            output = self.fn(features)
            self.cache.append(output)
            self.cache_dim = self.dim
        # Using cache:
        else:
            output = self.cache[self.idx]
        if "token_embeddings" in output:
            output["token_embeddings"] = shrink(output["token_embeddings"], self.dim)
        output["sentence_embedding"] = shrink(output["sentence_embedding"], self.dim)
        self.idx += 1
        return output


class CachedLossDecorator:
    """
    This decorator is used with the Cached... losses to compute the underlying loss function
    for each Matryoshka dimensionality. This is done by shrinking the pre-computed embeddings
    to the desired dimensionality and then passing them to the underlying loss function once
    for each desired dimensionality.

    This decorator is applied to the `calculate_loss` method of the Cached... losses.
    """

    def __init__(
        self, fn, matryoshka_dims: list[int], matryoshka_weights: list[float | int], n_dims_per_step: int = -1
    ) -> None:
        self.fn = fn
        self.matryoshka_dims = matryoshka_dims
        self.matryoshka_weights = matryoshka_weights
        self.n_dims_per_step = n_dims_per_step

    def __call__(self, reps: list[list[Tensor]], *args, **kwargs) -> Tensor:
        dim_indices = range(len(self.matryoshka_dims))
        if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
            dim_indices = random.sample(dim_indices, self.n_dims_per_step)

        loss = 0.0
        for idx in dim_indices:
            dim = self.matryoshka_dims[idx]
            weight = self.matryoshka_weights[idx]

            truncated = [[shrink(r, dim) for r in minibatch] for minibatch in reps]
            compute_gradients = torch.is_grad_enabled()
            # we need to detach the truncated embeddings,
            # otherwise the first backward pass of the underlying function will clear the computation graph of the embedding truncation
            if compute_gradients:
                matryoshka_reps = [[r.detach().requires_grad_() for r in minibatch] for minibatch in truncated]
            else:
                matryoshka_reps = truncated
            loss += weight * self.fn(matryoshka_reps, *args, **kwargs)
            # After computing the gradients in minibatches, we need to continue the backward pass through the truncation calculation
            # the gradients must be multiplied with the weights because otherwise the matryoshka weights are not considered in the backward pass
            if compute_gradients:
                for t_minibatch, d_minibatch in zip(truncated, matryoshka_reps):
                    for t, d in zip(t_minibatch, d_minibatch):
                        t.backward(weight * d.grad)
        return loss


class MatryoshkaLoss(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        loss: nn.Module,
        matryoshka_dims: list[int],
        matryoshka_weights: list[float | int] | None = None,
        n_dims_per_step: int = -1,
    ) -> None:
        """
        The MatryoshkaLoss can be seen as a loss *modifier* that allows you to use other loss functions at various
        different embedding dimensions. This is useful for when you want to train a model where users have the option
        to lower the embedding dimension to improve their embedding comparison speed and costs.

        This loss is also compatible with the Cached... losses, which are in-batch negative losses that allow for
        higher batch sizes. The higher batch sizes allow for more negatives, and often result in a stronger model.

        Args:
            model: SentenceTransformer model
            loss: The loss function to be used, e.g.
                :class:`MultipleNegativesRankingLoss`,
                :class:`CoSENTLoss`, etc.
            matryoshka_dims: A list of embedding dimensions to be used
                for the loss function, e.g. [768, 512, 256, 128, 64].
            matryoshka_weights: A list of weights to be used for the
                loss function, e.g. [1, 1, 1, 1, 1]. If None, then the
                weights will be set to 1 for all dimensions.
            n_dims_per_step: The number of dimensions to use per step.
                If -1, then all dimensions are used. If > 0, then a
                random sample of n_dims_per_step dimensions are used per
                step. The default value is -1.

        References:
            - The concept was introduced in this paper: https://arxiv.org/abs/2205.13147
            - `Matryoshka Embeddings <../../../examples/sentence_transformer/training/matryoshka/README.html>`_

        Inputs:
            +---------------------------------------+--------+
            | Texts                                 | Labels |
            +=======================================+========+
            | any                                   | any    |
            +---------------------------------------+--------+

        Relations:
            - :class:`Matryoshka2dLoss` uses this loss in combination with :class:`AdaptiveLayerLoss` which allows for
                layer reduction for faster inference.

        Example:
            ::

                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset

                model = SentenceTransformer("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "anchor": ["It's nice weather outside today.", "He drove to work."],
                    "positive": ["It's so sunny.", "He took the car to the office."],
                })
                loss = losses.MultipleNegativesRankingLoss(model)
                loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])

                trainer = SentenceTransformerTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
                )
                trainer.train()
        """
        super().__init__()
        self.model = model
        self.loss = loss

        if matryoshka_weights is None:
            matryoshka_weights = [1] * len(matryoshka_dims)
        # Sort the dimensions and weights in descending order
        dims_weights = zip(matryoshka_dims, matryoshka_weights)
        self.matryoshka_dims, self.matryoshka_weights = zip(*sorted(dims_weights, key=lambda x: x[0], reverse=True))
        self.n_dims_per_step = n_dims_per_step

        # The Cached... losses require a special treatment as their backward pass is incompatible with the
        # ForwardDecorator approach. Instead, we use a CachedLossDecorator to compute the loss for each
        # Matryoshka dimensionality given pre-computed embeddings passed to `calculate_loss`.
        self.cached_losses = (
            CachedMultipleNegativesRankingLoss,
            CachedGISTEmbedLoss,
            CachedMultipleNegativesSymmetricRankingLoss,
        )
        if isinstance(loss, self.cached_losses):
            loss.calculate_loss = CachedLossDecorator(
                loss.calculate_loss, self.matryoshka_dims, self.matryoshka_weights
            )

    def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
        # For the Cached... losses, the CachedLossDecorator has been applied to the `calculate_loss` method,
        # so we can directly call the loss function.
        if isinstance(self.loss, self.cached_losses):
            return self.loss(sentence_features, labels)

        # Otherwise, we apply the ForwardDecorator to the model's forward pass, which will cache the output
        # embeddings for each Matryoshka dimensionality, allowing it to be reused for the smaller dimensions.
        original_forward = self.model.forward
        try:
            decorated_forward = ForwardDecorator(original_forward)
            self.model.forward = decorated_forward

            dim_indices = range(len(self.matryoshka_dims))
            if self.n_dims_per_step > 0 and self.n_dims_per_step < len(dim_indices):
                dim_indices = random.sample(dim_indices, self.n_dims_per_step)
                dim_indices.sort()

            loss = 0.0
            for idx in dim_indices:
                dim = self.matryoshka_dims[idx]
                weight = self.matryoshka_weights[idx]
                decorated_forward.set_dim(dim)
                loss += weight * self.loss(sentence_features, labels)
        finally:
            self.model.forward = original_forward
        return loss

    def get_config_dict(self) -> dict[str, Any]:
        return {
            "loss": self.loss.__class__.__name__,
            "matryoshka_dims": self.matryoshka_dims,
            "matryoshka_weights": self.matryoshka_weights,
            "n_dims_per_step": self.n_dims_per_step,
        }

    @property
    def citation(self) -> str:
        return """
@misc{kusupati2024matryoshka,
    title={Matryoshka Representation Learning},
    author={Aditya Kusupati and Gantavya Bhatt and Aniket Rege and Matthew Wallingford and Aditya Sinha and Vivek Ramanujan and William Howard-Snyder and Kaifeng Chen and Sham Kakade and Prateek Jain and Ali Farhadi},
    year={2024},
    eprint={2205.13147},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
"""
