chromium/third_party/tflite_support/src/tensorflow_lite_support/python/test/task/vision/image_segmenter_test.py

# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for image_segmenter."""

import enum

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import segmentation_options_pb2
from tensorflow_lite_support.python.task.processor.proto import segmentations_pb2
from tensorflow_lite_support.python.task.vision import image_segmenter
from tensorflow_lite_support.python.task.vision.core import tensor_image
from tensorflow_lite_support.python.test import test_util

_BaseOptions = base_options_module.BaseOptions
_ColoredLabel = segmentations_pb2.ColoredLabel
_OutputType = segmentation_options_pb2.OutputType
_ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions

_MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_EXPECTED_COLORED_LABELS = [
    _ColoredLabel(color=(0, 0, 0), category_name='background', display_name=''),
    _ColoredLabel(
        color=(128, 0, 0), category_name='aeroplane', display_name=''),
    _ColoredLabel(color=(0, 128, 0), category_name='bicycle', display_name=''),
    _ColoredLabel(color=(128, 128, 0), category_name='bird', display_name=''),
    _ColoredLabel(color=(0, 0, 128), category_name='boat', display_name=''),
    _ColoredLabel(color=(128, 0, 128), category_name='bottle', display_name=''),
    _ColoredLabel(color=(0, 128, 128), category_name='bus', display_name=''),
    _ColoredLabel(color=(128, 128, 128), category_name='car', display_name=''),
    _ColoredLabel(color=(64, 0, 0), category_name='cat', display_name=''),
    _ColoredLabel(color=(192, 0, 0), category_name='chair', display_name=''),
    _ColoredLabel(color=(64, 128, 0), category_name='cow', display_name=''),
    _ColoredLabel(
        color=(192, 128, 0), category_name='dining table', display_name=''),
    _ColoredLabel(color=(64, 0, 128), category_name='dog', display_name=''),
    _ColoredLabel(color=(192, 0, 128), category_name='horse', display_name=''),
    _ColoredLabel(
        color=(64, 128, 128), category_name='motorbike', display_name=''),
    _ColoredLabel(
        color=(192, 128, 128), category_name='person', display_name=''),
    _ColoredLabel(
        color=(0, 64, 0), category_name='potted plant', display_name=''),
    _ColoredLabel(color=(128, 64, 0), category_name='sheep', display_name=''),
    _ColoredLabel(color=(0, 192, 0), category_name='sofa', display_name=''),
    _ColoredLabel(color=(128, 192, 0), category_name='train', display_name=''),
    _ColoredLabel(color=(0, 64, 128), category_name='tv', display_name='')
]
_MASK_MAGNIFICATION_FACTOR = 10
_MATCH_PIXELS_THRESHOLD = 0.01


def _create_segmenter_from_options(base_options, **segmentation_options):
  segmentation_options = segmentation_options_pb2.SegmentationOptions(
      **segmentation_options)
  options = _ImageSegmenterOptions(
      base_options=base_options, segmentation_options=segmentation_options)
  segmenter = _ImageSegmenter.create_from_options(options)
  return segmenter


class ModelFileType(enum.Enum):
  FILE_CONTENT = 1
  FILE_NAME = 2


class ImageSegmenterTest(parameterized.TestCase, tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.test_image_path = test_util.get_test_data_path(_IMAGE_FILE)
    self.test_seg_path = test_util.get_test_data_path(_SEGMENTATION_FILE)
    self.model_path = test_util.get_test_data_path(_MODEL_FILE)

  def test_create_from_file_succeeds_with_valid_model_path(self):
    # Creates with default option and valid model file successfully.
    segmenter = _ImageSegmenter.create_from_file(self.model_path)
    self.assertIsInstance(segmenter, _ImageSegmenter)

  def test_create_from_options_succeeds_with_valid_model_path(self):
    # Creates with options containing model file successfully.
    base_options = _BaseOptions(file_name=self.model_path)
    options = _ImageSegmenterOptions(base_options=base_options)
    segmenter = _ImageSegmenter.create_from_options(options)
    self.assertIsInstance(segmenter, _ImageSegmenter)

  def test_create_from_options_fails_with_invalid_model_path(self):
    # Invalid empty model path.
    with self.assertRaisesRegex(
        ValueError,
        r"ExternalFile must specify at least one of 'file_content', "
        r"'file_name' or 'file_descriptor_meta'."):
      base_options = _BaseOptions(file_name='')
      options = _ImageSegmenterOptions(base_options=base_options)
      _ImageSegmenter.create_from_options(options)

  def test_create_from_options_succeeds_with_valid_model_content(self):
    # Creates with options containing model content successfully.
    with open(self.model_path, 'rb') as f:
      base_options = _BaseOptions(file_content=f.read())
      options = _ImageSegmenterOptions(base_options=base_options)
      segmenter = _ImageSegmenter.create_from_options(options)
      self.assertIsInstance(segmenter, _ImageSegmenter)

  @parameterized.parameters(
      (ModelFileType.FILE_NAME, _EXPECTED_COLORED_LABELS),
      (ModelFileType.FILE_CONTENT, _EXPECTED_COLORED_LABELS))
  def test_segment_model(self, model_file_type, expected_colored_labels):
    # Creates segmenter.
    if model_file_type is ModelFileType.FILE_NAME:
      base_options = _BaseOptions(file_name=self.model_path)
    elif model_file_type is ModelFileType.FILE_CONTENT:
      with open(self.model_path, 'rb') as f:
        model_content = f.read()
      base_options = _BaseOptions(file_content=model_content)
    else:
      # Should never happen
      raise ValueError('model_file_type is invalid.')

    segmenter = _create_segmenter_from_options(base_options)

    # Loads image.
    image = tensor_image.TensorImage.create_from_file(self.test_image_path)

    # Performs image segmentation on the input.
    segmentation = segmenter.segment(image).segmentations[0]
    colored_labels = segmentation.colored_labels

    # Comparing results.
    self.assertEqual(colored_labels, expected_colored_labels,
                     'Colored labels do not match.')

  def test_segmentation_category_mask(self):
    """Check if category mask matches with ground truth."""
    # Creates segmenter.
    base_options = _BaseOptions(file_name=self.model_path)
    segmenter = _create_segmenter_from_options(
        base_options, output_type=_OutputType.CATEGORY_MASK)

    # Loads image.
    image = tensor_image.TensorImage.create_from_file(self.test_image_path)

    # Performs image segmentation on the input.
    segmentation = segmenter.segment(image).segmentations[0]
    result_pixels = segmentation.category_mask.flatten()

    # Check if data type of `confidence_masks` are correct.
    self.assertEqual(result_pixels.dtype, np.uint8)

    # Loads ground truth segmentation file.
    gt_segmentation = tensor_image.TensorImage.create_from_file(
        self.test_seg_path)
    gt_segmentation_array = gt_segmentation.buffer
    gt_segmentation_shape = gt_segmentation_array.shape
    num_pixels = gt_segmentation_shape[0] * gt_segmentation_shape[1]
    ground_truth_pixels = gt_segmentation_array.flatten()

    self.assertEqual(
        len(result_pixels), len(ground_truth_pixels),
        'Segmentation mask size does not match the ground truth mask size.')

    inconsistent_pixels = 0

    for index in range(num_pixels):
      inconsistent_pixels += (
          result_pixels[index] * _MASK_MAGNIFICATION_FACTOR !=
          ground_truth_pixels[index])

    self.assertLessEqual(
        inconsistent_pixels / num_pixels, _MATCH_PIXELS_THRESHOLD,
        f'Number of pixels in the candidate mask differing from that of the '
        f'ground truth mask exceeds {_MATCH_PIXELS_THRESHOLD}.')

  def test_segmentation_confidence_mask_matches_category_mask(self):
    """Check if the confidence mask matches with the category mask."""
    # Create BaseOptions from model file.
    base_options = _BaseOptions(file_name=self.model_path)

    # Loads image.
    image = tensor_image.TensorImage.create_from_file(self.test_image_path)

    # Run segmentation on the model in CATEGORY_MASK mode.
    segmenter = _create_segmenter_from_options(
        base_options, output_type=_OutputType.CATEGORY_MASK)

    # Performs image segmentation on the input and gets the category mask.
    segmentation = segmenter.segment(image).segmentations[0]
    category_mask = segmentation.category_mask

    # Run segmentation on the model in CONFIDENCE_MASK mode.
    segmenter = _create_segmenter_from_options(
        base_options, output_type=_OutputType.CONFIDENCE_MASK)

    # Performs image segmentation on the input again.
    segmentation = segmenter.segment(image).segmentations[0]
    # Gets the list of confidence masks and colored_labels.
    confidence_masks = segmentation.confidence_masks
    colored_labels = segmentation.colored_labels

    # Check if confidence mask shape is correct.
    self.assertEqual(
        len(confidence_masks), len(colored_labels),
        'Number of confidence masks must match with number of categories.')

    # Gather the confidence masks in a single array `confidence_mask_array`.
    confidence_mask_array = np.array(
        [confidence_mask.value for confidence_mask in confidence_masks])

    # Check if data type of `confidence_masks` are correct.
    self.assertEqual(confidence_mask_array.dtype, float)

    # Compute the category mask from the created confidence mask.
    calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
    self.assertListEqual(
        calculated_category_mask.tolist(), category_mask.tolist(),
        'Confidence mask does not match with the category mask.')


if __name__ == '__main__':
  tf.test.main()