import copy
import functools
import importlib.metadata
import inspect
import json
import os
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union

import torch
from packaging import version

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6

from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_optimum_quanto_available, is_torch_greater_or_equal, logging


if is_hqq_available():
    from hqq.core.quantize import Quantizer as HQQQuantizer


logger = logging.get_logger(__name__)


class CacheLayerMixin(ABC):
    """Base, abstract class for a single layer's cache."""

    is_compileable = False

    def __init__(self):
        self.keys, self.values = None, None

    @abstractmethod
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]: ...

    @abstractmethod
    def get_seq_length(self, cache_position=None) -> int: ...

    @abstractmethod
    def get_max_cache_shape(self) -> int: ...

    @abstractmethod
    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...

    def reset(self) -> None:
        """Resets the cache values while preserving the objects"""
        self.keys.zero_()
        self.values.zero_()

    def reorder_cache(self, beam_idx: torch.LongTensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Reorders this layer's cache for beam search."""
        if self.keys.numel():
            device = self.keys.device
            self.keys = self.keys.index_select(0, beam_idx.to(device))
        if self.values.numel():
            device = self.values.device
            self.values = self.values.index_select(0, beam_idx.to(device))


class DynamicLayer(CacheLayerMixin):
    """
    A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
    It stores the Key and Value states as tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.

    See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
    """

    is_sliding = False

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            cache_kwargs (`dict[str, Any]`, *optional*):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicLayer`.

        Return:
            A tuple containing the updated key and value states.
        """
        if self.keys is None:
            self.keys = key_states
            self.values = value_states
        else:
            self.keys = torch.cat([self.keys, key_states], dim=-2)
            self.values = torch.cat([self.values, value_states], dim=-2)
        return self.keys, self.values

    def get_seq_length(self, cache_position=None) -> int:
        """Returns the sequence length of the cached states."""
        if self.keys is None or self.keys.numel() == 0:
            return 0
        return self.keys.shape[-2]

    def get_max_cache_shape(self) -> int:
        """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
        return -1

    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        """Reorders the cache for beam search, given the selected beam indices."""
        if self.keys is not None and self.keys.numel():
            self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
            self.values = self.values.index_select(0, beam_idx.to(self.values.device))

    def crop(self, max_length: int) -> None:
        """
        Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
        negative to remove `max_length` tokens.
        """
        if max_length < 0:
            max_length = self.get_seq_length() - abs(max_length)

        if self.get_seq_length() <= max_length:
            return

        if self.keys is not None and self.keys.numel():
            self.keys = self.keys[..., :max_length, :]
            self.values = self.values[..., :max_length, :]

    def batch_repeat_interleave(self, repeats: int) -> None:
        """Repeat the cache `repeats` times in the batch dimension."""
        if self.keys is not None and self.keys.numel():
            self.keys = self.keys.repeat_interleave(repeats, dim=0)
            self.values = self.values.repeat_interleave(repeats, dim=0)

    def batch_select_indices(self, indices: torch.Tensor) -> None:
        """Only keep the `indices` in the batch dimension of the cache."""
        if self.keys is not None and self.keys.numel():
            self.keys = self.keys[indices, ...]
            self.values = self.values[indices, ...]

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        """Return the length and offset of the cache, used to generate the mask"""
        kv_offset = 0
        query_length = cache_position.shape[0]
        past_seen_tokens = self.get_seq_length()
        kv_length = query_length + past_seen_tokens
        return kv_length, kv_offset

    @classmethod
    def from_tensors(cls, keys: torch.Tensor, values: torch.Tensor) -> "DynamicLayer":
        """
        Build a `DynamicLayer` instance from pre-existing key/value tensors.

        Args:
            keys (`torch.Tensor`):
                Key cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.
            values (`torch.Tensor`):
                Value cache tensor of shape ``[batch_size, num_heads, seq_len, head_dim]``.

        Returns:
            `DynamicLayer`: The newly constructed layer whose internal cache directly references
            the supplied tensors.
        """
        layer = cls()
        layer.keys = keys
        layer.values = values
        return layer


class StaticLayer(CacheLayerMixin):
    """
    A static cache layer that stores the Key and Value states as static tensors with shape `[batch_size, num_heads, seq_len, head_dim]`.
    It allocates its full backing tensors up-front and mutates them in-place. Built for `torch.compile` support.

    See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
    """

    is_compileable = True
    is_sliding = False

    def __init__(
        self,
        max_cache_len: int,
        batch_size: int,
        num_heads: int,
        head_dim: int,
        dtype: torch.dtype = torch.float32,
        device: str = "cpu",
        sliding_window: Optional[int] = None,
    ):
        """
        Args:
            max_cache_len (`int`):
                Maximum number of tokens that can be stored, used for tensor preallocation.
            batch_size (`int`):
                Maximum batch size the cache is pre-allocated for.
            num_heads (`int`):
                Number of attention heads.
            head_dim (`int`):
                Per-head hidden dimension.
            dtype (`torch.dtype`, defaults to `torch.float32`):
                Data type of the cache tensors.
            device (`str` or `torch.device`, defaults to `"cpu"`):
                Device on which the cache tensors will be materialised.

        Notes:
            Static layers allocate their full backing tensors up-front and mutate them
            in-place. See the documentation of `Cache` for shared helper methods that
            operate uniformly across all layer types.
        """
        self.max_cache_len = max_cache_len
        self.max_batch_size = batch_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.dtype = dtype
        self.device = device

        self.keys = torch.zeros(
            (batch_size, num_heads, self.max_cache_len, head_dim),
            dtype=dtype,
            device=device,
        )
        self.values = torch.zeros(
            (batch_size, num_heads, self.max_cache_len, head_dim),
            dtype=dtype,
            device=device,
        )
        # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
        # preventing compiled graph breaks when updating the cache.
        torch._dynamo.mark_static_address(self.keys)
        torch._dynamo.mark_static_address(self.values)

    def get_max_cache_shape(self) -> int:
        """Return the maximum cache shape of the cache"""
        return self.max_cache_len

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Update the static cache tensors in place.

        Args:
            key_states (`torch.Tensor`): The new key states to cache.
            value_states (`torch.Tensor`): The new value states to cache.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states.
        """
        cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
        key_states = key_states.to(self.keys.dtype)
        value_states = value_states.to(self.values.dtype)

        # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
        # the device_map. However, even if it is the case, this will only run once, because then the new states received
        # will always have the same device
        if self.device != key_states.device:
            self.device = key_states.device
            self.keys = self.keys.to(self.device)
            self.values = self.values.to(self.device)

        if cache_position is None:
            # Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
            self.keys.copy_(key_states)
            self.values.copy_(value_states)
        else:
            # Generation phase. Update specific positions.
            # Use index_copy_ for in-place update (compile-friendly).
            try:
                self.keys.index_copy_(2, cache_position, key_states)
                self.values.index_copy_(2, cache_position, value_states)
            except NotImplementedError:
                # Fallback for devices like MPS where index_copy_ might not be supported.
                self.keys[:, :, cache_position] = key_states
                self.values[:, :, cache_position] = value_states
        return self.keys, self.values

    def get_seq_length(self, cache_position=None) -> int:
        """Returns the sequence length of the cached states."""
        if cache_position is not None:
            return int(cache_position[-1] + 1)
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        seq_length = (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
        return seq_length

    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
        """Reorders the cache for beam search, given the selected beam indices."""
        dev = self.keys.device
        beam_idx_dev = beam_idx.to(dev)
        self.keys = self.keys.index_select(0, beam_idx_dev)
        self.values = self.values.index_select(0, beam_idx_dev)

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        """Return the length and offset of the cache, used to generate the attention mask"""
        kv_offset = 0
        kv_length = self.max_cache_len
        return kv_length, kv_offset


class SlidingWindowLayer(StaticLayer):
    """
    A static cache layer that implements sliding window attention caching.

    See `CacheLayerMixin` for details on common methods that are implemented by all cache layers.
    """

    is_sliding = True

    def __init__(self, sliding_window, *args, **kwargs):
        """
        Args:
            sliding_window (`int`):
                Effective window size: number of tokens that are kept on each update call.
        """
        max_cache_len = kwargs.pop("max_cache_len", None)
        max_cache_len = min(sliding_window, max_cache_len) if max_cache_len is not None else sliding_window
        super().__init__(*args, max_cache_len=max_cache_len, *args, **kwargs)

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Update the sliding window cache tensors in place.

        Args:
            key_states (`torch.Tensor`): The new key states to cache.
            value_states (`torch.Tensor`): The new value states to cache.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value states.
        """
        cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
        if cache_position is None:
            raise ValueError("`cache_position` must be provided for SlidingWindowLayer.")

        # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
        # the device_map. However, even if it is the case, this will only run once, because then the new states received
        # will always have the same device
        if self.device != key_states.device:
            self.device = key_states.device
            self.keys = self.keys.to(self.device)
            self.values = self.values.to(self.device)

        key_states = key_states.to(self.keys.dtype)
        value_states = value_states.to(self.values.dtype)

        # Handle prefill phase when prompt length > sliding_window_size.
        # Note that we store cropped key/value states in the cache but return the full key/value states.
        if cache_position.shape[0] > self.max_cache_len:
            new_k = key_states[:, :, -self.max_cache_len :, :]
            new_v = value_states[:, :, -self.max_cache_len :, :]
            self.keys.copy_(new_k)
            self.values.copy_(new_v)
            return key_states, value_states

        # Sliding window logic for generation phase or prefill < window
        slicing = torch.arange(self.max_cache_len, device=self.device)
        current_seq_len = cache_position[-1] + 1  # Use last position to determine current length
        to_shift = current_seq_len > self.max_cache_len
        indices = (slicing + to_shift.sum()) % self.max_cache_len

        k_out_shifted = self.keys[:, :, indices]
        v_out_shifted = self.values[:, :, indices]

        # Clamp cache_position to determine the *target index* within the shifted cache view
        update_position = cache_position.clamp(min=0, max=self.max_cache_len - 1)

        try:
            k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
            v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
        except NotImplementedError:
            # Fallback for MPS: clone and modify the clone
            k_out_updated = k_out_shifted.clone()
            v_out_updated = v_out_shifted.clone()
            k_out_updated[:, :, update_position] = key_states
            v_out_updated[:, :, update_position] = value_states

        self.keys.copy_(k_out_updated)
        self.values.copy_(v_out_updated)
        return self.keys, self.values

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        """Return the length and offset of the cache, used to generate the attention mask"""
        query_length = cache_position.shape[0]
        first_cache_position = cache_position[0]

        kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0)
        # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
        kv_length = max(query_length, self.max_cache_len)
        return kv_length, kv_offset


class ChunkedSlidingLayer(SlidingWindowLayer):
    """
    An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.

    See `SlidingWindowLayer` for details on common methods that are implemented by all cache layers.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cumulative_length = 0

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        cache_position = cache_kwargs.get("cache_position") if cache_kwargs else None
        if cache_position is None:
            raise ValueError("`cache_position` must be provided for ChunkedSlidingLayer.")

        # This may be needed if the Layer was not created with the right device in the beginning, i.e. if it did not respect
        # the device_map. However, even if it is the case, this will only run once, because then the new states received
        # will always have the same device
        if self.device != key_states.device:
            self.device = key_states.device
            self.keys = self.keys.to(self.device)
            self.values = self.values.to(self.device)

        cumulative_length = self.cumulative_length
        self.cumulative_length += key_states.shape[-2]
        is_full = cumulative_length >= self.max_cache_len

        if is_full:
            full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
            full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
            # Fast decoding path -> here as the effective size is still sliding window, it is extremely important
            # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address
            # in memory (the values are the same as the full states, but not the address!!)
            if key_states.shape[-2] == 1:
                self.keys.copy_(full_key_states)
                self.values.copy_(full_value_states)
                return self.keys, self.values
        elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len:
            if cumulative_length == 0:
                full_key_states = key_states
                full_value_states = value_states
            else:
                full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
                full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
        else:
            try:
                self.keys.index_copy_(2, cache_position, key_states)
                self.values.index_copy_(2, cache_position, value_states)
            except NotImplementedError:
                self.keys[:, :, cache_position] = key_states
                self.values[:, :, cache_position] = value_states
            return self.keys, self.values

        self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
        self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
        return full_key_states, full_value_states

    def reset(self) -> None:
        super().reset()
        self.cumulative_length = 0

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        query_length = cache_position.shape[0]
        first_cache_position = cache_position[0]
        sliding_window = self.max_cache_len

        kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0)
        # This is the true general case for any Cache using local attention (sliding or chunked)
        if first_cache_position >= sliding_window:
            # Here the Cache is already full
            kv_length = sliding_window + query_length - 1
        elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window:
            # Here the Cache becomes full with the new input
            kv_length = first_cache_position + query_length
        else:
            # Here the Cache is still smaller than the local size, but we return the local size as it's static
            kv_length = sliding_window
        return kv_length, kv_offset


class CacheProcessor:
    """
    Base class for cache processors. It defines a pre-update and post-update methods that are called before and after the cache update.
    This class should be subclassed.
    """

    def __init__(self, cache: "Cache", **kwargs) -> None:
        """
        Initialize the processor and perform compatibility checks with the cache.

        Args:
            cache (`Cache`): The cache instance this processor will be applied to.
            **kwargs: Additional arguments that may be needed for initialization.
        """
        raise NotImplementedError(f"Make sure to implement `init` in {self.__class__.__name__}.")

    def pre_update(
        self,
        cache: "Cache",
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Function called before the cache update. Can modify the key/value states.

        Args:
            cache (`Cache`): The cache instance.
            key_states (`torch.Tensor`): The new key states to cache.
            value_states (`torch.Tensor`): The new value states to cache.
            layer_idx (`int`): The index of the layer to cache the states for.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            The modified key and value states.
        """
        return key_states, value_states

    def post_update(
        self,
        cache: "Cache",
        key_tensors: torch.Tensor,
        value_tensors: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Function called after the cache update. Can process the cached data.

        Args:
            cache (`Cache`): The cache instance.
            key_states (`torch.Tensor`): The key states that were cached.
            value_states (`torch.Tensor`): The value states that were cached.
            layer_idx (`int`): The index of the layer that was updated.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            The final key and value states to return to the model.
        """
        return key_tensors, value_tensors


class OffloadedCacheProcessor(CacheProcessor):
    """
    A cache processor that offloads cache tensors to conserve accelerator memory.

    This processor manages moving cache tensors between accelerator and CPU memory,
    using asynchronous prefetching to minimize performance impact. Works with both
    dynamic and static layers.
    """

    def __init__(self, cache: "Cache", offload_device: Union[str, torch.device] = "cpu", **kwargs):
        """Initialize the offload processor and check device compatibility."""
        self.offload_device = torch.device(offload_device)
        self.original_device = []
        self.prefetch_stream = None
        self.beam_idx = None

        if not (
            torch.cuda.is_available()
            or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available())
        ):
            raise RuntimeError(
                "OffloadedCacheProcessor can only be used with a GPU"
                + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "")
            )

        self.is_static = any(isinstance(layer, StaticLayer) for layer in cache.layers)
        if self.is_static:
            for i, layer in enumerate(cache.layers):
                device = cache.layer_init_kwargs["device"] if i == 0 else self.offload_device
                layer.keys = layer.keys.to(device)
                layer.values = layer.values.to(device)
                self.original_device.append(cache.layer_init_kwargs["device"])
            if len(cache) != cache.num_hidden_layers:
                raise ValueError("If static layers are used, all cache layers must be initialized")

        self.prefetch_stream = (
            torch.Stream() if is_torch_greater_or_equal("2.7", accept_dev=True) else torch.cuda.Stream()
        )

    def pre_update(
        self,
        cache: "Cache",
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Handles prefetching and eviction before cache update."""
        # Update the cache
        if len(cache) < layer_idx:
            raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
        elif len(cache) == layer_idx:
            self.original_device.append(key_states.device)
            self._evict_previous_layer(cache, layer_idx)
        else:
            # Wait for the previous layer to be evicted (on default stream)
            if is_torch_greater_or_equal("2.7", accept_dev=True):
                torch.accelerator.current_stream().synchronize()
            else:
                torch.cuda.current_stream().synchronize()
            self._evict_previous_layer(cache, layer_idx)
            self._ensure_layer_on_device(cache, layer_idx)

            # Prefetch the next layer
            self._prefetch_layer(cache, (layer_idx + 1) % len(cache))
        return key_states, value_states

    def _prefetch_layer(self, cache: "Cache", layer_idx: int):
        """Starts prefetching the next layer cache."""
        if layer_idx < len(cache):
            with (
                self.prefetch_stream
                if is_torch_greater_or_equal("2.7", accept_dev=True)
                else torch.cuda.stream(self.prefetch_stream)
            ):
                # Prefetch next layer tensors to GPU
                device = self.original_device[layer_idx]
                cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.to(device, non_blocking=True)
                cache.layers[layer_idx].values = cache.layers[layer_idx].values.to(device, non_blocking=True)

    def _evict_previous_layer(self, cache: "Cache", layer_idx: int):
        """Moves the previous layer cache to the CPU."""
        if len(cache) >= 2:  # Layer 0 stays on device to be on-device after all layers are created
            # We do it on the default stream so it occurs after all earlier computations on these tensors are done
            prev_layer_idx = (layer_idx - 1) % len(cache)
            cache.layers[prev_layer_idx].keys = cache.layers[prev_layer_idx].keys.to(
                self.offload_device, non_blocking=True
            )
            cache.layers[prev_layer_idx].values = cache.layers[prev_layer_idx].values.to(
                self.offload_device, non_blocking=True
            )

    def _ensure_layer_on_device(self, cache: "Cache", layer_idx: int):
        """Ensures the current layer is on the original device."""
        if layer_idx < len(cache):
            # Wait for the previous prefetch to be done
            self.prefetch_stream.synchronize()

            # Handle delayed beam search operations
            if self.beam_idx is not None:
                self.beam_idx = self.beam_idx.to(self.original_device[layer_idx])
                cache.layers[layer_idx].keys = cache.layers[layer_idx].keys.index_select(0, self.beam_idx)
                cache.layers[layer_idx].values = cache.layers[layer_idx].values.index_select(0, self.beam_idx)


class QuantizedCacheProcessor(CacheProcessor):
    """
    A cache processor that applies quantization to cache tensors to reduce memory usage.

    This processor quantizes cache tensors after they are stored, maintaining a residual
    length in original precision and quantizing older tokens.
    """

    def __init__(
        self,
        cache: "Cache",
        backend: str = "quanto",
        nbits: int = 4,
        axis_key: int = 0,
        axis_value: int = 0,
        q_group_size: int = 64,
        residual_length: int = 128,
        compute_dtype: torch.dtype = torch.float16,
        device: str = "cpu",
    ):
        """
        Parameters:
            backend (`str`, defaults to `"quanto"`):
                Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
            nbits (`int`, defaults to 4):
                Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
            axis_key (`int`, defaults to 0):
                Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
            axis_value (`int`, defaults to 0):
                Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
            q_group_size (`int`, defaults to 64):
                Size of the quantization group, should be a divisor of the model's hidden dimension.
                Defaults to 64.
            residual_length (`int`, defaults to 128):
                Length of the residual cache which will always be stored in original precision.
                Defaults to 128.
            compute_dtype (`torch.dtype`, defaults to `torch.float16`):
                The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
            device (`str`, defaults to `"cpu"`):
                Device on which to perform computations, should be same as the model's device.
        """
        self.backend = backend
        self.nbits = nbits
        self.axis_key = axis_key
        self.axis_value = axis_value
        self.q_group_size = q_group_size
        self.residual_length = residual_length
        self.compute_dtype = compute_dtype
        self.device = device
        self._quantized_keys: list[torch.Tensor] = []
        self._quantized_values: list[torch.Tensor] = []

        self.validate()
        self.erased_length = 0

        # Only compatible with DynamicCache
        if not isinstance(cache.layers[0], DynamicLayer):
            raise ValueError("QuantizedCacheProcessor is only compatible with DynamicCache")

    def validate(self):
        """Validates if the arguments passed are correct"""

        incorrect_arg_msg = (
            "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
            "but found {found_value}"
        )
        # Check that the values are reasonable in general (nbits, axis)
        # Later in QuantizedCache init we check if they are supported for that particular backend
        if self.nbits not in [1, 2, 3, 4, 8]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="nbits",
                    correct_value="2 or 4 or 8",
                    found_value=self.nbits,
                ),
            )
        if self.q_group_size <= 0:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="q_group_size",
                    correct_value="a positive integer",
                    found_value=self.q_group_size,
                ),
            )
        if self.residual_length < 0:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="residual_length",
                    correct_value="a positive integer",
                    found_value=self.residual_length,
                ),
            )

        if self.axis_key not in [0, 1, -1]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="axis_key",
                    correct_value="`1` or `0`, `-1`",
                    found_value=self.axis_key,
                ),
            )

        if self.axis_value not in [0, 1, -1]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="axis_value",
                    correct_value="`1` or `0` or `-1`",
                    found_value=self.axis_value,
                ),
            )

    def post_update(
        self,
        cache: "Cache",
        key_tensors: torch.Tensor,
        value_tensors: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply quantization after cache update."""

        if len(cache) < layer_idx:
            raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")

        # `key_tensors` is the content of the residual cache, after having been updated by DynamicLayer
        # On the first forward pass, we quantize the whole prompt (prefill, quantize_length=0)
        # On subsequent passes, we accumulate the tokens in the residual cache and quantize when it is full.
        if self._is_quantized_length_zero(layer_idx):
            self._quantized_keys.append(self._quantize(key_tensors.contiguous(), axis=self.axis_key))
            self._quantized_values.append(self._quantize(value_tensors.contiguous(), axis=self.axis_value))

            # Clear the residual cache
            self.erased_length = key_tensors.shape[-2]
            cache.layers[layer_idx].keys = torch.zeros(
                0,
                dtype=key_tensors.dtype,
                device=key_tensors.device,
            )
            cache.layers[layer_idx].values = torch.zeros(
                0,
                dtype=value_tensors.dtype,
                device=value_tensors.device,
            )
            # On prefill, we return the original prompt
            keys_to_return, values_to_return = key_tensors, value_tensors

        else:
            # Prepend the previously quantized cache
            dequant_key = self._dequantize(self._quantized_keys[layer_idx])
            dequant_value = self._dequantize(self._quantized_values[layer_idx])
            keys_to_return = torch.cat([dequant_key, key_tensors], dim=-2)
            values_to_return = torch.cat([dequant_value, value_tensors], dim=-2)
            if key_tensors.shape[-2] >= self.residual_length:
                # Quantize and store
                self._quantized_keys[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
                self._quantized_values[layer_idx] = self._quantize(values_to_return.contiguous(), axis=self.axis_value)

                # Clear the residual cache
                self.erased_length += key_tensors.shape[-2]
                cache.layers[layer_idx].keys = torch.zeros(
                    0,
                    dtype=key_tensors.dtype,
                    device=key_tensors.device,
                )
                cache.layers[layer_idx].values = torch.zeros(
                    0,
                    dtype=value_tensors.dtype,
                    device=value_tensors.device,
                )

        return keys_to_return, values_to_return

    def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor:
        """Quantize a tensor - to be implemented by specific quantization backends."""
        raise NotImplementedError("Quantization backend must implement _quantize method")

    def _dequantize(self, tensor: torch.Tensor) -> torch.Tensor:
        """Dequantize a tensor - to be implemented by specific quantization backends."""
        raise NotImplementedError("Quantization backend must implement _dequantize method")

    def _is_quantized_length_zero(self, layer_idx: int) -> bool:
        """Check if quantized cache is empty for layer. Note: shape[-2] is unreliable since quantized tensors are bit-packed and flattened."""
        return layer_idx >= len(self._quantized_keys)


class QuantoQuantizedCacheProcessor(QuantizedCacheProcessor):
    """
    Quantized cache processor that uses `quanto` as a backend to perform quantization.
    Current implementation supports `int2` and `int4` dtypes only.
    """

    def __init__(
        self,
        cache: "Cache",
        backend: str = "quanto",
        nbits: int = 4,
        axis_key: int = 0,
        axis_value: int = 0,
        q_group_size: int = 64,
        residual_length: int = 128,
        compute_dtype: torch.dtype = torch.float16,
        device: str = "cpu",
    ) -> None:
        """Initialize the quanto quantization processor."""
        super().__init__(
            cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device
        )

        if backend != "quanto":
            raise ValueError(f"QuantoQuantizedCacheProcessor only supports `quanto` backend, but got {backend}")

        if is_optimum_quanto_available():
            optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto"))
            if optimum_quanto_version <= version.parse("0.2.5"):
                raise ImportError(
                    f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCacheProcessor`. Detected version {optimum_quanto_version}."
                )
            from optimum.quanto import MaxOptimizer, qint2, qint4

        if self.nbits not in [2, 4]:
            raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")

        if self.axis_key not in [0, -1]:
            raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")

        if self.axis_value not in [0, -1]:
            raise ValueError(
                f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
            )

        self.qtype = qint4 if self.nbits == 4 else qint2
        self.optimizer = MaxOptimizer()

    def _quantize(self, tensor: torch.Tensor, axis: int) -> torch.Tensor:
        """Quantize tensor using quanto backend."""
        if is_optimum_quanto_available():
            from optimum.quanto import quantize_weight

            scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
            qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
            return qtensor

    def _dequantize(self, qtensor: torch.Tensor) -> torch.Tensor:
        """Dequantize tensor using quanto backend."""
        return qtensor.dequantize()


class HQQQuantizedCacheProcessor(QuantizedCacheProcessor):
    """
    Quantized cache processor that uses `HQQ` as a backend to perform quantization.
    Current implementation supports `int2`, `int4`, `int8` dtypes.
    """

    def __init__(
        self,
        cache: "Cache",
        backend: str = "quanto",
        nbits: int = 4,
        axis_key: int = 0,
        axis_value: int = 0,
        q_group_size: int = 64,
        residual_length: int = 128,
        compute_dtype: torch.dtype = torch.float16,
        device: str = "cpu",
    ) -> None:
        """Initialize the HQQ quantization processor."""
        super().__init__(
            cache, backend, nbits, axis_key, axis_value, q_group_size, residual_length, compute_dtype, device
        )

        if backend != "quanto":
            raise ValueError(f"HQQQuantizedCacheProcessor only supports `quanto` backend, but got {backend}")

        if self.nbits not in [1, 2, 3, 4, 8]:
            raise ValueError(
                f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
            )

        if self.axis_key not in [0, 1]:
            raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")

        if self.axis_value not in [0, 1]:
            raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")

        self.quantizer = HQQQuantizer

    def _quantize(self, tensor: torch.Tensor, axis: int) -> tuple[torch.Tensor, dict]:
        """Quantize tensor using HQQ backend."""
        qtensor, meta = self.quantizer.quantize(
            tensor,
            axis=axis,
            device=self.device,
            compute_dtype=self.compute_dtype,
            nbits=self.nbits,
            group_size=self.q_group_size,
        )
        meta["compute_dtype"] = self.compute_dtype
        self.quantizer.cuda(qtensor, meta=meta, device=self.device)  # Move to device and cast to dtype
        meta["scale"] = meta["scale"].to(qtensor.device)
        meta["zero"] = meta["zero"].to(qtensor.device)
        return qtensor, meta

    def _dequantize(self, qtensor_and_meta: tuple[torch.Tensor, dict]) -> torch.Tensor:
        """Dequantize tensor using HQQ backend."""
        quant_tensor, meta = qtensor_and_meta
        tensor = self.quantizer.dequantize(quant_tensor, meta)
        return tensor


def apply_processors(
    fn: Callable[..., tuple[torch.Tensor, torch.Tensor]],
) -> Callable[..., tuple[torch.Tensor, torch.Tensor]]:
    @functools.wraps(fn)
    def _wrapped_update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Wrapper around the update method to apply cache processors.
        """
        if self.cache_processor is not None:
            key_states, value_states = self.cache_processor.pre_update(
                self, key_states, value_states, layer_idx, cache_kwargs
            )

        key_tensors, value_tensors = fn(self, key_states, value_states, layer_idx, cache_kwargs)

        if self.cache_processor is not None:
            key_tensors, value_tensors = self.cache_processor.post_update(
                self, key_tensors, value_tensors, layer_idx, cache_kwargs
            )

        return key_tensors, value_tensors

    return _wrapped_update


class KeyValuesWrapper:
    """Helper class for Cache that simulates layer-indexed key/value lists from a layered cache.
    This allows for BC access and writing, e.g., cache.key_cache[idx] = ...
    Deprecated in favor of Cache.layers[idx].keys/values. TODO: remove in v4.56.0"""

    def __init__(self, layers, cache_type="keys"):
        self.layers = layers
        self.cache_type = cache_type

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            return [getattr(layer, self.cache_type) for layer in self.layers[idx]]
        return getattr(self.layers[idx], self.cache_type)

    def __setitem__(self, idx, value):
        if isinstance(idx, slice):
            for layer, val in zip(self.layers[idx], value):
                setattr(layer, self.cache_type, val)
        else:
            setattr(self.layers[idx], self.cache_type, value)

    def __len__(self):
        return len(self.layers)

    def __iter__(self):
        for layer in self.layers:
            yield getattr(layer, self.cache_type)

    def __bool__(self):
        return bool(self.layers)


class Cache:
    """
    Base container for per-layer key/value caches.

    A `Cache` behaves like a list of `CacheLayerMixin` objects, one per model layer.
    Sub-classes such as `DynamicCache`, `StaticCache`, or `SlidingWindowCache`
    simply pre-select which `CacheLayerMixin` class to use and may attach a
    `CacheProcessor` (off-loading, quantization).

    Example
    -------
    ```python
    from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache

    model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
    tok   = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
    inputs = tok("Hello", return_tensors="pt")

    cache = DynamicCache()
    outputs = model(**inputs, past_key_values=cache, use_cache=True)
    ```

    Parameters:
        layer_classes (`type[CacheLayerMixin]` or `list[type[CacheLayerMixin]]`):
            A list of `CacheLayerMixin` classes to instantiate for the cache. If only a `CacheLayerMixin` class is
            provided, then it is used for all layers.
        config (`PretrainedConfig`, *optional*):
            Model configuration used to infer number of layers, head sizes, default
            device/dtype, etc.
        cache_processor (`CacheProcessor` or `str`, *optional*):
            Cache processor to apply (e.g., "offloaded", "quanto_quantized", "hqq_quantized")
            or a CacheProcessor class.
        max_batch_size (`int`, *optional*): Maximum batch size for static caches.
        max_cache_len (`int`, *optional*): Maximum sequence length. For hybrid caches, SlidingWindowLayers are
            clamped to `min(sliding_window, max_cache_len)`, StaticLayers use full `max_cache_len`.
        device (`torch.device`, *optional*): Device for cache tensors.
        dtype (`torch.dtype`, *optional*): Data type for cache tensors.
        layer_device_map (`dict[int, Union[str, torch.device]]`, *optional*): Per-layer device mapping.
        tp_size (`int`, *optional*): Tensor parallel size to adjust the number of key/value heads.

    Additional keyword arguments are forwarded to the chosen layers constructor(s) and CacheProcessors. See the
    documentation of the relevant `CacheLayerMixin` class and `CacheProcessor` class for more details.
    """

    def __init__(
        self,
        layer_classes: Union[list[type[CacheLayerMixin]], type[CacheLayerMixin]],
        config: Optional[PretrainedConfig] = None,
        cache_processor: Optional[Union[str, type[CacheProcessor]]] = None,
        max_batch_size: Optional[int] = None,
        max_cache_len: Optional[int] = None,
        device: Union[torch.device, str, None] = None,
        dtype: Optional[torch.dtype] = None,
        layer_device_map: Optional[dict[int, torch.device]] = None,
        tp_size: Optional[int] = None,
        **kwargs,
    ):
        self.layers: list[CacheLayerMixin] = []
        self.layer_classes = layer_classes

        processor_class = PROCESSOR_CLASS_MAP[cache_processor] if isinstance(cache_processor, str) else cache_processor
        kwargs.update(
            max_batch_size=max_batch_size,
            max_cache_len=max_cache_len,
            device=device,
            dtype=dtype,
            layer_device_map=layer_device_map,
            tp_size=tp_size,
        )
        processor_kwargs, kwargs = parse_processor_args(processor_class, kwargs)

        self.layer_init_kwargs = parse_layer_args_from_model_config(config, **kwargs)
        self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)

        self.append_new_layers(self.num_hidden_layers - 1)
        self.cache_processor = processor_class(self, **processor_kwargs) if processor_class is not None else None

    def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self.layers):
            return self.layers[layer_idx].keys, self.layers[layer_idx].values
        else:
            raise KeyError(
                f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
            )

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__
        if getattr(self, "layers", None) is None:
            if getattr(self, "key_cache", None) is not None:
                return len(self.key_cache)
            return 0
        # Empty dynamic caches initialize an empty layer to be ready for first update
        dynamic_empty = (
            getattr(self, "layers", None) is not None
            and len(self.layers) == 1
            and isinstance(self.layers[0], DynamicLayer)
            and self.layers[0].keys is None
        )
        return len(self.layers) if not dynamic_empty else 0

    def __repr__(self):
        return f"{self.__class__.__name__}(layers={self.layers})"

    def append_new_layers(self, layer_idx: int) -> None:
        """
        Appends layers to the cache until the layer `layer_idx` is reached.
        Used for preallocation in static caches and on the fly in dynamic caches.

        Args:
            layer_idx (`int`):
                The index of the layer to append.
        """
        while len(self.layers) <= layer_idx:
            kwargs = self.layer_init_kwargs.copy()
            if self.layer_init_kwargs.get("layer_device_map", None) is not None:
                kwargs["device"] = kwargs.pop("layer_device_map")[len(self.layers)]

            new_layer_class = (
                self.layer_classes[len(self.layers)] if isinstance(self.layer_classes, list) else self.layer_classes
            )
            new_layer = new_layer_class(**kwargs)
            self.layers.append(new_layer)

    @apply_processors
    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`dict[str, Any]`, *optional*):
                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
                cache to be created.

        Return:
            A tuple containing the updated key and value states.
        """
        self.append_new_layers(layer_idx)
        return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)

    def get_seq_length(self, layer_idx: int = 0, cache_position=None) -> int:
        """Returns the sequence length of the cache for the given layer. TODO: deprecate in favor of cache_position"""
        if layer_idx >= len(self.layers):
            return 0
        # Hack since QuantizedCache messes with keys shape as it becomes the residual cache
        if self.cache_processor is not None and isinstance(self.cache_processor, QuantizedCacheProcessor):
            return self.cache_processor.erased_length + self.layers[layer_idx].get_seq_length(cache_position)
        return self.layers[layer_idx].get_seq_length(cache_position)

    def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
        """
        Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
        the given layer at `layer_idx`.
        The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns (i.e. sliding_window, chunk_size),
        for each layer.
        """
        kv_length, kv_offset = self.layers[layer_idx].get_mask_sizes(cache_position)
        return kv_length, kv_offset

    @property
    def key_cache(self) -> KeyValuesWrapper:
        """List-like object of key cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].keys`"""
        logger.warning_once(
            "`cache.key_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].keys` instead."
        )
        return KeyValuesWrapper(self.layers, "keys")

    @property
    def value_cache(self) -> KeyValuesWrapper:
        """List-like object of value cache tensors indexed by layer. Deprecated in favor of `cache.layers[idx].values`"""
        logger.warning_once(
            "`cache.value_cache[idx]` is deprecated and will be removed in v4.56.0. Use `cache.layers[idx].values` instead."
        )
        return KeyValuesWrapper(self.layers, "values")

    ### Wrappers for layer operations and properties ###

    def get_max_cache_shape(self, layer_idx: int = 0) -> int:
        """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
        return self.layers[layer_idx].get_max_cache_shape()

    def reset(self):
        """Recursively reset all layers tensors"""
        for layer_idx in range(len(self.layers)):
            self.layers[layer_idx].reset()

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorder the cache for beam search"""
        for layer_idx in range(len(self.layers)):
            self.layers[layer_idx].reorder_cache(beam_idx)

    def crop(self, max_length: int):
        """Crop the cache to the given length"""
        for layer_idx in range(len(self.layers)):
            self.layers[layer_idx].crop(max_length)

    def batch_repeat_interleave(self, repeats: int):
        """Repeat and interleave the cache"""
        for layer_idx in range(len(self.layers)):
            self.layers[layer_idx].batch_repeat_interleave(repeats)

    def batch_select_indices(self, indices: torch.Tensor):
        """Select indices from the cache"""
        for layer_idx in range(len(self.layers)):
            self.layers[layer_idx].batch_select_indices(indices)

    @property
    def max_batch_size(self) -> int:
        """Return the maximum batch size of the cache"""
        values = [layer.max_batch_size for layer in self.layers]
        if len(set(values)) > 1:
            raise ValueError(f"Max batch size is not consistent across layers: {values}")
        return values[0]

    @property
    def max_cache_len(self) -> int:
        """Return the maximum cache length of the cache"""
        values = [layer.max_cache_len for layer in self.layers]
        return max(values)

    @property
    def is_compileable(self) -> bool:
        """Return whether the cache is compileable"""
        return all(layer.is_compileable for layer in self.layers)

    @property
    def is_sliding(self) -> list[bool]:
        """Return whether the layers of the cache are sliding window"""
        return [getattr(layer, "is_sliding", False) for layer in self.layers]


class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = DynamicCache()
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        DynamicCache()
        ```
    """

    # Specialized constructor for DDP cache data, needed for BC
    def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
        super().__init__(layer_classes=DynamicLayer, *args, **kwargs)
        # `ddp_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36212
        # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
        # iterable contains the key and value states for a layer gathered across replicas by torch.distributed
        # (shape=[global batch size, num_heads, seq_len, head_dim]).
        # WARNING: `ddp_cache_data` must be the first argument in `__init__`, otherwise we'll break
        # compatibility. The name of the argument doesn't matter.
        if ddp_cache_data is not None:
            for key_states, value_states in ddp_cache_data:
                self.layers.append(DynamicLayer.from_tensors(key_states, value_states))

    def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]:
        """
        Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
        backward compatibility.
        """
        legacy_cache = ()
        for layer in self.layers:
            legacy_cache += ((layer.keys, layer.values),)
        return legacy_cache

    @classmethod
    def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]) -> "Cache":
        """
        Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
        backward compatibility.
        """
        cache = cls()
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx]
                cache.update(key_states, value_states, layer_idx)
        return cache


# Utilities for `DynamicCache` <> torch.export support

if is_torch_greater_or_equal("2.3"):

    def _get_cache_dict(cache: DynamicCache):
        if any(not isinstance(layer, DynamicLayer) for layer in cache.layers):
            raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")

        if not is_torch_greater_or_equal_than_2_6:
            logger.warning_once(
                "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
            )

        return {
            "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
            "value_cache": [layer.values for layer in cache.layers if layer.values is not None],
        }

    def _unflatten_dynamic_cache(
        values,
        context: torch.utils._pytree.Context,
    ):
        dictionary = torch.utils._pytree._dict_unflatten(values, context)
        cache = DynamicCache()
        # Reconstruct layers from keys and values lists
        key_list = dictionary.get("key_cache", [])
        value_list = dictionary.get("value_cache", [])
        for idx in range(max(len(key_list), len(value_list))):
            key = key_list[idx] if idx < len(key_list) else None
            value = value_list[idx] if idx < len(value_list) else None
            cache.update(key, value, idx)
        return cache

    torch.utils._pytree.register_pytree_node(
        DynamicCache,
        lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
        _unflatten_dynamic_cache,
        serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
        flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
            _get_cache_dict(dynamic_cache)
        ),
    )
    # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
    torch.fx._pytree.register_pytree_flatten_spec(
        DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
    )


class OffloadedCache(DynamicCache):
    """
    A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
    Useful for generating from models with very long context.

    In addition to the default accelerator stream, where all forward() computations happen,
    this class uses another stream, the prefetch stream, which it creates itself.
    Since scheduling of operations on separate streams happens independently, this class uses
    the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
    The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
    ensure the eviction is scheduled after all computations on that cache are finished.
    """

    def __init__(self) -> None:
        # Create the underlying cache with offload processor
        super().__init__(cache_processor=OffloadedCacheProcessor)


class StaticCache(Cache):
    """
    Static Cache class to be used with `torch.compile(model)` and `torch.export()`.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache

        >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

        >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
        >>> max_generated_length = inputs.input_ids.shape[1] + 10
        >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        StaticCache()
        ```
    """

    def __init__(self, *args, **kwargs):
        super().__init__(layer_classes=StaticLayer, *args, **kwargs)


class OffloadedStaticCache(StaticCache):
    """
    A drop-in replacement for StaticCache that conserves accelerator memory by offloading
    cache tensors to CPU when not actively being used.

    This cache maintains the compilation-friendly properties of StaticCache while enabling
    much longer sequences by offloading inactive layers to CPU memory.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:
        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache

        >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
        >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")

        >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")

        >>> # Prepare a cache class with offloading
        >>> max_generated_length = inputs.input_ids.shape[1] + 10
        >>> past_key_values = OffloadedStaticCache(
        ...     config=model.config,
        ...     max_batch_size=1,
        ...     max_cache_len=max_generated_length,
        ...     device=model.device,
        ...     dtype=model.dtype
        ... )
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache with offloaded layers
        OffloadedStaticCache()
        ```
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)


class SlidingWindowCache(Cache):
    """
    Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
    Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.sliding_window - 1`,
    if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
    we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.

    The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:

    indices = (slicing + to_shift[-1].sum()-1) % self.sliding_window
    tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
        37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
        55, 56, 57, 58, 59, 60, 61, 62, 63,  0])

    We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache

        >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
        >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")

        >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
        >>> max_generated_length = inputs.input_ids.shape[1] + 10
        >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        SlidingWindowCache()
        ```
    """

    def __init__(self, *args, **kwargs):
        super().__init__(layer_classes=SlidingWindowLayer, *args, **kwargs)


class HybridCache(Cache):
    """
    Hybrid Cache class to be used with `torch.compile` for models that alternate between a local sliding window
    attention and global attention in every other layer (originally implemented for Gemma2).
    Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"]
    for global attention. For more information, see the documentation of those layer types.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache

        >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

        >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
        >>> max_generated_length = inputs.input_ids.shape[1] + 10
        >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        HybridCache()
        ```
    """

    def __init__(self, config: PretrainedConfig, *args, **kwargs):
        if hasattr(config, "layer_types"):
            layer_classes = [LAYER_CLASS_MAP[layer_type] for layer_type in config.layer_types]
        else:
            # In this case, fall back to StaticCache
            layer_classes = [StaticLayer] * config.num_hidden_layers
        super().__init__(config=config, layer_classes=layer_classes, *args, **kwargs)


# The mapping already handles dispatching the correct layers in Hybrid, this is only used for BC
class HybridChunkedCache(HybridCache): ...


class OffloadedHybridCache(HybridChunkedCache):
    """
    A drop-in replacement for HybridChunkedCache that conserves accelerator memory by offloading
    cache tensors to CPU when not actively being used.

    This cache maintains the compilation-friendly properties of HybridChunkedCache while enabling
    much longer sequences by offloading inactive layers to CPU memory.

    See `Cache` for details on common methods that are implemented by all cache classes.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, cache_processor=OffloadedCacheProcessor, **kwargs)


class QuantizedCache(DynamicCache):
    """
    A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
    It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.

    The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
    original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
    quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.

    It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
    Value in original precision states as a list of tensors, one for each layer. The size of each tensor
    is `[batch_size, num_heads, seq_len - residual_length, head_dim]`.

    See `Cache` for details on common methods that are implemented by all cache classes.
    """

    def __init__(self, backend, **kwargs) -> None:
        if backend == "quanto":
            processor = QuantoQuantizedCacheProcessor
        elif backend == "hqq":
            processor = HQQQuantizedCacheProcessor
        else:
            raise ValueError(f"Unknown quantization backend `{backend}`")

        super().__init__(cache_processor=processor, **kwargs)


class QuantoQuantizedCache(QuantizedCache):
    """
    A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
    It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.

    The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
    original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
    quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.

    It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
    Value in original precision states as a list of tensors, one for each layer. The size of each tensor
    is `[batch_size, num_heads, seq_len - residual_length, head_dim]`

    Uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> # Run pip install quanto first if you don't have it yet
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> cache_config = QuantizedCacheConfig(nbits=4)
        >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        QuantoQuantizedCache()
        ```
    """

    def __init__(self, **kwargs) -> None:
        DynamicCache.__init__(self, cache_processor=QuantoQuantizedCacheProcessor, **kwargs)


class HQQQuantizedCache(QuantizedCache):
    """
    A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
    It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.

    The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
    original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
    quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.

    It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
    Value in original precision states as a list of tensors, one for each layer. The size of each tensor
    is `[batch_size, num_heads, seq_len - residual_length, head_dim]`

    Uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> # Run pip install hqq first if you don't have it yet
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
        >>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        HQQQuantizedCache()
        ```
    """

    def __init__(self, backend="HQQ", **kwargs) -> None:
        assert backend == "HQQ"
        DynamicCache.__init__(self, cache_processor=HQQQuantizedCacheProcessor, **kwargs)


class EncoderDecoderCache(Cache):
    """
    Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
    cross-attention caches.

    See `Cache` for details on common methods that are implemented by all cache classes.

    Example:

        ```python
        >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache

        >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
        >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")

        >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")

        >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
        >>> self_attention_cache = DynamicCache()
        >>> cross_attention_cache = DynamicCache()
        >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        EncoderDecoderCache()
        ```

    """

    # Override @property from Cache
    is_compileable = None

    def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
        super().__init__(layer_classes=DynamicLayer)
        self.self_attention_cache = self_attention_cache
        self.cross_attention_cache = cross_attention_cache
        self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False)

        self.is_updated = {}
        for layer_idx in range(len(cross_attention_cache)):
            self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)

    def __iter__(self):
        """
        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
        keys and values
        """
        for layer_idx in range(len(self)):
            yield (
                self.self_attention_cache.layers[layer_idx].keys,
                self.self_attention_cache.layers[layer_idx].values,
                self.cross_attention_cache.layers[layer_idx].keys,
                self.cross_attention_cache.layers[layer_idx].values,
            )

    def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self):
            return (
                self.self_attention_cache.layers[layer_idx].keys,
                self.self_attention_cache.layers[layer_idx].values,
                self.cross_attention_cache.layers[layer_idx].keys,
                self.cross_attention_cache.layers[layer_idx].values,
            )
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def __len__(self):
        """
        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
        to the number of layers in the model.
        """
        return len(self.self_attention_cache)

    def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
        """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
        legacy_cache = ()
        if len(self.cross_attention_cache) > 0:
            for self_attn, cross_attn in zip(
                self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
            ):
                legacy_cache += (self_attn + cross_attn,)
        else:
            legacy_cache = self.self_attention_cache.to_legacy_cache()
        return legacy_cache

    @classmethod
    def from_legacy_cache(
        cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]
    ) -> "EncoderDecoderCache":
        """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
        cache = cls(
            self_attention_cache=DynamicCache(),
            cross_attention_cache=DynamicCache(),
        )
        if past_key_values is not None:
            for layer_idx in range(len(past_key_values)):
                key_states, value_states = past_key_values[layer_idx][:2]
                cache.self_attention_cache.update(key_states, value_states, layer_idx)
                if len(past_key_values[layer_idx]) > 2:
                    key_states, value_states = past_key_values[layer_idx][2:]
                    cache.cross_attention_cache.update(key_states, value_states, layer_idx)
                    cache.is_updated[layer_idx] = True
        return cache

    def get_seq_length(self, layer_idx: Optional[int] = 0, cache_position=None) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
        return self.self_attention_cache.get_seq_length(layer_idx, cache_position)

    def reset(self):
        if hasattr(self.self_attention_cache, "reset"):
            self.self_attention_cache.reset()
        if hasattr(self.cross_attention_cache, "reset"):
            self.cross_attention_cache.reset()
        elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
            raise ValueError(
                "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
                "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
                f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
                f"{self.cross_attention_cache.__str__()} for the cross attention cache."
            )
        for layer_idx in self.is_updated:
            self.is_updated[layer_idx] = False

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Reorders the cache for beam search, given the selected beam indices."""
        self.self_attention_cache.reorder_cache(beam_idx)
        self.cross_attention_cache.reorder_cache(beam_idx)

    def check_dynamic_cache(self, method: str):
        if not (
            isinstance(self.self_attention_cache, DynamicCache)
            and isinstance(self.cross_attention_cache, DynamicCache)
        ):
            raise ValueError(
                f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
                f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
            )

    # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
    def crop(self, maximum_length: int):
        """
        Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
        negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.
        """
        self.check_dynamic_cache(self.crop.__name__)
        self.self_attention_cache.crop(maximum_length)

    def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
        """
        Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
        `_split_model_inputs()` in `generation.utils`
        """
        self.check_dynamic_cache(self.batch_split.__name__)
        self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
        cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)

        out = []
        for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
            out.append(EncoderDecoderCache(self_attn, cross_attn))
        return out

    def batch_repeat_interleave(self, repeats: int):
        """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
        self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
        self.self_attention_cache.batch_repeat_interleave(repeats)
        self.cross_attention_cache.batch_repeat_interleave(repeats)

    def batch_select_indices(self, indices: torch.Tensor):
        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
        self.check_dynamic_cache(self.batch_select_indices.__name__)
        self.self_attention_cache.batch_select_indices(indices)
        self.cross_attention_cache.batch_select_indices(indices)

    def get_max_cache_shape(self) -> int:
        """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
        return self.self_attention_cache.get_max_cache_shape()

    def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
        return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)


def parse_processor_args(processor_class: Optional[type["CacheProcessor"]], kwargs: dict) -> tuple[dict, dict]:
    """
    Parse processor arguments from kwargs based on the processor class init signature.

    Args:
        processor_class: The processor class to inspect, or None
        kwargs: Dictionary of keyword arguments

    Returns:
        tuple: (processor_kwargs, remaining_kwargs)
    """
    try:
        params = list(inspect.signature(processor_class.__init__).parameters)[2:]
    except Exception:
        return {}, kwargs

    processor_kwargs = {k: kwargs[k] for k in params if k in kwargs}
    remaining_kwargs = {k: v for k, v in kwargs.items() if k not in processor_kwargs}
    return processor_kwargs, remaining_kwargs


def parse_layer_args_from_model_config(
    config: Optional[PretrainedConfig],
    batch_size: Optional[int] = None,
    max_cache_len: Optional[int] = None,
    device: Union[torch.device, str, None] = None,
    dtype: Optional[torch.dtype] = None,
    layer_device_map: Optional[dict[int, torch.device]] = None,
    tp_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
) -> dict:
    """
    Parse layer arguments from model configuration for cache initialization.

    Args:
        config (`Optional[PretrainedConfig]`): Model configuration containing shape/device info.
        batch_size (`Optional[int]`): Batch size for cache initialization.
        max_cache_len (`Optional[int]`): Maximum sequence length for cache.
        device (`Union[torch.device, str, None]`): Device for cache tensors.
        dtype (`Optional[torch.dtype]`): Data type for cache tensors.
        layer_device_map: Per-layer device mapping.
        tp_size (`Optional[int]`): Tensor parallel size to adjust number of key/value heads.
        max_batch_size (`Optional[int]`): Maximum batch size for cache initialization.

    Returns:
        `dict`: Dictionary containing parsed layer arguments for cache initialization.
    """
    # No model config -> must be a dynamic cache, return bare dict
    if config is None:
        return {}
    # Build the args dict for hybrid, sliding or static
    else:
        # Hybrid/Sliding caches require a config that supports sliding_window (max_cache_len already used)
        if (
            getattr(config, "layer_types", None) is not None
            and "sliding_attention" in config.layer_types
            and "full_attention" in config.layer_types
        ):
            if getattr(config, "sliding_window", None) is None:
                raise ValueError(
                    "Setting up a hybrid or sliding window KVCache requires the model config supporting "
                    "sliding window attention, please check if there is a `sliding_window` field in the model "
                    "config and it's not set to None."
                )
        # Adjust max_cache_len for sliding window layers (they can't be larger than sliding window)
        max_cache_len = max_cache_len or config.max_position_embeddings
        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads:
        head_dim = (
            config.head_dim
            if getattr(config, "head_dim", None) is not None
            else config.hidden_size // config.num_attention_heads
        )
        num_heads = (
            config.num_attention_heads
            if getattr(config, "num_key_value_heads", None) is None
            else config.num_key_value_heads
        )
        if tp_size is not None and tp_size > 1:
            if num_heads % tp_size != 0:
                raise ValueError(
                    f"Number of key value heads {num_heads} must be divisible by tensor parallel size {tp_size}."
                )
            # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
            num_heads //= tp_size
        layer_args = {
            "batch_size": max_batch_size if max_batch_size is not None else batch_size,
            "max_cache_len": max_cache_len,
            "device": torch.device(device) if device is not None else None,
            "dtype": dtype,
            "layer_device_map": layer_device_map,
            "head_dim": head_dim,
            "num_heads": num_heads,
            "sliding_window": getattr(config, "sliding_window", None),
        }
        return {k: v for k, v in layer_args.items() if v is not None}


LAYER_CLASS_MAP: dict[str, type["CacheLayerMixin"]] = {
    "full_attention": StaticLayer,
    "sliding_attention": SlidingWindowLayer,
    "chunked_attention": ChunkedSlidingLayer,
}
PROCESSOR_CLASS_MAP: dict[str, type["CacheProcessor"]] = {
    "offloaded": OffloadedCacheProcessor,
    "quanto_quantized": QuantizedCacheProcessor,
    "hqq_quantized": HQQQuantizedCacheProcessor,
}


### Deprecated classes


class SinkCache(Cache):
    """
    Is its now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
    See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
    general `custom_generate`usage.
    """

    # TODO (joao, manuel): Remove this class in v4.59.0
    def __init__(self, **kwargs) -> None:
        raise NotImplementedError(
            "`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
            "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
        )


@dataclass
class CacheConfig:
    """
    Base class for cache configs. Deprecated in favor of a simpler dictionary.
    """

    cache_implementation: None

    def __post_init__(self):
        logger.warning_once(
            "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."
        )

    @classmethod
    def from_dict(cls, config_dict, **kwargs):
        """
        Constructs a CacheConfig instance from a dictionary of parameters.
        Args:
            config_dict (dict[str, Any]): Dictionary containing configuration parameters.
            **kwargs: Additional keyword arguments to override dictionary values.

        Returns:
            CacheConfig: Instance of CacheConfig constructed from the dictionary.
        """
        logger.warning_once(
            "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."
        )
        config = cls(**config_dict)
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(config, key):
                setattr(config, key, value)
                to_remove.append(key)
        for key in to_remove:
            kwargs.pop(key, None)
        return config

    # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
        """
        Save this instance to a JSON file.

        Args:
            json_file_path (`str` or `os.PathLike`):
                Path to the JSON file in which this configuration instance's parameters will be saved.
            use_diff (`bool`, *optional*, defaults to `True`):
                If set to `True`, only the difference between the config instance and the default
                `QuantizationConfig()` is serialized to JSON file.
        """
        with open(json_file_path, "w", encoding="utf-8") as writer:
            config_dict = self.to_dict()
            json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

            writer.write(json_string)

    # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
    def to_dict(self) -> dict[str, Any]:
        """
        Serializes this instance to a Python dictionary. Returns:
            `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
        """
        return copy.deepcopy(self.__dict__)

    # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
    def __iter__(self):
        """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
        for attr, value in copy.deepcopy(self.__dict__).items():
            yield attr, value

    # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
    def __repr__(self):
        return f"{self.__class__.__name__} {self.to_json_string()}"

    def to_json_string(self):
        """
        Serializes this instance to a JSON formatted string.
        Returns:
            str: JSON formatted string representing the configuration instance.
        """
        return json.dumps(self.__dict__, indent=2) + "\n"

    # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
    def update(self, **kwargs):
        """
        Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
        returning all the unused kwargs.

        Args:
            kwargs (`dict[str, Any]`):
                Dictionary of attributes to tentatively update this class.

        Returns:
            `dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
        """
        to_remove = []
        for key, value in kwargs.items():
            if hasattr(self, key):
                setattr(self, key, value)
                to_remove.append(key)

        # Remove all the attributes that were updated, without modifying the input dict
        unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
        return unused_kwargs


@dataclass
class QuantizedCacheConfig(CacheConfig):
    """
    Configuration class for quantized cache settings. Deprecated in favor of a simpler dictionary.

    Attributes:
        backend (`str`, *optional*, defaults to `"quanto"`):
            Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
        nbits (`Optional[int]`, *optional*, defaults to 4):
            Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
        axis_key (`int`, *optional*, defaults to 0):
            Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
        axis_value (`int`, *optional*, defaults to 0):
            Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
        q_group_size (`Optional[int]`, *optional*, defaults to 64):
            Size of the quantization group, should be a divisor of the model's hidden dimension.
            Defaults to 64.
        residual_length (`Optional[int]`, *optional*, defaults to 128):
            Length of the residual cache which will always be stored in original precision.
            Defaults to 128.
        compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
            The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
        device (`str`, *optional*, defaults to `"cpu"`):
            Device on which to perform computations, should be same as the model's device.
    """

    def __init__(
        self,
        backend: str = "quanto",
        nbits: Optional[int] = 4,
        axis_key: Optional[int] = 0,
        axis_value: Optional[int] = 0,
        q_group_size: Optional[int] = 64,
        residual_length: Optional[int] = 128,
        compute_dtype: Optional[torch.dtype] = torch.float16,
        device: Optional[str] = "cpu",
    ):
        logger.warning_once(
            "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."
        )
        self.backend = backend
        self.nbits = nbits
        self.axis_key = axis_key
        self.axis_value = axis_value
        self.q_group_size = q_group_size
        self.residual_length = residual_length
        self.compute_dtype = compute_dtype
        self.device = device

    def validate(self):
        """Validates if the arguments passed are correct"""

        incorrect_arg_msg = (
            "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
            "but found {found_value}"
        )
        # Check that the values are reasonable in general (nbits, axis)
        # Later in QuantizedCache init we check if they are supported for that particular backend
        if self.nbits not in [1, 2, 3, 4, 8]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="nbits",
                    correct_value="2 or 4 or 8",
                    found_value=self.nbits,
                ),
            )
        if self.q_group_size <= 0:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="q_group_size",
                    correct_value="a positive integer",
                    found_value=self.q_group_size,
                ),
            )
        if self.residual_length < 0:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="residual_length",
                    correct_value="a positive integer",
                    found_value=self.residual_length,
                ),
            )

        if self.axis_key not in [0, 1, -1]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="axis_key",
                    correct_value="`1` or `0`, `-1`",
                    found_value=self.axis_key,
                ),
            )

        if self.axis_value not in [0, 1, -1]:
            raise ValueError(
                incorrect_arg_msg.format(
                    key="axis_value",
                    correct_value="`1` or `0` or `-1`",
                    found_value=self.axis_value,
                ),
            )


@dataclass
class StaticCacheConfig(CacheConfig):
    """
    Configuration class for static cache settings.
    """

    cache_implementation = "static"

    def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
        logger.warning_once(
            "CacheConfig is deprecated and will be removed in v4.55.0 in favor of a simpler dictionary."
        )
        self.batch_size = batch_size
        self.max_cache_len = max_cache_len
        self.device = device

    def initialise_cache_layer(self, layer_idx, key_states):
        """Overridden to use the correct device if offloaded layer (and pin memory)."""
        if len(self.key_cache) > layer_idx:
            return

        num_key_value_heads = key_states.shape[1]
        device = key_states.device if self.is_sliding[layer_idx] else self.offload_device
        pin_memory = not self.is_sliding[layer_idx]
        global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
        sliding_cache_shape = (self.max_batch_size, num_key_value_heads, self.sliding_window, self.head_dim)
        # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
        # breaks when updating the cache.
        cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
        new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
        new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device, pin_memory=pin_memory)
        torch._dynamo.mark_static_address(new_layer_key_cache)
        torch._dynamo.mark_static_address(new_layer_value_cache)
        self.key_cache.append(new_layer_key_cache)
        self.value_cache.append(new_layer_value_cache)

        # Make sure to initialize the on-device layer if it does not already exist
        if self.device_key_cache is None and not self.is_sliding[layer_idx]:
            self.device_key_cache = []
            self.device_value_cache = []
            # We need 2 layers to avoid race conditions when prefetching the next one
            for _ in range(2):
                device_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
                device_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=key_states.device)
                torch._dynamo.mark_static_address(new_layer_key_cache)
                torch._dynamo.mark_static_address(new_layer_value_cache)
                self.device_key_cache.append(device_layer_key_cache)
                self.device_value_cache.append(device_layer_value_cache)

    def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
        # Wait for prefetch stream if needed
        if self._prefetch_stream is not None:
            torch.cuda.default_stream(key_states.device).wait_stream(self._prefetch_stream)

        # Get correct on-device layer
        k_out = self.device_key_cache[self.active_device_layer]
        v_out = self.device_value_cache[self.active_device_layer]

        # Let's prefetch the next layer as soon as possible
        self._prefetch_next_layer(layer_idx)

        # Copy to on-device layer
        k_out[:, :, cache_position] = key_states
        v_out[:, :, cache_position] = value_states

        # Copy to offloaded device
        self.key_cache[layer_idx][:, :, cache_position] = key_states.to(self.offload_device)
        self.value_cache[layer_idx][:, :, cache_position] = value_states.to(self.offload_device)

        return k_out, v_out

    def _prefetch_next_layer(self, layer_idx: int) -> None:
        """Based on current layer_idx, prefetch next full layer to the device."""

        # Switch the active layer
        self.active_device_layer = 0 if self.active_device_layer == 1 else 1

        # Find the next non-sliding layer
        try:
            next_layer = layer_idx + 1 + self.is_sliding[layer_idx + 1 :].index(False)
        # In this case, we are at the last layer, and we go back to prefect the first one
        except ValueError:
            next_layer = self.is_sliding.index(False)

        # Alternate between two on-device caches.
        if self._prefetch_stream is not None:
            with torch.cuda.stream(self._prefetch_stream):
                self._prefetch_layer_in_context(next_layer)
        else:
            self._prefetch_layer_in_context(next_layer)

    def _prefetch_layer_in_context(self, layer_idx: int) -> None:
        """Performs the actual copy of the layer to device cache."""
        if len(self.key_cache) > layer_idx:
            self.device_key_cache[self.active_device_layer].copy_(self.key_cache[layer_idx], non_blocking=True)
            self.device_value_cache[self.active_device_layer].copy_(self.value_cache[layer_idx], non_blocking=True)
        # The layer was not yet initialized
        else:
            self.device_key_cache[self.active_device_layer].fill_(0.0)
            self.device_value_cache[self.active_device_layer].fill_(0.0)


# TODO (manuel, joao): remove this class, it is here only for backwards compatibility
# PEP 562: Lazy loading for deprecated location of MambaCache
def __getattr__(name: str) -> Any:
    if name == "MambaCache":
        logger.warning_once(
            "Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed "
            "in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead."
        )

        class MambaCache:
            """
            Importing `MambaCache` from `transformers.cache_utils` is deprecated and will be removed
            in a future version. Please import it from `transformers` or `transformers.models.mamba.cache_mamba` instead.

            Cache for mamba model which does not have attention mechanism and key value states.

            Arguments:
                config (`PretrainedConfig):
                    The configuration file defining the shape-related attributes required to initialize the static cache.
                max_batch_size (`int`):
                    The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
                dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
                    The default `dtype` to use when initializing the layer.
                device (`torch.device` or `str`, *optional*):
                    The device on which the cache should be initialized. Should be the same as the layer.

            Example:

                ```python
                >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache

                >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
                >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")

                >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")

                >>> # Prepare a cache class and pass it to model's forward
                >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
                >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
                >>> outputs.past_key_values
                MambaCache()
                ```
            """

            is_compileable = True

            # TODO (joao): add layer_device_map arg and update code in `generate` accordingly
            def __init__(
                self,
                config,
                max_batch_size: int,
                dtype: torch.dtype = torch.float16,
                device: Union[torch.device, str, None] = None,
            ):
                self.max_batch_size = max_batch_size
                self._dtype = dtype
                self.intermediate_size = config.intermediate_size
                self.ssm_state_size = config.state_size
                self.conv_kernel_size = config.conv_kernel

                self.conv_states: list[torch.Tensor] = []
                self.ssm_states: list[torch.Tensor] = []
                device = torch.device(device) if device is not None else None
                for _ in range(config.num_hidden_layers):
                    conv_state: torch.Tensor = torch.zeros(
                        self.max_batch_size,
                        self.intermediate_size,
                        self.conv_kernel_size,
                        device=device,
                        dtype=self._dtype,
                    )
                    ssm_state: torch.Tensor = torch.zeros(
                        self.max_batch_size,
                        self.intermediate_size,
                        self.ssm_state_size,
                        device=device,
                        dtype=self._dtype,
                    )

                    torch._dynamo.mark_static_address(conv_state)
                    torch._dynamo.mark_static_address(ssm_state)
                    self.conv_states.append(conv_state)
                    self.ssm_states.append(ssm_state)

            def update_conv_state(
                self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
            ) -> torch.Tensor:
                # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used
                # when the cache is initialized in the forward pass (e.g. Mamba)
                if self.conv_states[layer_idx].device != new_conv_state.device:
                    self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device)

                conv_state = self.conv_states[layer_idx]
                cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)

                conv_state = conv_state.roll(shifts=-1, dims=-1)
                conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
                self.conv_states[layer_idx].zero_()
                self.conv_states[layer_idx] += conv_state
                return self.conv_states[layer_idx]

            def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
                self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device)
                return self.ssm_states[layer_idx]

            def reset(self):
                for layer_idx in range(len(self.conv_states)):
                    # In-place ops prevent breaking the static address
                    self.conv_states[layer_idx].zero_()
                    self.ssm_states[layer_idx].zero_()

        return MambaCache
    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
