import functools
import logging
import time
from enum import Enum
from typing import Any, Callable, Optional, Union

import torch


class RequestStatus(Enum):
    """Status of a generation request through its lifecycle."""

    PENDING = "pending"
    PREFILLING = "prefilling"
    PREFILLING_SPLIT = "prefilling_split"
    SPLIT_PENDING_REMAINDER = "split_pending_remainder"
    DECODING = "decoding"
    FINISHED = "finished"
    FAILED = "failed"


try:
    from opentelemetry import metrics
    from opentelemetry.trace import Status, StatusCode, get_tracer

    _has_opentelemetry = True
except ImportError:
    _has_opentelemetry = False


def attach_tracer(tracer_name_template=None):
    """
    Decorator that attaches a tracer to a class.

    This decorator should be applied to classes that need OpenTelemetry tracing.
    It adds a tracer attribute to the class instance that can be used by the traced decorator.

    Args:
        tracer_name_template: Optional template string for the tracer name.
            If provided, it should contain {module} which will be replaced with the class's full module path
            and {class_name} for the class name.
            If None, a default naming scheme will be used where:
              - If the module already starts with "transformers.", it will use that directly
              - Otherwise, it will prepend "transformers." to the module name

    Returns:
        Class decorator function
    """
    if not _has_opentelemetry:
        return lambda cls: cls

    def decorator(cls):
        original_init = cls.__init__

        @functools.wraps(original_init)
        def init_with_tracer(self, *args, **kwargs):
            original_init(self, *args, **kwargs)

            module_name = cls.__module__
            class_name = cls.__qualname__

            if tracer_name_template is None:
                if module_name.startswith("transformers."):
                    tracer_name = f"{module_name}.{class_name}"
                else:
                    tracer_name = f"transformers.{module_name}.{class_name}"
            else:
                tracer_name = tracer_name_template.format(module=module_name, class_name=class_name)

            self.tracer = get_tracer(tracer_name)

        cls.__init__ = init_with_tracer
        return cls

    return decorator


def traced(
    func=None,
    *,
    span_name=None,
    standalone=False,
    additional_attributes: Optional[list[tuple[str, str, Union[Any, Callable[[Any], Any]]]]] = None,
):
    """
    Decorator to trace function calls with OpenTelemetry.

    Can be used as @traced or @traced(span_name="custom_name")

    Args:
        func: The function to trace
        span_name: Optional custom name for the span (defaults to function name)
        standalone: If True, creates a parentless span
        additional_attributes: Optional list of additional attributes to set on the span.
          Each item is a tuple of (instance_attribute_name, span_attribute_key, value_or_transform_function)
          where:
            - instance_attribute_name: Name of the attribute to get from the class instance
            - span_attribute_key: Key to use when setting the attribute on the span
            - value_or_transform_function: Either a raw value to use directly, or a function to transform
              the attribute value before setting it on the span

    Returns:
        Decorated function with tracing
    """

    def decorator(func):
        if not _has_opentelemetry:
            return func

        import functools

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            instance = args[0] if args and (hasattr(func, "__self__") and func.__self__ is not None) else None
            is_method = instance is not None

            if is_method and hasattr(instance, "tracer"):
                tracer = instance.tracer
            else:
                tracer = get_tracer(f"transformers.{func.__module__}.{func.__name__}")

            name = span_name or func.__name__
            span_fn = tracer.start_span if standalone else tracer.start_as_current_span
            with span_fn(name) as span:
                span.set_attribute("function.name", func.__name__)
                span.set_attribute("function.module", func.__module__)
                span.set_attribute("function.is_method", is_method)

                if args:
                    for i, arg in enumerate(args):
                        if isinstance(arg, (str, int, float, bool)) or arg is None:
                            span.set_attribute(f"args.{i}", str(arg))
                        else:
                            span.set_attribute(f"args.{i}", str(type(arg)))
                if kwargs:
                    for key, value in kwargs.items():
                        if isinstance(value, (str, int, float, bool)) or value is None:
                            span.set_attribute(f"kwargs.{key}", str(value))
                        else:
                            span.set_attribute(f"kwargs.{key}", str(type(value)))

                if additional_attributes and is_method:
                    for attr_config in additional_attributes:
                        instance_attribute_name, span_attribute_key, value_or_transform_function = attr_config
                        if hasattr(instance, instance_attribute_name):
                            attribute_value = getattr(instance, instance_attribute_name)
                            if callable(value_or_transform_function):
                                transformed_value = value_or_transform_function(attribute_value)
                            else:
                                transformed_value = value_or_transform_function
                            span.set_attribute(span_attribute_key, transformed_value)

                try:
                    result = func(*args, **kwargs)
                    return result
                except Exception as e:
                    span.set_status(Status(StatusCode.ERROR))
                    span.record_exception(e)
                    raise

        return wrapper

    if func is None:
        return decorator
    return decorator(func)


logger = logging.getLogger(__name__)


@attach_tracer()
class ContinuousBatchProcessorMetrics:
    """Metrics collection for ContinuousBatchProcessor."""

    def __init__(self, max_batch_tokens: int):
        """Initialize metrics for continuous batch processor.

        Args:
            max_batch_tokens: Maximum number of tokens in a batch
        """
        self.max_batch_tokens = max_batch_tokens

        self._setup_metrics()

    def _setup_metrics(self):
        """Initialize OpenTelemetry metrics and tracing if the library is available."""

        if not _has_opentelemetry:
            logger.info("OpenTelemetry is not installed. Metrics and tracing will not be recorded.")
            return

        self.meter = metrics.get_meter("transformers.generation.continuous_batch_processor")

        # Define appropriate buckets for TTFT (typically ranges from ~50ms to several seconds)
        ttft_buckets = [10, 25, 50, 75, 100, 150, 200, 300, 500, 750, 1000, 2000, 5000, 10000]

        self.ttft_histogram = self.meter.create_histogram(
            name="ttft_milliseconds",
            description="Time to first token in milliseconds",
            unit="ms",
            explicit_bucket_boundaries_advisory=ttft_buckets,
        )

        self.active_requests_gauge = self.meter.create_gauge(
            name="active_requests_count",
            description="Number of active requests currently being processed",
            unit="requests",
        )

        self.waiting_requests_gauge = self.meter.create_gauge(
            name="waiting_requests_count",
            description="Number of requests waiting to be processed",
            unit="requests",
        )

        # Define appropriate buckets for request latency (similar to TTFT but with higher upper bounds)
        latency_buckets = [50, 100, 250, 500, 1000, 2000, 5000, 10000, 20000, 30000, 60000]

        self.request_latency_histogram = self.meter.create_histogram(
            name="request_latency_milliseconds",
            description="End-to-end latency for completed requests in milliseconds",
            unit="ms",
            explicit_bucket_boundaries_advisory=latency_buckets,
        )

        self.decode_prefill_ratio_gauge = self.meter.create_gauge(
            name="decode_prefill_ratio",
            description="Ratio of decode tokens to prefill tokens in a batch",
            unit="ratio",
        )

        self.prefill_tokens_counter = self.meter.create_counter(
            name="prefill_tokens_processed",
            description="Number of prefill tokens processed",
            unit="tokens",
        )

        self.decode_tokens_counter = self.meter.create_counter(
            name="decode_tokens_processed",
            description="Number of decode tokens processed",
            unit="tokens",
        )

        # Define appropriate buckets for batch fill percentage (0-100%)
        batch_fill_buckets = [5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 98, 100]

        self.batch_fill_percentage_histogram = self.meter.create_histogram(
            name="batch_fill_percentage",
            description="Percentage of max_batch_tokens utilized in each batch",
            unit="percent",
            explicit_bucket_boundaries_advisory=batch_fill_buckets,
        )

        self.kv_cache_free_memory_gauge = self.meter.create_gauge(
            name="kv_cache_free_memory_bytes",
            description="Free memory of the PagedAttentionCache in bytes",
            unit="bytes",
        )

        self.kv_cache_memory_gauge = self.meter.create_gauge(
            name="kv_cache_memory_bytes",
            description="Memory usage of the PagedAttentionCache in bytes",
            unit="bytes",
        )

    @traced
    def record_ttft_metric(self, created_time: float, request_id: str) -> None:
        """Record Time to First Token (TTFT).

        Args:
            created_time: The time the request was created
            request_id: The ID of the request
        """
        if not _has_opentelemetry:
            return

        ttft_ms = (time.time() - created_time) * 1000.0

        try:
            self.ttft_histogram.record(ttft_ms)
            logger.debug(f"Recorded TTFT for request {request_id}: {ttft_ms:.2f}ms")
        except Exception as e:
            logger.warning(f"Failed to record TTFT metric: {e}")

    @traced
    def record_batch_metrics(self, requests_in_batch: list) -> None:
        """Record metrics about the batch composition including decode/prefill ratio and batch fill percentage.

        Args:
            requests_in_batch: List of request states in the current batch
        """
        if not _has_opentelemetry or not requests_in_batch:
            return

        decode_tokens = 0
        prefill_tokens = 0

        for state in requests_in_batch:
            if state.status == RequestStatus.DECODING:
                decode_tokens += 1
            elif state.status in [RequestStatus.PREFILLING, RequestStatus.PREFILLING_SPLIT]:
                prefill_tokens += len(state.prompt_ids)

        total_batch_tokens = decode_tokens + prefill_tokens

        try:
            if prefill_tokens > 0:
                self.prefill_tokens_counter.add(prefill_tokens)

            if decode_tokens > 0:
                self.decode_tokens_counter.add(decode_tokens)

            if prefill_tokens > 0:
                ratio = decode_tokens / prefill_tokens
                self.decode_prefill_ratio_gauge.set(ratio)

            fill_percentage = (total_batch_tokens / self.max_batch_tokens) * 100.0
            self.batch_fill_percentage_histogram.record(fill_percentage)
            logger.debug(
                f"Batch metrics: {decode_tokens} decode tokens, {prefill_tokens} prefill tokens, "
                f"batch fill: {fill_percentage:.2f}% ({total_batch_tokens}/{self.max_batch_tokens})"
            )
        except Exception as e:
            logger.warning(f"Failed to record batch metrics: {e}")

    @traced
    def record_kv_cache_memory_metrics(self, cache) -> None:
        """Record memory usage of the PagedAttentionCache without GPU synchronization.

        This calculates the theoretical memory usage based on cache configuration
        and the number of blocks currently in use.

        Args:
            cache: The PagedAttentionCache object to measure
        """
        if not _has_opentelemetry:
            return

        try:
            # Calculate memory usage based on cache configuration
            num_used_blocks = cache.num_blocks - len(cache._free_blocks)
            num_layers = len(cache.key_cache)

            # Each used block stores key and value states
            # Each with shape: (num_kv_heads, block_size, head_dim)
            bytes_per_parameter = 2 if cache.dtype in [torch.float16, torch.bfloat16] else 4  # Size in bytes

            # Total bytes = num_layers * num_used_blocks * block_size *
            #               num_kv_heads * head_dim * 2 (both K and V) * bytes_per_parameter
            memory_bytes = (
                num_layers
                * num_used_blocks
                * cache.block_size
                * cache.num_key_value_heads
                * cache.head_dim
                * 2  # For both key and value caches
                * bytes_per_parameter
            )

            free_memory_bytes = (
                num_layers
                * len(cache._free_blocks)
                * cache.block_size
                * cache.num_key_value_heads
                * cache.head_dim
                * 2  # For both key and value caches
                * bytes_per_parameter
            )

            self.kv_cache_memory_gauge.set(memory_bytes)
            self.kv_cache_free_memory_gauge.set(free_memory_bytes)
            logger.debug(
                f"KV Cache memory: {memory_bytes / (1024 * 1024):.2f}MB, "
                f"Used blocks: {num_used_blocks}/{cache.num_blocks} "
                f"({num_used_blocks / cache.num_blocks * 100:.1f}%)"
            )
        except Exception as e:
            logger.warning(f"Failed to record KV cache memory metrics: {e}")

    @traced
    def record_queue_metrics(self, active_requests: int, waiting_requests: int) -> None:
        """Record metrics about active and waiting requests.

        Args:
            active_requests: Number of active requests
            waiting_requests: Number of waiting requests
        """
        if not _has_opentelemetry:
            return

        try:
            self.active_requests_gauge.set(active_requests)
            self.waiting_requests_gauge.set(waiting_requests)
            logger.debug(f"Queue metrics: {active_requests} active requests, {waiting_requests} waiting requests")
        except Exception as e:
            logger.warning(f"Failed to record queue metrics: {e}")

    @traced
    def record_request_completion(self, created_time: float, request_id: str) -> None:
        """Record metrics about a completed request.

        Args:
            created_time: The time the request was created
            request_id: The ID of the request
        """
        if not _has_opentelemetry:
            return

        latency_ms = (time.time() - created_time) * 1000.0

        try:
            self.request_latency_histogram.record(latency_ms)

            logger.debug(f"Recorded request completion for {request_id}: {latency_ms:.2f}ms")
        except Exception as e:
            logger.warning(f"Failed to record request completion metric: {e}")
