# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import math
from typing import Optional, Union

import torch

from ...image_processing_utils_fast import (
    BaseImageProcessorFast,
    BatchFeature,
    DefaultFastImageProcessorKwargs,
    SizeDict,
    group_images_by_shape,
    reorder_images,
)
from ...image_utils import (
    IMAGENET_STANDARD_MEAN,
    IMAGENET_STANDARD_STD,
    ImageInput,
    PILImageResampling,
    make_nested_list_of_images,
)
from ...processing_utils import Unpack
from ...utils import TensorType, auto_docstring, is_torchvision_available, logging


if is_torchvision_available():
    from torchvision.transforms import functional as F


logger = logging.get_logger(__name__)

MAX_IMAGE_SIZE = 4096  # 4k resolution as absolute maximum


def _resize_output_size_rescale_to_max_len(
    height: int, width: int, min_len: Optional[int] = 1, max_len: Optional[int] = None
) -> tuple[int, int]:
    """
    Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
    Args:
        height (`int`):
            Height of the input image.
        width (`int`):
            Width of the input image.
        min_len (`int`, *optional*, defaults to 1):
            Minimum size of the output image.
        max_len (`int`, *optional*, defaults to the maximum size of the image):
            Maximum size of the output image.
    Returns:
        The output size of the image after resizing.
    """
    max_len = max(height, width) if max_len is None else max_len
    aspect_ratio = width / height

    if width >= height:
        width = max_len
        height = int(width / aspect_ratio)
        if height % 2 != 0:
            height += 1
    elif height > width:
        height = max_len
        width = int(height * aspect_ratio)
        if width % 2 != 0:
            width += 1

    # Avoid resizing to a size smaller than min_len
    height = max(height, min_len)
    width = max(width, min_len)
    return height, width


def _resize_output_size_scale_below_upper_bound(
    height: int, width: int, max_len: Optional[dict[str, int]] = None
) -> tuple[int, int]:
    """
    Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
    Args:
        height (`int`):
            Height of the input image.
        width (`int`):
            Width of the input image.
        max_len (`Dict[str, int]`, *optional*, defaults to the maximum size of the image):
            Defines the maximum dimensions of the image.
    Returns:
        The output size of the image after resizing.
    """
    max_len = max(height, width) if max_len is None else max_len

    aspect_ratio = width / height
    if width >= height and width > max_len:
        width = max_len
        height = int(width / aspect_ratio)
    elif height > width and height > max_len:
        height = max_len
        width = int(height * aspect_ratio)

    # Avoid resizing to a size smaller than 1
    height = max(height, 1)
    width = max(width, 1)
    return height, width


def get_resize_output_image_size(
    image,
    resolution_max_side: int,
) -> tuple[int, int]:
    """
    Get the output size of the image after resizing given a dictionary specifying the max and min sizes.
    Args:
        image (`torch.Tensor`):
            Image to resize.
        resolution_max_side (`int`):
            The longest edge of the image will be resized to this value. The shortest edge will be resized to keep the
            input aspect ratio.
    Returns:
        The output size of the image after resizing.
    """
    height, width = image.size()[-2:]

    # Find the output size, when rescaling the longest edge to max_len and preserving the aspect ratio
    height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=resolution_max_side)
    # Find the output size when scaling the image to be below the MAX_IMAGE_SIZE
    height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
    return height, width


def get_max_height_width(images_list: list[list["torch.Tensor"]]) -> tuple[int, int]:
    """
    Get the maximum height and width across all images in a batch.
    """
    image_sizes = []
    for images in images_list:
        for image in images:
            image_sizes.append(image.size()[-2:])

    max_height = max(size[0] for size in image_sizes)
    max_width = max(size[1] for size in image_sizes)
    return (max_height, max_width)


def make_pixel_mask(image: "torch.Tensor", output_size: tuple[int, int]) -> "torch.Tensor":
    """
    Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.

    Args:
        image (`torch.Tensor`):
            Image to make the pixel mask for.
        output_size (`Tuple[int, int]`):
            Output size of the mask.
    """
    input_height, input_width = image.size()[-2:]
    mask = torch.zeros(output_size, dtype=torch.int64, device=image.device)
    mask[:input_height, :input_width] = 1
    return mask


class Idefics3FastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
    """
    do_pad (`bool`, *optional*):
        Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest
        number of patches in the batch. Padding will be applied to the bottom and right with zeros.
    do_image_splitting (`bool`, *optional*, defaults to `True`):
        Whether to split the image into sub-images concatenated with the original image. They are split into patches
        such that each patch has a size of `max_image_size["height"]` x `max_image_size["width"]`.
    max_image_size (`Dict`, *optional*, defaults to `{"longest_edge": 364}`):
        Maximum resolution of the patches of images accepted by the model. This is a dictionary containing the key "longest_edge".
    return_row_col_info (`bool`, *optional*, defaults to `False`):
        Whether to return the row and column information of the images.
    """

    do_pad: Optional[bool]
    do_image_splitting: Optional[bool]
    max_image_size: Optional[dict[str, int]]
    return_row_col_info: Optional[bool]


@auto_docstring
class Idefics3ImageProcessorFast(BaseImageProcessorFast):
    resample = PILImageResampling.LANCZOS
    image_mean = IMAGENET_STANDARD_MEAN
    image_std = IMAGENET_STANDARD_STD
    size = {"longest_edge": 4 * 364}
    max_image_size = {"longest_edge": 364}
    do_resize = True
    do_rescale = True
    do_normalize = True
    do_convert_rgb = True
    do_image_splitting = True
    do_pad = True
    return_row_col_info = False
    valid_kwargs = Idefics3FastImageProcessorKwargs

    def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput:
        """
        Prepare a nested images structure for processing.
        """
        return make_nested_list_of_images(images, expected_ndims=expected_ndims)

    def resize(
        self,
        image: "torch.Tensor",
        size: SizeDict,
        interpolation: "F.InterpolationMode" = None,
        antialias: bool = True,
        **kwargs,
    ) -> "torch.Tensor":
        """
        Resize an image. The longest edge of the image is resized to size.longest_edge, with the shortest edge
        resized to keep the input aspect ratio. Can also be used with size.height and size.width.
        Args:
            image (`np.ndarray`):
                Image to resize.
            size (`Dict[str, int]`):
                Size of the output image.
            interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
                `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
            antialias (`bool`, *optional*, defaults to `True`):
                Whether to use antialiasing when resizing the image.
        """
        interpolation = interpolation if interpolation is not None else F.InterpolationMode.BILINEAR
        if interpolation == F.InterpolationMode.LANCZOS:
            logger.warning_once(
                "You have used fast image processor with LANCZOS resample which not yet supported for torch.Tensor. "
                "BICUBIC resample will be used as an alternative. Please fall back to slow image processor if you "
                "want full consistency with the original model."
            )
            interpolation = F.InterpolationMode.BICUBIC

        if size.longest_edge:
            size = get_resize_output_image_size(image, resolution_max_side=size.longest_edge)
        elif size.height and size.width:
            size = (size.height, size.width)
        else:
            raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")

        return F.resize(image, size, interpolation=interpolation, antialias=antialias)

    def split_images(
        self,
        images: torch.Tensor,
        max_image_size: dict[str, int],
        interpolation: "F.InterpolationMode" = None,
    ):
        """
        Split an image into squares of side max_image_size and the original image resized to max_image_size.
        That means that a single image becomes a sequence of images.
        This is a "trick" to spend more compute on each image with no changes in the vision encoder.
        1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
        2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
        sub-images of the same size each (image_size, image_size). Typically, 364x364.
        3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
        Args:
            images (`torch.Tensor`):
                Images to split.
            max_image_size (`Dict[str, int]`):
                Maximum size of the output image. If the image is larger than this size, it will be split into
                patches of this size, and the original image will be concatenated with the patches, resized to max_size.
            interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
                `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
        """
        batch_size, num_channels, height, width = images.size()
        height_dim, width_dim = 2, 3

        max_height = max_width = max_image_size["longest_edge"]

        frames = []
        if height > max_height or width > max_width:
            # Calculate the number of splits
            num_splits_h = math.ceil(height / max_height)
            num_splits_w = math.ceil(width / max_width)

            # Split the images by height, then by width
            frames = (
                images.unfold(height_dim, size=max_height, step=max_height)
                .unfold(width_dim, size=max_width, step=max_width)
                .contiguous()
                .view(batch_size, num_channels, -1, max_height, max_width)
                .permute(0, 2, 1, 3, 4)
            )  # batch_size x n_frames x num_channels x height x width

            # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
            global_image_height, global_image_width = max_height, max_width
            images = self.resize(
                images, SizeDict(height=global_image_height, width=global_image_width), interpolation=interpolation
            )

            frames = torch.cat((frames, images.unsqueeze(1)), dim=1)
        else:
            num_splits_h, num_splits_w = 0, 0
            frames = images.unsqueeze(1)

        num_splits_h = [num_splits_h] * batch_size
        num_splits_w = [num_splits_w] * batch_size

        return frames, num_splits_h, num_splits_w

    def resize_for_vision_encoder(
        self,
        image: torch.Tensor,
        vision_encoder_max_size: int,
        interpolation: "F.InterpolationMode" = None,
    ):
        """
        Resize images to be multiples of `vision_encoder_max_size` while preserving the aspect ratio.
        Args:
            image (`torch.Tensor`):
                Images to resize.
            vision_encoder_max_size (`int`):
                Maximum size of the output image. If the image is larger than this size, it will be split into
                patches of this size, and the original image will be concatenated with the patches, resized to max_size.
            interpolation (`InterpolationMode`, *optional*, defaults to `InterpolationMode.BILINEAR`):
                `InterpolationMode` filter to use when resizing the image e.g. `InterpolationMode.BICUBIC`.
        """
        height, width = image.size()[-2:]

        aspect_ratio = width / height
        if width >= height:
            width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
            height = int(width / aspect_ratio)
            height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
        elif height > width:
            height = math.ceil(height / vision_encoder_max_size) * vision_encoder_max_size
            width = int(height * aspect_ratio)
            width = math.ceil(width / vision_encoder_max_size) * vision_encoder_max_size
        new_size = SizeDict(height=height, width=width)
        return self.resize(image, size=new_size, interpolation=interpolation)

    def pad(
        self,
        image: torch.Tensor,
        padded_size: tuple[int, int],
        fill: int = 0,
        return_pixel_mask: bool = True,
    ):
        original_size = image.shape[-2:]
        padding_bottom = padded_size[0] - original_size[0]
        padding_right = padded_size[1] - original_size[1]

        if padding_bottom < 0 or padding_right < 0:
            raise ValueError(
                f"Padding dimensions are negative. Please make sure that the padded size is larger than the "
                f"original size. Got padded size: {padded_size}, original size: {original_size}."
            )

        # Only pad if necessary
        if original_size != padded_size:
            padding = (0, 0, padding_right, padding_bottom)
            image = F.pad(image, padding, fill=fill, padding_mode="constant")

        # Make a pixel mask for the image, where 1 indicates a valid pixel and 0 indicates padding.
        pixel_mask = None
        if return_pixel_mask:
            pixel_mask = torch.zeros_like(image[..., 0, :, :], dtype=torch.int64)
            pixel_mask[: original_size[0], : original_size[1]] = 1

        return image, pixel_mask

    @auto_docstring
    def preprocess(self, images: ImageInput, **kwargs: Unpack[Idefics3FastImageProcessorKwargs]) -> BatchFeature:
        return super().preprocess(images, **kwargs)

    def _preprocess(
        self,
        images: list[list["torch.Tensor"]],
        do_resize: bool,
        size: SizeDict,
        interpolation: Optional["F.InterpolationMode"],
        do_rescale: bool,
        rescale_factor: float,
        do_normalize: bool,
        image_mean: Optional[Union[float, list[float]]],
        image_std: Optional[Union[float, list[float]]],
        do_pad: Optional[bool],
        do_image_splitting: Optional[bool],
        max_image_size: Optional[dict[str, int]],
        return_row_col_info: Optional[bool],
        disable_grouping: Optional[bool],
        return_tensors: Optional[Union[str, TensorType]],
        **kwargs,
    ) -> BatchFeature:
        """
        Process a batch of images for the model.
        """

        grouped_images, grouped_images_index = group_images_by_shape(
            images, is_nested=True, disable_grouping=disable_grouping
        )
        resized_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            if do_resize:
                stacked_images = self.resize(stacked_images, size, interpolation=interpolation)
            resized_images_grouped[shape] = stacked_images
        resized_images = reorder_images(resized_images_grouped, grouped_images_index, is_nested=True)

        grouped_images, grouped_images_index = group_images_by_shape(
            resized_images, is_nested=True, disable_grouping=disable_grouping
        )
        split_images_grouped = {}
        if do_image_splitting:
            rows_grouped = {}
            cols_grouped = {}
            for shape, stacked_images in grouped_images.items():
                stacked_images = self.resize_for_vision_encoder(
                    stacked_images, max_image_size["longest_edge"], interpolation=interpolation
                )
                stacked_images, rows, cols = self.split_images(
                    stacked_images, max_image_size=max_image_size, interpolation=interpolation
                )
                split_images_grouped[shape] = stacked_images
                rows_grouped[shape] = rows
                cols_grouped[shape] = cols
            processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
            rows = reorder_images(rows_grouped, grouped_images_index, is_nested=True)
            cols = reorder_images(cols_grouped, grouped_images_index, is_nested=True)
            # flattenened the doubly nested list to a nested list
            for i, group_images in enumerate(processed_images):
                processed_images[i] = [image for sublist in group_images for image in sublist]
        else:
            for shape, stacked_images in grouped_images.items():
                # We square the images to max_image_size
                stacked_images = self.resize(
                    image=stacked_images,
                    size=SizeDict(height=max_image_size["longest_edge"], width=max_image_size["longest_edge"]),
                    interpolation=interpolation,
                )
                split_images_grouped[shape] = stacked_images
            processed_images = reorder_images(split_images_grouped, grouped_images_index, is_nested=True)
            rows = [[0] * len(images) for images in processed_images]
            cols = [[0] * len(images) for images in processed_images]
        # Group images by size for further processing
        # Needed in case do_resize is False, or resize returns images with different sizes
        grouped_images, grouped_images_index = group_images_by_shape(
            processed_images, is_nested=True, disable_grouping=disable_grouping
        )
        processed_images_grouped = {}
        for shape, stacked_images in grouped_images.items():
            # Fused rescale and normalize
            stacked_images = self.rescale_and_normalize(
                stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
            )
            processed_images_grouped[shape] = stacked_images
        processed_images = reorder_images(processed_images_grouped, grouped_images_index, is_nested=True)
        if do_pad:
            # Get max images per batch
            max_num_images = max(len(images_) for images_ in processed_images)
            max_height, max_width = get_max_height_width(processed_images)

            processed_images_padded = torch.zeros(
                len(processed_images),
                max_num_images,
                *(processed_images[0][0].shape[0], max_height, max_width),
                device=processed_images[0][0].device,
            )
            pixel_attention_masks = torch.zeros(
                len(processed_images),
                max_num_images,
                *(max_height, max_width),
                device=processed_images[0][0].device,
            )
            for i, images in enumerate(processed_images):
                for j, image in enumerate(images):
                    processed_images_padded[i, j], pixel_attention_masks[i, j] = self.pad(
                        image, (max_height, max_width)
                    )
            processed_images = processed_images_padded

        if do_pad:
            data = {"pixel_values": processed_images, "pixel_attention_mask": pixel_attention_masks}
        elif return_tensors == "pt":
            data = {"pixel_values": torch.stack([torch.stack(images) for images in processed_images])}
        else:
            data = {"pixel_values": processed_images}
        # This is needed for generating correct text inputs in the processor - we don't pad to the max number of images
        encoding = BatchFeature(data=data, tensor_type=return_tensors)

        if return_row_col_info:
            encoding["rows"] = rows
            encoding["cols"] = cols

        return encoding

    def to_dict(self):
        encoder_dict = super().to_dict()
        encoder_dict.pop("_valid_processor_keys", None)
        encoder_dict.pop("return_row_col_info", None)
        return encoder_dict

    def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
        """
        A utility that returns number of image patches for a given image size.

        Args:
            height (`int`):
                Height of the input image.
            width (`int`):
                Width of the input image.
            images_kwargs (`dict`, *optional*)
                Any kwargs to override defaults of the image processor.
        Returns:
            `int`: Number of patches per image.
        """
        do_image_splitting = images_kwargs.get("do_image_splitting", self.do_image_splitting)
        max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
        size = images_kwargs.get("size", self.size)

        num_patches = num_rows = num_cols = 1
        if do_image_splitting:
            height, width = _resize_output_size_rescale_to_max_len(height, width, max_len=size["longest_edge"])
            height, width = _resize_output_size_scale_below_upper_bound(height, width, max_len=MAX_IMAGE_SIZE)
            aspect_ratio = width / height

            if width >= height:
                resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
                resized_height = int(width / aspect_ratio)
                resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
            elif height > width:
                resized_height = math.ceil(height / max_image_size["longest_edge"]) * max_image_size["longest_edge"]
                resized_width = int(height * aspect_ratio)
                resized_width = math.ceil(width / max_image_size["longest_edge"]) * max_image_size["longest_edge"]

            max_height = max_width = max_image_size["longest_edge"]
            if resized_height > max_height or resized_width > max_width:
                # Calculate the number of splits
                num_rows = math.ceil(resized_height / max_height)
                num_cols = math.ceil(resized_width / max_width)
                num_patches = num_rows * num_cols + 1

        return num_patches, num_rows, num_cols


__all__ = ["Idefics3ImageProcessorFast"]
