chromium/third_party/tflite_support/src/tensorflow_lite_support/python/task/vision/core/tensor_image.py

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorImage class."""

import numpy as np

from tensorflow_lite_support.python.task.vision.core import color_space_type
from tensorflow_lite_support.python.task.vision.core.pybinds import image_utils


class TensorImage(object):
  """Wrapper class for the Image object."""

  def __init__(self,
               image_data: image_utils.ImageData,
               is_from_numpy_array: bool = True) -> None:
    """Initializes the `TensorImage` object.

    Args:
      image_data: image_utils.ImageData, contains raw image data, width, height
        and channels info.
      is_from_numpy_array: boolean, whether `image_data` is loaded from
        numpy array. if False, it means that `image_data` is loaded from
        stbi_load** function in C++ and need to free the storage of ImageData in
        the destructor.
    """
    self._image_data = image_data
    self._is_from_numpy_array = is_from_numpy_array

    # Gets the FrameBuffer object.

  @classmethod
  def create_from_file(cls, file_name: str) -> "TensorImage":
    """Creates `TensorImage` object from the image file.

    Args:
      file_name: Image file name.

    Returns:
      `TensorImage` object.

    Raises:
      RuntimeError if the image file can't be decoded.
    """
    image_data = image_utils.decode_image_from_file(file_name)
    return cls(image_data, is_from_numpy_array=False)

  @classmethod
  def create_from_array(cls, array: np.ndarray) -> "TensorImage":
    """Creates `TensorImage` object from the numpy array.

    Args:
      array: numpy array with dtype=uint8. Its shape should be either (h, w, 3)
        or (1, h, w, 3) for RGB images, either (h, w) or (1, h, w) for GRAYSCALE
        images and either (h, w, 4) or (1, h, w, 4) for RGBA images.

    Returns:
        `TensorImage` object.

    Raises:
      ValueError if the dytype of the numpy array is not `uint8` or the
        dimention is not the valid dimention.
    """
    if array.dtype != np.uint8:
      raise ValueError("Expect numpy array with dtype=uint8.")

    image_data = image_utils.ImageData(np.squeeze(array))
    return cls(image_data)

  @classmethod
  def create_from_buffer(cls, buffer: str) -> "TensorImage":
    """Creates `TensorImage` object from the binary buffer.

    Args:
      buffer: Binary memory buffer.

    Returns:
      `TensorImage` object.

    Raises:
      RuntimeError if the binary buffer can't be decoded into `TensorImage`
        object.
    """
    image_data = image_utils.decode_image_from_buffer(buffer, len(buffer))
    return cls(image_data, is_from_numpy_array=False)

  def __del__(self) -> None:
    """Destructor to free the storage of ImageData if loaded from the file."""
    if not self._is_from_numpy_array and image_utils:
      # __del__ can be executed during interpreter shutdown, therefore
      # image_utils may not be available.
      # See https://docs.python.org/3/reference/datamodel.html#object.__del__
      image_utils.image_data_free(self._image_data)

  @property
  def buffer(self) -> np.ndarray:
    """Gets the numpy array that represents `self.image_data`.

    Returns:
      Numpy array that represents `self.image_data` which is an
        `image_util.ImageData` object. To avoid copy, we will use
        `return np.array(..., copy = False)`. Therefore, this `TensorImage`
        object should out live the returned numpy array.
    """
    return np.array(self._image_data, copy=False)

  @property
  def height(self) -> int:
    """Gets the height of the image."""
    return self._image_data.height

  @property
  def width(self) -> int:
    """Gets the width of the image."""
    return self._image_data.width

  @property
  def color_space_type(self) -> color_space_type.ColorSpaceType:
    """Gets the color space type of the image."""
    channels = self._image_data.channels
    if channels == 1:
      return color_space_type.ColorSpaceType.GRAYSCALE
    elif channels == 3:
      return color_space_type.ColorSpaceType.RGB
    elif channels == 4:
      return color_space_type.ColorSpaceType.RGBA
    else:
      raise ValueError("Unsupported color space type.")