# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for ViTMatte."""

from typing import Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
    BaseImageProcessorFast,
    DefaultFastImageProcessorKwargs,
    group_images_by_shape,
    reorder_images,
)
from ...image_utils import (
    IMAGENET_STANDARD_MEAN,
    IMAGENET_STANDARD_STD,
    ChannelDimension,
    ImageInput,
    get_image_size,
)
from ...processing_utils import Unpack
from ...utils import (
    TensorType,
    auto_docstring,
    filter_out_non_signature_kwargs,
    is_torch_available,
    is_torchvision_available,
    is_torchvision_v2_available,
    logging,
)


if is_torch_available():
    import torch

if is_torchvision_available():
    if is_torchvision_v2_available():
        from torchvision.transforms.v2 import functional as F
    else:
        from torchvision.transforms import functional as F


logger = logging.get_logger(__name__)


class VitMatteFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
    """
    do_pad (`bool`, *optional*, defaults to `True`):
        Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden
        by the `do_pad` parameter in the `preprocess` method.
    size_divisibility (`int`, *optional*, defaults to 32):
        The width and height of the image will be padded to be divisible by this number.
    """

    do_pad: Optional[bool]
    size_divisibility: int


@auto_docstring
class VitMatteImageProcessorFast(BaseImageProcessorFast):
    do_rescale: bool = True
    rescale_factor: Union[int, float] = 1 / 255
    do_normalize: bool = True
    image_mean: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_MEAN
    image_std: Optional[Union[float, list[float]]] = IMAGENET_STANDARD_STD
    do_pad: bool = True
    size_divisibility: int = 32
    valid_kwargs = VitMatteFastImageProcessorKwargs

    def __init__(self, **kwargs: Unpack[VitMatteFastImageProcessorKwargs]) -> None:
        super().__init__(**kwargs)

    def _pad_image(
        self,
        images: "torch.tensor",
        size_divisibility: int = 32,
    ) -> "torch.tensor":
        """
        Pads an image or batched images constantly so that width and height are divisible by size_divisibility

        Args:
            image (`torch,tensor`):
                Image to pad.
            size_divisibility (`int`, *optional*, defaults to 32):
                The width and height of the image will be padded to be divisible by this number.
        """
        height, width = get_image_size(images, channel_dim=ChannelDimension.FIRST)

        pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility
        pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility

        if pad_width + pad_height > 0:
            padding = (0, 0, pad_width, pad_height)
            images = F.pad(images, padding)

        return images

    @auto_docstring
    def preprocess(
        self,
        images: list["torch.Tensor"],
        trimaps: list["torch.Tensor"],
        **kwargs: Unpack[VitMatteFastImageProcessorKwargs],
    ) -> BatchFeature:
        r"""
        trimaps (`list[torch.Tensor]`):
            The trimaps to preprocess.
        """
        return super().preprocess(images, trimaps, **kwargs)

    def _preprocess_image_like_inputs(
        self,
        images: ImageInput,
        trimaps: ImageInput,
        do_convert_rgb: bool,
        input_data_format: ChannelDimension,
        device: Optional[Union[str, "torch.device"]] = None,
        **kwargs: Unpack[VitMatteFastImageProcessorKwargs],
    ) -> BatchFeature:
        """
        Preprocess image-like inputs.
        """
        images = self._prepare_image_like_inputs(
            images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
        )
        trimaps = self._prepare_image_like_inputs(images=trimaps, expected_ndims=2, device=device)

        return self._preprocess(images, trimaps, **kwargs)

    @filter_out_non_signature_kwargs()
    def _preprocess(
        self,
        images: list["torch.Tensor"],
        trimaps: list["torch.Tensor"],
        do_rescale: Optional[bool] = None,
        rescale_factor: Optional[float] = None,
        do_normalize: Optional[bool] = None,
        image_mean: Optional[Union[float, list[float]]] = None,
        image_std: Optional[Union[float, list[float]]] = None,
        do_pad: Optional[bool] = None,
        size_divisibility: Optional[int] = None,
        disable_grouping: Optional[bool] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        **kwargs,
    ) -> BatchFeature:
        grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
        grouped_trimaps, grouped_trimaps_index = group_images_by_shape(trimaps, disable_grouping=disable_grouping)
        processed_images_grouped = {}
        for shape in grouped_images:
            stacked_images = grouped_images[shape]
            stacked_trimaps = grouped_trimaps[shape]
            # Fused rescale and normalize
            stacked_images = self.rescale_and_normalize(
                stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
            )
            stacked_trimaps = self.rescale_and_normalize(
                stacked_trimaps, do_rescale, rescale_factor, False, image_mean, image_std
            )
            stacked_images = torch.cat([stacked_images, stacked_trimaps], dim=1)
            if do_pad:
                stacked_images = self._pad_image(stacked_images, self.size_divisibility)
            processed_images_grouped[shape] = stacked_images

        processed_images = reorder_images(processed_images_grouped, grouped_images_index)
        processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

        return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)


__all__ = ["VitMatteImageProcessorFast"]
