chromium/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/metadata_writers/metadata_info.py

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper classes for common model metadata information."""

import collections
import csv
import os
from typing import List, Optional, Type, Union

from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils

# Min and max values for UINT8 tensors.
_MIN_UINT8 = 0
_MAX_UINT8 = 255

# Default description for vocabulary files.
_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language "
                           "words to embedding vectors.")


class GeneralMd:
  """A container for common metadata information of a model.

  Attributes:
    name: name of the model.
    version: version of the model.
    description: description of what the model does.
    author: author of the model.
    licenses: licenses of the model.
  """

  def __init__(self,
               name: Optional[str] = None,
               version: Optional[str] = None,
               description: Optional[str] = None,
               author: Optional[str] = None,
               licenses: Optional[str] = None):
    self.name = name
    self.version = version
    self.description = description
    self.author = author
    self.licenses = licenses

  def create_metadata(self) -> _metadata_fb.ModelMetadataT:
    """Creates the model metadata based on the general model information.

    Returns:
      A Flatbuffers Python object of the model metadata.
    """
    model_metadata = _metadata_fb.ModelMetadataT()
    model_metadata.name = self.name
    model_metadata.version = self.version
    model_metadata.description = self.description
    model_metadata.author = self.author
    model_metadata.license = self.licenses
    return model_metadata


class AssociatedFileMd:
  """A container for common associated file metadata information.

  Attributes:
    file_path: path to the associated file.
    description: description of the associated file.
    file_type: file type of the associated file [1].
    locale: locale of the associated file [2].
    [1]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L77
    [2]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154
  """

  def __init__(
      self,
      file_path: str,
      description: Optional[str] = None,
      file_type: Optional[_metadata_fb.AssociatedFileType] = _metadata_fb
      .AssociatedFileType.UNKNOWN,
      locale: Optional[str] = None):
    self.file_path = file_path
    self.description = description
    self.file_type = file_type
    self.locale = locale

  def create_metadata(self) -> _metadata_fb.AssociatedFileT:
    """Creates the associated file metadata.

    Returns:
      A Flatbuffers Python object of the associated file metadata.
    """
    file_metadata = _metadata_fb.AssociatedFileT()
    file_metadata.name = os.path.basename(self.file_path)
    file_metadata.description = self.description
    file_metadata.type = self.file_type
    file_metadata.locale = self.locale
    return file_metadata


class LabelFileMd(AssociatedFileMd):
  """A container for label file metadata information."""

  _LABEL_FILE_DESCRIPTION = ("Labels for categories that the model can "
                             "recognize.")
  _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS

  def __init__(self, file_path: str, locale: Optional[str] = None):
    """Creates a LabelFileMd object.

    Args:
      file_path: file_path of the label file.
      locale: locale of the label file [1].
      [1]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154
    """
    super().__init__(file_path, self._LABEL_FILE_DESCRIPTION, self._FILE_TYPE,
                     locale)


class RegexTokenizerMd:
  """A container for the Regex tokenizer [1] metadata information.

  [1]:
    https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L459
  """

  def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
    """Initializes a RegexTokenizerMd object.

    Args:
      delim_regex_pattern: the regular expression to segment strings and create
        tokens.
      vocab_file_path: path to the vocabulary file.
    """
    self._delim_regex_pattern = delim_regex_pattern
    self._vocab_file_path = vocab_file_path

  def create_metadata(self) -> _metadata_fb.ProcessUnitT:
    """Creates the Bert tokenizer metadata based on the information.

    Returns:
      A Flatbuffers Python object of the Bert tokenizer metadata.
    """
    vocab = _metadata_fb.AssociatedFileT()
    vocab.name = self._vocab_file_path
    vocab.description = _VOCAB_FILE_DESCRIPTION
    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY

    # Create the RegexTokenizer.
    tokenizer = _metadata_fb.ProcessUnitT()
    tokenizer.optionsType = (
        _metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
    tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
    tokenizer.options.delimRegexPattern = self._delim_regex_pattern
    tokenizer.options.vocabFile = [vocab]
    return tokenizer


class BertTokenizerMd:
  """A container for the Bert tokenizer [1] metadata information.

  [1]:
    https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
  """

  def __init__(self, vocab_file_path: str):
    """Initializes a BertTokenizerMd object.

    Args:
      vocab_file_path: path to the vocabulary file.
    """
    self._vocab_file_path = vocab_file_path

  def create_metadata(self) -> _metadata_fb.ProcessUnitT:
    """Creates the Bert tokenizer metadata based on the information.

    Returns:
      A Flatbuffers Python object of the Bert tokenizer metadata.
    """
    vocab = _metadata_fb.AssociatedFileT()
    vocab.name = self._vocab_file_path
    vocab.description = _VOCAB_FILE_DESCRIPTION
    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
    tokenizer = _metadata_fb.ProcessUnitT()
    tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
    tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
    tokenizer.options.vocabFile = [vocab]
    return tokenizer


class SentencePieceTokenizerMd:
  """A container for the sentence piece tokenizer [1] metadata information.

  [1]:
    https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
  """

  _SP_MODEL_DESCRIPTION = "The sentence piece model file."
  _SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
      " This file is optional during tokenization, while the sentence piece "
      "model is mandatory.")

  def __init__(self,
               sentence_piece_model_path: str,
               vocab_file_path: Optional[str] = None):
    """Initializes a SentencePieceTokenizerMd object.

    Args:
      sentence_piece_model_path: path to the sentence piece model file.
      vocab_file_path: path to the vocabulary file.
    """
    self._sentence_piece_model_path = sentence_piece_model_path
    self._vocab_file_path = vocab_file_path

  def create_metadata(self) -> _metadata_fb.ProcessUnitT:
    """Creates the sentence piece tokenizer metadata based on the information.

    Returns:
      A Flatbuffers Python object of the sentence piece tokenizer metadata.
    """
    tokenizer = _metadata_fb.ProcessUnitT()
    tokenizer.optionsType = (
        _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
    tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()

    sp_model = _metadata_fb.AssociatedFileT()
    sp_model.name = self._sentence_piece_model_path
    sp_model.description = self._SP_MODEL_DESCRIPTION
    tokenizer.options.sentencePieceModel = [sp_model]
    if self._vocab_file_path:
      vocab = _metadata_fb.AssociatedFileT()
      vocab.name = self._vocab_file_path
      vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
      vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
      tokenizer.options.vocabFile = [vocab]
    return tokenizer


class ScoreCalibrationMd:
  """A container for score calibration [1] metadata information.

  [1]:
    https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
  """

  _SCORE_CALIBRATION_FILE_DESCRIPTION = (
      "Contains sigmoid-based score calibration parameters. The main purposes "
      "of score calibration is to make scores across classes comparable, so "
      "that a common threshold can be used for all output classes.")
  _FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_SCORE_CALIBRATION

  def __init__(self,
               score_transformation_type: _metadata_fb.ScoreTransformationType,
               default_score: float, file_path: str):
    """Creates a ScoreCalibrationMd object.

    Args:
      score_transformation_type: type of the function used for transforming the
        uncalibrated score before applying score calibration.
      default_score: the default calibrated score to apply if the uncalibrated
        score is below min_score or if no parameters were specified for a given
        index.
      file_path: file_path of the score calibration file [1].
      [1]:
        https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L122

    Raises:
      ValueError: if the score_calibration file is malformed.
    """
    self._score_transformation_type = score_transformation_type
    self._default_score = default_score
    self._file_path = file_path

    # Sanity check the score calibration file.
    with open(self._file_path) as calibration_file:
      csv_reader = csv.reader(calibration_file, delimiter=",")
      for row in csv_reader:
        if row and len(row) != 3 and len(row) != 4:
          raise ValueError(
              f"Expected empty lines or 3 or 4 parameters per line in score"
              f" calibration file, but got {len(row)}.")

        if row and float(row[0]) < 0:
          raise ValueError(
              f"Expected scale to be a non-negative value, but got "
              f"{float(row[0])}.")

  def create_metadata(self) -> _metadata_fb.ProcessUnitT:
    """Creates the score calibration metadata based on the information.

    Returns:
      A Flatbuffers Python object of the score calibration metadata.
    """
    score_calibration = _metadata_fb.ProcessUnitT()
    score_calibration.optionsType = (
        _metadata_fb.ProcessUnitOptions.ScoreCalibrationOptions)
    options = _metadata_fb.ScoreCalibrationOptionsT()
    options.scoreTransformation = self._score_transformation_type
    options.defaultScore = self._default_score
    score_calibration.options = options
    return score_calibration

  def create_score_calibration_file_md(self) -> AssociatedFileMd:
    return AssociatedFileMd(self._file_path,
                            self._SCORE_CALIBRATION_FILE_DESCRIPTION,
                            self._FILE_TYPE)


class TensorMd:
  """A container for common tensor metadata information.

  Attributes:
    name: name of the tensor.
    description: description of what the tensor is.
    min_values: per-channel minimum value of the tensor.
    max_values: per-channel maximum value of the tensor.
    content_type: content_type of the tensor.
    associated_files: information of the associated files in the tensor.
    tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
      used to locate the corresponding tensor and decide the order of the tensor
      metadata [2] when populating model metadata.
    [1]:
      https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
    [2]:
      https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612
  """

  def __init__(self,
               name: Optional[str] = None,
               description: Optional[str] = None,
               min_values: Optional[List[float]] = None,
               max_values: Optional[List[float]] = None,
               content_type: _metadata_fb.ContentProperties = _metadata_fb
               .ContentProperties.FeatureProperties,
               associated_files: Optional[List[Type[AssociatedFileMd]]] = None,
               tensor_name: Optional[str] = None):
    self.name = name
    self.description = description
    self.min_values = min_values
    self.max_values = max_values
    self.content_type = content_type
    self.associated_files = associated_files
    self.tensor_name = tensor_name

  def create_metadata(self) -> _metadata_fb.TensorMetadataT:
    """Creates the input tensor metadata based on the information.

    Returns:
      A Flatbuffers Python object of the input metadata.
    """
    tensor_metadata = _metadata_fb.TensorMetadataT()
    tensor_metadata.name = self.name
    tensor_metadata.description = self.description

    # Create min and max values
    stats = _metadata_fb.StatsT()
    stats.max = self.max_values
    stats.min = self.min_values
    tensor_metadata.stats = stats

    # Create content properties
    content = _metadata_fb.ContentT()
    if self.content_type is _metadata_fb.ContentProperties.FeatureProperties:
      content.contentProperties = _metadata_fb.FeaturePropertiesT()
    elif self.content_type is _metadata_fb.ContentProperties.ImageProperties:
      content.contentProperties = _metadata_fb.ImagePropertiesT()
    elif self.content_type is (
        _metadata_fb.ContentProperties.BoundingBoxProperties):
      content.contentProperties = _metadata_fb.BoundingBoxPropertiesT()
    elif self.content_type is _metadata_fb.ContentProperties.AudioProperties:
      content.contentProperties = _metadata_fb.AudioPropertiesT()

    content.contentPropertiesType = self.content_type
    tensor_metadata.content = content

    # TODO(b/174091474): check if multiple label files have populated locale.
    # Create associated files
    if self.associated_files:
      tensor_metadata.associatedFiles = [
          file.create_metadata() for file in self.associated_files
      ]
    return tensor_metadata


class InputImageTensorMd(TensorMd):
  """A container for input image tensor metadata information.

  Attributes:
    norm_mean: the mean value used in tensor normalization [1].
    norm_std: the std value used in the tensor normalization [1]. norm_mean and
      norm_std must have the same dimension.
    color_space_type: the color space type of the input image [2].
    [1]:
      https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
    [2]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172
  """

  # Min and max float values for image pixels.
  _MIN_PIXEL = 0.0
  _MAX_PIXEL = 255.0

  def __init__(
      self,
      name: Optional[str] = None,
      description: Optional[str] = None,
      norm_mean: Optional[List[float]] = None,
      norm_std: Optional[List[float]] = None,
      color_space_type: Optional[
          _metadata_fb.ColorSpaceType] = _metadata_fb.ColorSpaceType.UNKNOWN,
      tensor_type: Optional[_schema_fb.TensorType] = None):
    """Initializes the instance of InputImageTensorMd.

    Args:
      name: name of the tensor.
      description: description of what the tensor is.
      norm_mean: the mean value used in tensor normalization [1].
      norm_std: the std value used in the tensor normalization [1]. norm_mean
        and norm_std must have the same dimension.
      color_space_type: the color space type of the input image [2].
      tensor_type: data type of the tensor.
      [1]:
        https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
      [2]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172

    Raises:
      ValueError: if norm_mean and norm_std have different dimensions.
    """
    if norm_std and norm_mean and len(norm_std) != len(norm_mean):
      raise ValueError(
          f"norm_mean and norm_std are expected to be the same dim. But got "
          f"{len(norm_mean)} and {len(norm_std)}")

    if tensor_type is _schema_fb.TensorType.UINT8:
      min_values = [_MIN_UINT8]
      max_values = [_MAX_UINT8]
    elif tensor_type is _schema_fb.TensorType.FLOAT32 and norm_std and norm_mean:
      min_values = [
          float(self._MIN_PIXEL - mean) / std
          for mean, std in zip(norm_mean, norm_std)
      ]
      max_values = [
          float(self._MAX_PIXEL - mean) / std
          for mean, std in zip(norm_mean, norm_std)
      ]
    else:
      # Uint8 and Float32 are the two major types currently. And Task library
      # doesn't support other types so far.
      min_values = None
      max_values = None

    super().__init__(name, description, min_values, max_values,
                     _metadata_fb.ContentProperties.ImageProperties)
    self.norm_mean = norm_mean
    self.norm_std = norm_std
    self.color_space_type = color_space_type

  def create_metadata(self) -> _metadata_fb.TensorMetadataT:
    """Creates the input image metadata based on the information.

    Returns:
      A Flatbuffers Python object of the input image metadata.
    """
    tensor_metadata = super().create_metadata()
    tensor_metadata.content.contentProperties.colorSpace = self.color_space_type
    # Create normalization parameters
    if self.norm_mean and self.norm_std:
      normalization = _metadata_fb.ProcessUnitT()
      normalization.optionsType = (
          _metadata_fb.ProcessUnitOptions.NormalizationOptions)
      normalization.options = _metadata_fb.NormalizationOptionsT()
      normalization.options.mean = self.norm_mean
      normalization.options.std = self.norm_std
      tensor_metadata.processUnits = [normalization]
    return tensor_metadata


class InputTextTensorMd(TensorMd):
  """A container for the input text tensor metadata information.

  Attributes:
    tokenizer_md: information of the tokenizer in the input text tensor, if any.
  """

  def __init__(self,
               name: Optional[str] = None,
               description: Optional[str] = None,
               tokenizer_md: Optional[RegexTokenizerMd] = None):
    """Initializes the instance of InputTextTensorMd.

    Args:
      name: name of the tensor.
      description: description of what the tensor is.
      tokenizer_md: information of the tokenizer in the input text tensor, if
        any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
        is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
        `bert_nl_classifier.MetadataWriter`.
        [1]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
        [2]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
        [3]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
    """
    super().__init__(name, description)
    self.tokenizer_md = tokenizer_md

  def create_metadata(self) -> _metadata_fb.TensorMetadataT:
    """Creates the input text metadata based on the information.

    Returns:
      A Flatbuffers Python object of the input text metadata.

    Raises:
      ValueError: if the type of tokenizer_md is unsupported.
    """
    if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
      raise ValueError(
          f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
          f"unsupported")

    tensor_metadata = super().create_metadata()
    if self.tokenizer_md:
      tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
    return tensor_metadata


class InputAudioTensorMd(TensorMd):
  """A container for the input audio tensor metadata information.

  Attributes:
    sample_rate: the sample rate in Hz when the audio was captured.
    channels: the channel count of the audio.
  """

  def __init__(self,
               name: Optional[str] = None,
               description: Optional[str] = None,
               sample_rate: int = 0,
               channels: int = 0):
    """Initializes the instance of InputAudioTensorMd.

    Args:
      name: name of the tensor.
      description: description of what the tensor is.
      sample_rate: the sample rate in Hz when the audio was captured.
      channels: the channel count of the audio.
    """
    super().__init__(
        name,
        description,
        content_type=_metadata_fb.ContentProperties.AudioProperties)

    self.sample_rate = sample_rate
    self.channels = channels

  def create_metadata(self) -> _metadata_fb.TensorMetadataT:
    """Creates the input audio metadata based on the information.

    Returns:
      A Flatbuffers Python object of the input audio metadata.

    Raises:
      ValueError: if any value of sample_rate, channels is negative.
    """
    # 0 is the default value in Flatbuffers.
    if self.sample_rate < 0:
      raise ValueError(
          f"sample_rate should be non-negative, but got {self.sample_rate}.")

    if self.channels < 0:
      raise ValueError(
          f"channels should be non-negative, but got {self.channels}.")

    tensor_metadata = super().create_metadata()
    properties = tensor_metadata.content.contentProperties
    properties.sampleRate = self.sample_rate
    properties.channels = self.channels

    return tensor_metadata


class ClassificationTensorMd(TensorMd):
  """A container for the classification tensor metadata information.

  Attributes:
    label_files: information of the label files [1] in the classification
      tensor.
    score_calibration_md: information of the score calibration operation [2] in
      the classification tensor.
    [1]:
      https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
    [2]:
      https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
  """

  # Min and max float values for classification results.
  _MIN_FLOAT = 0.0
  _MAX_FLOAT = 1.0

  def __init__(self,
               name: Optional[str] = None,
               description: Optional[str] = None,
               label_files: Optional[List[LabelFileMd]] = None,
               tensor_type: Optional[_schema_fb.TensorType] = None,
               score_calibration_md: Optional[ScoreCalibrationMd] = None,
               tensor_name: Optional[str] = None):
    """Initializes the instance of ClassificationTensorMd.

    Args:
      name: name of the tensor.
      description: description of what the tensor is.
      label_files: information of the label files [1] in the classification
        tensor.
      tensor_type: data type of the tensor.
      score_calibration_md: information of the score calibration files operation
        [2] in the classification tensor.
      tensor_name: name of the corresponding tensor [3] in the TFLite model. It
        is used to locate the corresponding classification tensor and decide the
        order of the tensor metadata [4] when populating model metadata.
      [1]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
      [2]:
        https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
      [3]:
        https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
      [4]:
        https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612
    """
    self.score_calibration_md = score_calibration_md

    if tensor_type is _schema_fb.TensorType.UINT8:
      min_values = [_MIN_UINT8]
      max_values = [_MAX_UINT8]
    elif tensor_type is _schema_fb.TensorType.FLOAT32:
      min_values = [self._MIN_FLOAT]
      max_values = [self._MAX_FLOAT]
    else:
      # Uint8 and Float32 are the two major types currently. And Task library
      # doesn't support other types so far.
      min_values = None
      max_values = None

    associated_files = label_files or []
    if self.score_calibration_md:
      associated_files.append(
          score_calibration_md.create_score_calibration_file_md())

    super().__init__(name, description, min_values, max_values,
                     _metadata_fb.ContentProperties.FeatureProperties,
                     associated_files, tensor_name)

  def create_metadata(self) -> _metadata_fb.TensorMetadataT:
    """Creates the classification tensor metadata based on the information."""
    tensor_metadata = super().create_metadata()
    if self.score_calibration_md:
      tensor_metadata.processUnits = [
          self.score_calibration_md.create_metadata()
      ]
    return tensor_metadata


class CategoryTensorMd(TensorMd):
  """A container for the category tensor metadata information."""

  def __init__(self,
               name: Optional[str] = None,
               description: Optional[str] = None,
               label_files: Optional[List[LabelFileMd]] = None):
    """Initializes a CategoryTensorMd object.

    Args:
      name: name of the tensor.
      description: description of what the tensor is.
      label_files: information of the label files [1] in the category tensor.
      [1]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108
    """
    # In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
    value_label_files = label_files
    if value_label_files:
      for file in value_label_files:
        file.file_type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS

    super().__init__(
        name=name, description=description, associated_files=value_label_files)


class BertInputTensorsMd:
  """A container for the input tensor metadata information of Bert models."""

  _IDS_NAME = "ids"
  _IDS_DESCRIPTION = "Tokenized ids of the input text."
  _MASK_NAME = "mask"
  _MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
                       "tokens.")
  _SEGMENT_IDS_NAME = "segment_ids"
  _SEGMENT_IDS_DESCRIPTION = (
      "0 for the first sequence, 1 for the second sequence if exists.")

  def __init__(self,
               model_buffer: bytearray,
               ids_name: str,
               mask_name: str,
               segment_name: str,
               ids_md: Optional[TensorMd] = None,
               mask_md: Optional[TensorMd] = None,
               segment_ids_md: Optional[TensorMd] = None,
               tokenizer_md: Union[None, BertTokenizerMd,
                                   SentencePieceTokenizerMd] = None):
    """Initializes a BertInputTensorsMd object.

    `ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
    in the TFLite schema, which help to determine the tensor order when
    populating metadata.

    Args:
      model_buffer: valid buffer of the model file.
      ids_name: name of the ids tensor, which represents the tokenized ids of
        the input text.
      mask_name: name of the mask tensor, which represents the mask with 1 for
        real tokens and 0 for padding tokens.
      segment_name: name of the segment ids tensor, where `0` stands for the
        first sequence, and `1` stands for the second sequence if exists.
      ids_md: input ids tensor informaton.
      mask_md: input mask tensor informaton.
      segment_ids_md: input segment tensor informaton.
      tokenizer_md: information of the tokenizer used to process the input
        string, if any. Supported tokenziers are: `BertTokenizer` [1] and
          `SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer`
          [3], refer to `nl_classifier.MetadataWriter`.
        [1]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
        [2]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
        [3]:
        https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
    """

    self._input_names = [ids_name, mask_name, segment_name]

    # Get the input tensor names in order from the model. Later, we need to
    # order the input metadata according to this tensor order.
    self._ordered_input_names = writer_utils.get_input_tensor_names(
        model_buffer)

    # Verify that self._ordered_input_names (read from the model) and
    # self._input_name (collected from users) are aligned.
    if collections.Counter(self._ordered_input_names) != collections.Counter(
        self._input_names):
      raise ValueError(
          f"The input tensor names ({self._ordered_input_names}) do not match "
          f"the tensor names read from the model ({self._input_names}).")

    if ids_md is None:
      ids_md = TensorMd(name=self._IDS_NAME, description=self._IDS_DESCRIPTION)

    if mask_md is None:
      mask_md = TensorMd(
          name=self._MASK_NAME, description=self._MASK_DESCRIPTION)

    if segment_ids_md is None:
      segment_ids_md = TensorMd(
          name=self._SEGMENT_IDS_NAME,
          description=self._SEGMENT_IDS_DESCRIPTION)

    # The order of self._input_md matches the order of self._input_names.
    self._input_md = [ids_md, mask_md, segment_ids_md]

    if not isinstance(tokenizer_md,
                      (type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
      raise ValueError(
          f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
      )

    self._tokenizer_md = tokenizer_md

  def create_input_tesnor_metadata(self) -> List[_metadata_fb.TensorMetadataT]:
    """Creates the input metadata for the three input tesnors."""
    # The order of the three input tensors may vary with each model conversion.
    # We need to order the input metadata according to the tensor order in the
    # model.
    ordered_metadata = []
    name_md_dict = dict(zip(self._input_names, self._input_md))
    for name in self._ordered_input_names:
      ordered_metadata.append(name_md_dict[name].create_metadata())
    return ordered_metadata

  def create_input_process_unit_metadata(
      self) -> List[_metadata_fb.ProcessUnitT]:
    """Creates the input process unit metadata."""
    if self._tokenizer_md:
      return [self._tokenizer_md.create_metadata()]
    else:
      return []

  def get_tokenizer_associated_files(self) -> List[str]:
    """Gets the associated files that are packed in the tokenizer."""
    if self._tokenizer_md:
      return writer_utils.get_tokenizer_associated_files(
          self._tokenizer_md.create_metadata().options)
    else:
      return []