# Copyright 2025 NXAI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""xLSTM configuration."""

from typing import Optional

from ...configuration_utils import PretrainedConfig
from ...utils import is_xlstm_available, logging


if is_xlstm_available():
    from xlstm.xlstm_large.model import (
        BackendModeType,
        ChunkwiseKernelType,
        DtypeType,
        SequenceKernelType,
        StepKernelType,
        WeightModeType,
        round_up_to_next_multiple_of,
        xLSTMLargeConfig,
    )

    external_xlstm = True
else:
    from typing import Literal

    BackendModeType = Literal["train", "train_with_padding", "inference"]
    ChunkwiseKernelType = Literal[
        "chunkwise--native_autograd",
        "parallel--native_autograd",
    ]
    DtypeType = Literal["float32", "bfloat16", "float16"]
    SequenceKernelType = Literal["native_sequence__native"]
    StepKernelType = Literal["native"]
    WeightModeType = Literal["single", "fused"]

    def round_up_to_next_multiple_of(x: int, multiple_of: int) -> int:
        """Rounds up x to the next multiple of multiple_of."""
        return int(((x + multiple_of - 1) // multiple_of) * multiple_of)

    external_xlstm = False


logger = logging.get_logger(__name__)


class xLSTMConfig(PretrainedConfig):
    """
    This is the configuration class to store the configuration of a [`xLSTM`]. It is used to instantiate a xLSTM
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the xLSTM-7b [NX-AI/xLSTM-7b](https://huggingface.co/NX-AI/xLSTM-7b) model.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (int, optional, *optional*, defaults to 50304):
            Vocabulary size of the xLSTM model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`xLSTMModel`]. Defaults to the GPT2-NeoX tokenizer size.
        hidden_size (int, optional, *optional*, defaults to 4096):
            Dimensionality of the embeddings or hidden states.
        embedding_dim (int, optional, *optional*, defaults to 4096):
            Dimensionality of the embeddings or hidden states, use hidde_size if None.
        num_hidden_layers (int, optional, *optional*, defaults to 32):
            Number of blocks of the xLSTM model.
        num_blocks (int, optional, *optional*, defaults to 32):
            Number of blocks of the xLSTM model, use num_hidden_layers if None.
        num_heads (int, optional, *optional*, defaults to 8):
            Number of heads for the xLSTM Layer/Cell.
        use_bias (bool, optional, *optional*, defaults to `False`):
            Whether to use biases in the xLSTM model.
        norm_reduction_force_float32 (bool, optional, *optional*, defaults to `True`):
            Whether to force the float32 norm reduction op to be done in fp32 precision.
        tie_word_embeddings (bool, optional, *optional*, defaults to `False`):
            Whether to tie word embeddings to the lm head weights.
        add_out_norm (bool, optional, *optional*, defaults to `True`):
            Whether to add an output norm after the blocks before the LMHead.
        norm_eps (float, optional, *optional*, defaults to 1e-06):
            Norm eps for RMSNorm and Layer Norm.
        qk_dim_factor (float, optional, *optional*, defaults to 0.5):
            Scale factor for the query and key dimension.
        v_dim_factor (float, optional, *optional*, defaults to 1.0):
            Scale factor for the value dimension.
        chunkwise_kernel (ChunkwiseKernelType, optional, *optional*, defaults to `"chunkwise--native_autograd"`):
            Kernel type for chunkwise processing mode.
        sequence_kernel (SequenceKernelType, optional, *optional*, defaults to `"native_sequence__native"`):
            Kernel type for sequence processing mode.
        step_kernel (StepKernelType, optional, *optional*, defaults to `"native"`):
            Kernel type for step processing mode.
        mode (BackendModeType, optional, *optional*, defaults to `"inference"`):
            Operation mode (inference is needed for generation).
        chunk_size (int, optional, *optional*, defaults to 64):
            Internal chunk size.
        return_last_states (bool, optional, *optional*, defaults to `True`):
            If to return the last states / cache internally. Needed as True for generation.
        autocast_kernel_dtype (DtypeType, optional, *optional*, defaults to `"bfloat16"`):
            Kernel dtype for the states.
        eps (float, optional, *optional*, defaults to 1e-06):
            Epsilon for the mLSTM cell post norm.
        inference_state_dtype (DtypeType, optional, *optional*, defaults to `"float32"`):
            Kernel dtype for states in inference.
        ffn_proj_factor (float, optional, *optional*, defaults to 2.667):
            Size factor of the post-up projection gated Feed Forward network.
        ffn_round_up_to_multiple_of (int, optional, *optional*, defaults to 64):
            Size factor round value of the post-up projection gated Feed Forward network.
        gate_soft_cap (float, optional, *optional*, defaults to 15.0):
            Gate soft cap scale.
        output_logit_soft_cap (float, optional, *optional*, defaults to 30.0):
            Output logit soft cap scale.
        weight_mode (`Literal`, *optional*, defaults to `"single"`):
            Whether parallel linear layers are separated or fused (single).
        use_cache (bool, optional, *optional*, defaults to `True`):
            Whether to use the cache (xLSTMCache).
        pad_token_id (int, optional, *optional*, defaults to 1):
            Pad token id needed for generation.
        bos_token_id (int, optional, *optional*, defaults to 0):
            BOS token id needed for generation.
        eos_token_id (int, optional, *optional*, defaults to 2):
            EOS token id needed for generation.
        max_inference_chunksize (int, optional, *optional*, defaults to 16384):
            Limit the chunk size for inference to save memory.

    Example:

    ```python
    >>> from transformers import xLSTMConfig, xLSTMModel

    >>> # Initializing a xLSTM configuration
    >>> configuration = xLSTMConfig()

    >>> # Initializing a model (with random weights) from the configuration
    >>> model = xLSTMModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "xlstm"

    def __init__(
        self,
        vocab_size: int = 50304,
        hidden_size: int = 4096,
        embedding_dim: Optional[int] = None,
        num_hidden_layers: Optional[int] = 32,
        num_blocks: Optional[int] = None,
        num_heads: int = 8,
        use_bias: bool = False,
        norm_reduction_force_float32: bool = True,
        tie_word_embeddings: bool = False,
        add_out_norm: bool = True,
        norm_eps: float = 1e-6,
        # mlstm_layer
        qk_dim_factor: float = 0.5,
        v_dim_factor: float = 1.0,
        # mlstm backend
        chunkwise_kernel: ChunkwiseKernelType = "chunkwise--native_autograd",
        sequence_kernel: SequenceKernelType = "native_sequence__native",
        step_kernel: StepKernelType = "native",
        # nedded to enable generation
        mode: BackendModeType = "inference",
        chunk_size: int = 64,
        # needed to be true for generation
        return_last_states: bool = True,
        autocast_kernel_dtype: DtypeType = "bfloat16",
        eps: float = 1e-6,
        inference_state_dtype: DtypeType = "float32",
        # feedforward
        ffn_proj_factor: float = 2.667,
        ffn_round_up_to_multiple_of: int = 64,
        # capping
        gate_soft_cap: float = 15.0,
        output_logit_soft_cap: float = 30.0,
        # weights
        weight_mode: WeightModeType = "single",
        # HF interface
        use_cache: bool = True,
        pad_token_id: int = 1,
        bos_token_id: int = 0,
        eos_token_id: int = 2,
        max_inference_chunksize: int = 16384,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size if hidden_size is not None else embedding_dim
        self.embedding_dim = embedding_dim if embedding_dim is not None else hidden_size
        self.num_hidden_layers = num_hidden_layers if num_hidden_layers is not None else num_blocks
        self.num_blocks = num_blocks if num_blocks is not None else num_hidden_layers
        self.num_heads = num_heads
        self.use_bias = use_bias
        self.tie_word_embeddings = tie_word_embeddings
        self.add_out_norm = add_out_norm
        self.norm_eps = norm_eps
        self.norm_reduction_force_float32 = norm_reduction_force_float32
        # mlstm_layer
        self.qk_dim_factor = qk_dim_factor
        self.v_dim_factor = v_dim_factor
        # mlstm backend
        self.chunkwise_kernel = chunkwise_kernel
        self.sequence_kernel = sequence_kernel
        self.step_kernel = step_kernel
        self.mode = mode
        self.chunk_size = chunk_size
        self.return_last_states = return_last_states
        self.autocast_kernel_dtype = autocast_kernel_dtype
        self.eps = eps
        self.inference_state_dtype = inference_state_dtype
        # feedforward
        self.ffn_proj_factor = ffn_proj_factor
        self.ffn_round_up_to_multiple_of = ffn_round_up_to_multiple_of
        # capping
        self.gate_soft_cap = gate_soft_cap
        self.output_logit_soft_cap = output_logit_soft_cap
        self.weight_mode = weight_mode

        self.use_cache = use_cache
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.max_inference_chunksize = max_inference_chunksize

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )

    @property
    def qk_dim(self):
        return round_up_to_next_multiple_of(
            self.hidden_size * self.qk_dim_factor,
            multiple_of=64,
        )

    @property
    def v_dim(self):
        return round_up_to_next_multiple_of(
            self.hidden_size * self.v_dim_factor,
            multiple_of=64,
        )

    @property
    def qk_head_dim(self):
        return self.qk_dim // self.num_heads

    @property
    def v_head_dim(self):
        return self.v_dim // self.num_heads

    def to_xlstm_block_config(self):
        if external_xlstm:
            return xLSTMLargeConfig(
                vocab_size=self.vocab_size,
                embedding_dim=self.hidden_size,
                num_blocks=self.num_hidden_layers,
                num_heads=self.num_heads,
                use_bias=self.use_bias,
                add_out_norm=self.add_out_norm,
                norm_eps=self.norm_eps,
                norm_reduction_force_float32=self.norm_reduction_force_float32,
                # mlstm_layer
                qk_dim_factor=self.qk_dim_factor,
                v_dim_factor=self.v_dim_factor,
                # mlstm backend
                chunkwise_kernel=self.chunkwise_kernel,
                sequence_kernel=self.sequence_kernel,
                step_kernel=self.step_kernel,
                mode=self.mode,
                chunk_size=self.chunk_size,
                return_last_states=self.return_last_states,
                autocast_kernel_dtype=self.autocast_kernel_dtype,
                eps=self.eps,
                inference_state_dtype=self.inference_state_dtype,
                # feedforward
                ffn_proj_factor=self.ffn_proj_factor,
                ffn_round_up_to_multiple_of=self.ffn_round_up_to_multiple_of,
                # capping
                gate_soft_cap=self.gate_soft_cap,
                output_logit_soft_cap=self.output_logit_soft_cap,
                weight_mode=self.weight_mode,
            )
        else:
            return self


__all__ = ["xLSTMConfig"]
