chromium/third_party/tflite_support/src/tensorflow_lite_support/metadata/python/tests/metadata_test.py

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_lite_support.metadata.metadata."""

import enum
import os

from absl.testing import parameterized
import flatbuffers
import six
import tensorflow as tf

from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.python import metadata as _metadata


class Tokenizer(enum.Enum):
  BERT_TOKENIZER = 0
  SENTENCE_PIECE = 1


class TensorType(enum.Enum):
  INPUT = 0
  OUTPUT = 1


def _read_file(file_name, mode="rb"):
  with open(file_name, mode) as f:
    return f.read()


class MetadataTest(tf.test.TestCase, parameterized.TestCase):

  def setUp(self):
    super(MetadataTest, self).setUp()
    self._invalid_model_buf = None
    self._invalid_file = "not_existed_file"
    self._model_buf = self._create_model_buf()
    self._model_file = self.create_tempfile().full_path
    with open(self._model_file, "wb") as f:
      f.write(self._model_buf)
    self._metadata_file = self._create_metadata_file()
    self._metadata_file_with_version = self._create_metadata_file_with_version(
        self._metadata_file, "1.0.0")
    self._file1 = self.create_tempfile("file1").full_path
    self._file2 = self.create_tempfile("file2").full_path
    self._file2_content = b"file2_content"
    with open(self._file2, "wb") as f:
      f.write(self._file2_content)
    self._file3 = self.create_tempfile("file3").full_path

  def _create_model_buf(self):
    # Create a model with two inputs and one output, which matches the metadata
    # created by _create_metadata_file().
    metadata_field = _schema_fb.MetadataT()
    subgraph = _schema_fb.SubGraphT()
    subgraph.inputs = [0, 1]
    subgraph.outputs = [2]

    metadata_field.name = "meta"
    buffer_field = _schema_fb.BufferT()
    model = _schema_fb.ModelT()
    model.subgraphs = [subgraph]
    # Creates the metadata and buffer fields for testing purposes.
    model.metadata = [metadata_field, metadata_field]
    model.buffers = [buffer_field, buffer_field, buffer_field]
    model_builder = flatbuffers.Builder(0)
    model_builder.Finish(
        model.Pack(model_builder),
        _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
    return model_builder.Output()

  def _create_metadata_file(self):
    associated_file1 = _metadata_fb.AssociatedFileT()
    associated_file1.name = b"file1"
    associated_file2 = _metadata_fb.AssociatedFileT()
    associated_file2.name = b"file2"
    self.expected_recorded_files = [
        six.ensure_str(associated_file1.name),
        six.ensure_str(associated_file2.name)
    ]

    input_meta = _metadata_fb.TensorMetadataT()
    output_meta = _metadata_fb.TensorMetadataT()
    output_meta.associatedFiles = [associated_file2]
    subgraph = _metadata_fb.SubGraphMetadataT()
    # Create a model with two inputs and one output.
    subgraph.inputTensorMetadata = [input_meta, input_meta]
    subgraph.outputTensorMetadata = [output_meta]

    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.name = "Mobilenet_quantized"
    model_meta.associatedFiles = [associated_file1]
    model_meta.subgraphMetadata = [subgraph]
    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)

    metadata_file = self.create_tempfile().full_path
    with open(metadata_file, "wb") as f:
      f.write(b.Output())
    return metadata_file

  def _create_model_buffer_with_wrong_identifier(self):
    wrong_identifier = b"widn"
    model = _schema_fb.ModelT()
    model_builder = flatbuffers.Builder(0)
    model_builder.Finish(model.Pack(model_builder), wrong_identifier)
    return model_builder.Output()

  def _create_metadata_buffer_with_wrong_identifier(self):
    # Creates a metadata with wrong identifier
    wrong_identifier = b"widn"
    metadata = _metadata_fb.ModelMetadataT()
    metadata_builder = flatbuffers.Builder(0)
    metadata_builder.Finish(metadata.Pack(metadata_builder), wrong_identifier)
    return metadata_builder.Output()

  def _populate_metadata_with_identifier(self, model_buf, metadata_buf,
                                         identifier):
    # For testing purposes only. MetadataPopulator cannot populate metadata with
    # wrong identifiers.
    model = _schema_fb.ModelT.InitFromObj(
        _schema_fb.Model.GetRootAsModel(model_buf, 0))
    buffer_field = _schema_fb.BufferT()
    buffer_field.data = metadata_buf
    model.buffers = [buffer_field]
    # Creates a new metadata field.
    metadata_field = _schema_fb.MetadataT()
    metadata_field.name = _metadata.MetadataPopulator.METADATA_FIELD_NAME
    metadata_field.buffer = len(model.buffers) - 1
    model.metadata = [metadata_field]
    b = flatbuffers.Builder(0)
    b.Finish(model.Pack(b), identifier)
    return b.Output()

  def _create_metadata_file_with_version(self, metadata_file, min_version):
    # Creates a new metadata file with the specified min_version for testing
    # purposes.
    metadata_buf = bytearray(_read_file(metadata_file))

    metadata = _metadata_fb.ModelMetadataT.InitFromObj(
        _metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
    metadata.minParserVersion = min_version

    b = flatbuffers.Builder(0)
    b.Finish(
        metadata.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)

    metadata_file_with_version = self.create_tempfile().full_path
    with open(metadata_file_with_version, "wb") as f:
      f.write(b.Output())
    return metadata_file_with_version


class MetadataPopulatorTest(MetadataTest):

  def _create_bert_tokenizer(self):
    vocab_file_name = "bert_vocab"
    vocab = _metadata_fb.AssociatedFileT()
    vocab.name = vocab_file_name
    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
    tokenizer = _metadata_fb.ProcessUnitT()
    tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
    tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
    tokenizer.options.vocabFile = [vocab]
    return tokenizer, [vocab_file_name]

  def _create_sentence_piece_tokenizer(self):
    sp_model_name = "sp_model"
    vocab_file_name = "sp_vocab"
    sp_model = _metadata_fb.AssociatedFileT()
    sp_model.name = sp_model_name
    vocab = _metadata_fb.AssociatedFileT()
    vocab.name = vocab_file_name
    vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
    tokenizer = _metadata_fb.ProcessUnitT()
    tokenizer.optionsType = (
        _metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
    tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
    tokenizer.options.sentencePieceModel = [sp_model]
    tokenizer.options.vocabFile = [vocab]
    return tokenizer, [sp_model_name, vocab_file_name]

  def _create_tokenizer(self, tokenizer_type):
    if tokenizer_type is Tokenizer.BERT_TOKENIZER:
      return self._create_bert_tokenizer()
    elif tokenizer_type is Tokenizer.SENTENCE_PIECE:
      return self._create_sentence_piece_tokenizer()
    else:
      raise ValueError(
          "The tokenizer type, {0}, is unsupported.".format(tokenizer_type))

  def _create_tempfiles(self, file_names):
    tempfiles = []
    for name in file_names:
      tempfiles.append(self.create_tempfile(name).full_path)
    return tempfiles

  def _create_model_meta_with_subgraph_meta(self, subgraph_meta):
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.subgraphMetadata = [subgraph_meta]
    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    return b.Output()

  def testToValidModelFile(self):
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    self.assertIsInstance(populator, _metadata.MetadataPopulator)

  def testToInvalidModelFile(self):
    with self.assertRaises(IOError) as error:
      _metadata.MetadataPopulator.with_model_file(self._invalid_file)
    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
                     str(error.exception))

  def testToValidModelBuffer(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    self.assertIsInstance(populator, _metadata.MetadataPopulator)

  def testToInvalidModelBuffer(self):
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataPopulator.with_model_buffer(self._invalid_model_buf)
    self.assertEqual("model_buf cannot be empty.", str(error.exception))

  def testToModelBufferWithWrongIdentifier(self):
    model_buf = self._create_model_buffer_with_wrong_identifier()
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataPopulator.with_model_buffer(model_buf)
    self.assertEqual(
        "The model provided does not have the expected identifier, and "
        "may not be a valid TFLite model.", str(error.exception))

  def testSinglePopulateAssociatedFile(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    populator.load_associated_files([self._file1])
    populator.populate()

    packed_files = populator.get_packed_associated_file_list()
    expected_packed_files = [os.path.basename(self._file1)]
    self.assertEqual(set(packed_files), set(expected_packed_files))

  def testRepeatedPopulateAssociatedFile(self):
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_associated_files([self._file1, self._file2])
    # Loads file2 multiple times.
    populator.load_associated_files([self._file2])
    populator.populate()

    packed_files = populator.get_packed_associated_file_list()
    expected_packed_files = [
        os.path.basename(self._file1),
        os.path.basename(self._file2)
    ]
    self.assertLen(packed_files, 2)
    self.assertEqual(set(packed_files), set(expected_packed_files))

    # Check if the model buffer read from file is the same as that read from
    # get_model_buffer().
    model_buf_from_file = _read_file(self._model_file)
    model_buf_from_getter = populator.get_model_buffer()
    self.assertEqual(model_buf_from_file, model_buf_from_getter)

  def testPopulateInvalidAssociatedFile(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(IOError) as error:
      populator.load_associated_files([self._invalid_file])
    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
                     str(error.exception))

  def testPopulatePackedAssociatedFile(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    populator.load_associated_files([self._file1])
    populator.populate()
    with self.assertRaises(ValueError) as error:
      populator.load_associated_files([self._file1])
      populator.populate()
    self.assertEqual(
        "File, '{0}', has already been packed.".format(
            os.path.basename(self._file1)), str(error.exception))

  def testLoadAssociatedFileBuffers(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    file_buffer = _read_file(self._file1)
    populator.load_associated_file_buffers({self._file1: file_buffer})
    populator.populate()

    packed_files = populator.get_packed_associated_file_list()
    expected_packed_files = [os.path.basename(self._file1)]
    self.assertEqual(set(packed_files), set(expected_packed_files))

  def testRepeatedLoadAssociatedFileBuffers(self):
    file_buffer1 = _read_file(self._file1)
    file_buffer2 = _read_file(self._file2)
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)

    populator.load_associated_file_buffers({
        self._file1: file_buffer1,
        self._file2: file_buffer2
    })
    # Loads file2 multiple times.
    populator.load_associated_file_buffers({self._file2: file_buffer2})
    populator.populate()

    packed_files = populator.get_packed_associated_file_list()
    expected_packed_files = [
        os.path.basename(self._file1),
        os.path.basename(self._file2)
    ]
    self.assertEqual(set(packed_files), set(expected_packed_files))

    # Check if the model buffer read from file is the same as that read from
    # get_model_buffer().
    model_buf_from_file = _read_file(self._model_file)
    model_buf_from_getter = populator.get_model_buffer()
    self.assertEqual(model_buf_from_file, model_buf_from_getter)

  def testLoadPackedAssociatedFileBuffersFails(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    file_buffer = _read_file(self._file1)
    populator.load_associated_file_buffers({self._file1: file_buffer})
    populator.populate()

    # Load file1 again should fail.
    with self.assertRaises(ValueError) as error:
      populator.load_associated_file_buffers({self._file1: file_buffer})
      populator.populate()
    self.assertEqual(
        "File, '{0}', has already been packed.".format(
            os.path.basename(self._file1)), str(error.exception))

  def testGetPackedAssociatedFileList(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    packed_files = populator.get_packed_associated_file_list()
    self.assertEqual(packed_files, [])

  def testPopulateMetadataFileToEmptyModelFile(self):
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_file(self._metadata_file)
    populator.load_associated_files([self._file1, self._file2])
    populator.populate()

    model_buf_from_file = _read_file(self._model_file)
    model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
    # self._model_file already has two elements in the metadata field, so the
    # populated TFLite metadata will be the third element.
    metadata_field = model.Metadata(2)
    self.assertEqual(
        six.ensure_str(metadata_field.Name()),
        six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))

    buffer_index = metadata_field.Buffer()
    buffer_data = model.Buffers(buffer_index)
    metadata_buf_np = buffer_data.DataAsNumpy()
    metadata_buf = metadata_buf_np.tobytes()
    expected_metadata_buf = bytearray(
        _read_file(self._metadata_file_with_version))
    self.assertEqual(metadata_buf, expected_metadata_buf)

    recorded_files = populator.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))

    # Up to now, we've proved the correctness of the model buffer that read from
    # file. Then we'll test if get_model_buffer() gives the same model buffer.
    model_buf_from_getter = populator.get_model_buffer()
    self.assertEqual(model_buf_from_file, model_buf_from_getter)

  def testPopulateMetadataFileWithoutAssociatedFiles(self):
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_file(self._metadata_file)
    populator.load_associated_files([self._file1])
    # Suppose to populate self._file2, because it is recorded in the metadta.
    with self.assertRaises(ValueError) as error:
      populator.populate()
    self.assertEqual(("File, '{0}', is recorded in the metadata, but has "
                      "not been loaded into the populator.").format(
                          os.path.basename(self._file2)), str(error.exception))

  def testPopulateMetadataBufferWithWrongIdentifier(self):
    metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    with self.assertRaises(ValueError) as error:
      populator.load_metadata_buffer(metadata_buf)
    self.assertEqual(
        "The metadata buffer does not have the expected identifier, and may not"
        " be a valid TFLite Metadata.", str(error.exception))

  def _assert_golden_metadata(self, model_file):
    model_buf_from_file = _read_file(model_file)
    model = _schema_fb.Model.GetRootAsModel(model_buf_from_file, 0)
    # There are two elements in model.Metadata array before the population.
    # Metadata should be packed to the third element in the array.
    metadata_field = model.Metadata(2)
    self.assertEqual(
        six.ensure_str(metadata_field.Name()),
        six.ensure_str(_metadata.MetadataPopulator.METADATA_FIELD_NAME))

    buffer_index = metadata_field.Buffer()
    buffer_data = model.Buffers(buffer_index)
    metadata_buf_np = buffer_data.DataAsNumpy()
    metadata_buf = metadata_buf_np.tobytes()
    expected_metadata_buf = bytearray(
        _read_file(self._metadata_file_with_version))
    self.assertEqual(metadata_buf, expected_metadata_buf)

  def testPopulateMetadataFileToModelWithMetadataAndAssociatedFiles(self):
    # First, creates a dummy metadata different from self._metadata_file. It
    # needs to have the same input/output tensor numbers as self._model_file.
    # Populates it and the associated files into the model.
    input_meta = _metadata_fb.TensorMetadataT()
    output_meta = _metadata_fb.TensorMetadataT()
    subgraph = _metadata_fb.SubGraphMetadataT()
    # Create a model with two inputs and one output.
    subgraph.inputTensorMetadata = [input_meta, input_meta]
    subgraph.outputTensorMetadata = [output_meta]
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.subgraphMetadata = [subgraph]
    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    metadata_buf = b.Output()

    # Populate the metadata.
    populator1 = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator1.load_metadata_buffer(metadata_buf)
    populator1.load_associated_files([self._file1, self._file2])
    populator1.populate()

    # Then, populate the metadata again.
    populator2 = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator2.load_metadata_file(self._metadata_file)
    populator2.populate()

    # Test if the metadata is populated correctly.
    self._assert_golden_metadata(self._model_file)

  def testPopulateMetadataFileToModelFileWithMetadataAndBufFields(self):
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_file(self._metadata_file)
    populator.load_associated_files([self._file1, self._file2])
    populator.populate()

    # Tests if the metadata is populated correctly.
    self._assert_golden_metadata(self._model_file)

    recorded_files = populator.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))

    # Up to now, we've proved the correctness of the model buffer that read from
    # file. Then we'll test if get_model_buffer() gives the same model buffer.
    model_buf_from_file = _read_file(self._model_file)
    model_buf_from_getter = populator.get_model_buffer()
    self.assertEqual(model_buf_from_file, model_buf_from_getter)

  def testPopulateInvalidMetadataFile(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(IOError) as error:
      populator.load_metadata_file(self._invalid_file)
    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
                     str(error.exception))

  def testPopulateInvalidMetadataBuffer(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(ValueError) as error:
      populator.load_metadata_buffer([])
    self.assertEqual("The metadata to be populated is empty.",
                     str(error.exception))

  def testGetModelBufferBeforePopulatingData(self):
    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    model_buf = populator.get_model_buffer()
    expected_model_buf = self._model_buf
    self.assertEqual(model_buf, expected_model_buf)

  def testLoadMetadataBufferWithNoSubgraphMetadataThrowsException(self):
    # Create a dummy metadata without Subgraph.
    model_meta = _metadata_fb.ModelMetadataT()
    builder = flatbuffers.Builder(0)
    builder.Finish(
        model_meta.Pack(builder),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    meta_buf = builder.Output()

    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(ValueError) as error:
      populator.load_metadata_buffer(meta_buf)
    self.assertEqual(
        "The number of SubgraphMetadata should be exactly one, but got 0.",
        str(error.exception))

  def testLoadMetadataBufferWithWrongInputMetaNumberThrowsException(self):
    # Create a dummy metadata with no input tensor metadata, while the expected
    # number is 2.
    output_meta = _metadata_fb.TensorMetadataT()
    subgprah_meta = _metadata_fb.SubGraphMetadataT()
    subgprah_meta.outputTensorMetadata = [output_meta]
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.subgraphMetadata = [subgprah_meta]
    builder = flatbuffers.Builder(0)
    builder.Finish(
        model_meta.Pack(builder),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    meta_buf = builder.Output()

    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(ValueError) as error:
      populator.load_metadata_buffer(meta_buf)
    self.assertEqual(
        ("The number of input tensors (2) should match the number of "
         "input tensor metadata (0)"), str(error.exception))

  def testLoadMetadataBufferWithWrongOutputMetaNumberThrowsException(self):
    # Create a dummy metadata with no output tensor metadata, while the expected
    # number is 1.
    input_meta = _metadata_fb.TensorMetadataT()
    subgprah_meta = _metadata_fb.SubGraphMetadataT()
    subgprah_meta.inputTensorMetadata = [input_meta, input_meta]
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.subgraphMetadata = [subgprah_meta]
    builder = flatbuffers.Builder(0)
    builder.Finish(
        model_meta.Pack(builder),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    meta_buf = builder.Output()

    populator = _metadata.MetadataPopulator.with_model_buffer(self._model_buf)
    with self.assertRaises(ValueError) as error:
      populator.load_metadata_buffer(meta_buf)
    self.assertEqual(
        ("The number of output tensors (1) should match the number of "
         "output tensor metadata (0)"), str(error.exception))

  def testLoadMetadataAndAssociatedFilesShouldSucceeds(self):
    # Create a src model with metadata and two associated files.
    src_model_buf = self._create_model_buf()
    populator_src = _metadata.MetadataPopulator.with_model_buffer(src_model_buf)
    populator_src.load_metadata_file(self._metadata_file)
    populator_src.load_associated_files([self._file1, self._file2])
    populator_src.populate()

    # Create a model to be populated with the metadata and files from
    # src_model_buf.
    dst_model_buf = self._create_model_buf()
    populator_dst = _metadata.MetadataPopulator.with_model_buffer(dst_model_buf)
    populator_dst.load_metadata_and_associated_files(
        populator_src.get_model_buffer())
    populator_dst.populate()

    # Tests if the metadata and associated files are populated correctly.
    dst_model_file = self.create_tempfile().full_path
    with open(dst_model_file, "wb") as f:
      f.write(populator_dst.get_model_buffer())
    self._assert_golden_metadata(dst_model_file)

    recorded_files = populator_dst.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set(self.expected_recorded_files))

  @parameterized.named_parameters(
      {
          "testcase_name": "InputTensorWithBert",
          "tensor_type": TensorType.INPUT,
          "tokenizer_type": Tokenizer.BERT_TOKENIZER
      }, {
          "testcase_name": "OutputTensorWithBert",
          "tensor_type": TensorType.OUTPUT,
          "tokenizer_type": Tokenizer.BERT_TOKENIZER
      }, {
          "testcase_name": "InputTensorWithSentencePiece",
          "tensor_type": TensorType.INPUT,
          "tokenizer_type": Tokenizer.SENTENCE_PIECE
      }, {
          "testcase_name": "OutputTensorWithSentencePiece",
          "tensor_type": TensorType.OUTPUT,
          "tokenizer_type": Tokenizer.SENTENCE_PIECE
      })
  def testGetRecordedAssociatedFileListWithSubgraphTensor(
      self, tensor_type, tokenizer_type):
    # Creates a metadata with the tokenizer in the tensor process units.
    tokenizer, expected_files = self._create_tokenizer(tokenizer_type)

    # Create the tensor with process units.
    tensor = _metadata_fb.TensorMetadataT()
    tensor.processUnits = [tokenizer]

    # Create the subgrah with the tensor.
    subgraph = _metadata_fb.SubGraphMetadataT()
    dummy_tensor_meta = _metadata_fb.TensorMetadataT()
    subgraph.outputTensorMetadata = [dummy_tensor_meta]
    if tensor_type is TensorType.INPUT:
      subgraph.inputTensorMetadata = [tensor, dummy_tensor_meta]
      subgraph.outputTensorMetadata = [dummy_tensor_meta]
    elif tensor_type is TensorType.OUTPUT:
      subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
      subgraph.outputTensorMetadata = [tensor]
    else:
      raise ValueError(
          "The tensor type, {0}, is unsupported.".format(tensor_type))

    # Create a model metadata with the subgraph metadata
    meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)

    # Creates the tempfiles.
    tempfiles = self._create_tempfiles(expected_files)

    # Creates the MetadataPopulator object.
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_buffer(meta_buffer)
    populator.load_associated_files(tempfiles)
    populator.populate()

    recorded_files = populator.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set(expected_files))

  @parameterized.named_parameters(
      {
          "testcase_name": "InputTensorWithBert",
          "tensor_type": TensorType.INPUT,
          "tokenizer_type": Tokenizer.BERT_TOKENIZER
      }, {
          "testcase_name": "OutputTensorWithBert",
          "tensor_type": TensorType.OUTPUT,
          "tokenizer_type": Tokenizer.BERT_TOKENIZER
      }, {
          "testcase_name": "InputTensorWithSentencePiece",
          "tensor_type": TensorType.INPUT,
          "tokenizer_type": Tokenizer.SENTENCE_PIECE
      }, {
          "testcase_name": "OutputTensorWithSentencePiece",
          "tensor_type": TensorType.OUTPUT,
          "tokenizer_type": Tokenizer.SENTENCE_PIECE
      })
  def testGetRecordedAssociatedFileListWithSubgraphProcessUnits(
      self, tensor_type, tokenizer_type):
    # Creates a metadata with the tokenizer in the subgraph process units.
    tokenizer, expected_files = self._create_tokenizer(tokenizer_type)

    # Create the subgraph with process units.
    subgraph = _metadata_fb.SubGraphMetadataT()
    if tensor_type is TensorType.INPUT:
      subgraph.inputProcessUnits = [tokenizer]
    elif tensor_type is TensorType.OUTPUT:
      subgraph.outputProcessUnits = [tokenizer]
    else:
      raise ValueError(
          "The tensor type, {0}, is unsupported.".format(tensor_type))

    # Creates the input and output tensor meta to match self._model_file.
    dummy_tensor_meta = _metadata_fb.TensorMetadataT()
    subgraph.inputTensorMetadata = [dummy_tensor_meta, dummy_tensor_meta]
    subgraph.outputTensorMetadata = [dummy_tensor_meta]

    # Create a model metadata with the subgraph metadata
    meta_buffer = self._create_model_meta_with_subgraph_meta(subgraph)

    # Creates the tempfiles.
    tempfiles = self._create_tempfiles(expected_files)

    # Creates the MetadataPopulator object.
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_buffer(meta_buffer)
    populator.load_associated_files(tempfiles)
    populator.populate()

    recorded_files = populator.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set(expected_files))

  def testPopulatedFullPathAssociatedFileShouldSucceed(self):
    # Create AssociatedFileT using the full path file name.
    associated_file = _metadata_fb.AssociatedFileT()
    associated_file.name = self._file1

    # Create model metadata with the associated file.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.associatedFiles = [associated_file]
    # Creates the input and output tensor metadata to match self._model_file.
    dummy_tensor = _metadata_fb.TensorMetadataT()
    subgraph.inputTensorMetadata = [dummy_tensor, dummy_tensor]
    subgraph.outputTensorMetadata = [dummy_tensor]
    md_buffer = self._create_model_meta_with_subgraph_meta(subgraph)

    # Populate the metadata to a model.
    populator = _metadata.MetadataPopulator.with_model_file(self._model_file)
    populator.load_metadata_buffer(md_buffer)
    populator.load_associated_files([self._file1])
    populator.populate()

    # The recorded file name in metadata should only contain file basename; file
    # directory should not be included.
    recorded_files = populator.get_recorded_associated_file_list()
    self.assertEqual(set(recorded_files), set([os.path.basename(self._file1)]))


class MetadataDisplayerTest(MetadataTest):

  def setUp(self):
    super(MetadataDisplayerTest, self).setUp()
    self._model_with_meta_file = (
        self._create_model_with_metadata_and_associated_files())

  def _create_model_with_metadata_and_associated_files(self):
    model_buf = self._create_model_buf()
    model_file = self.create_tempfile().full_path
    with open(model_file, "wb") as f:
      f.write(model_buf)

    populator = _metadata.MetadataPopulator.with_model_file(model_file)
    populator.load_metadata_file(self._metadata_file)
    populator.load_associated_files([self._file1, self._file2])
    populator.populate()
    return model_file

  def testLoadModelBufferMetadataBufferWithWrongIdentifierThrowsException(self):
    model_buf = self._create_model_buffer_with_wrong_identifier()
    metadata_buf = self._create_metadata_buffer_with_wrong_identifier()
    model_buf = self._populate_metadata_with_identifier(
        model_buf, metadata_buf,
        _metadata.MetadataPopulator.TFLITE_FILE_IDENTIFIER)
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataDisplayer.with_model_buffer(model_buf)
    self.assertEqual(
        "The metadata buffer does not have the expected identifier, and may not"
        " be a valid TFLite Metadata.", str(error.exception))

  def testLoadModelBufferModelBufferWithWrongIdentifierThrowsException(self):
    model_buf = self._create_model_buffer_with_wrong_identifier()
    metadata_file = self._create_metadata_file()
    wrong_identifier = b"widn"
    metadata_buf = bytearray(_read_file(metadata_file))
    model_buf = self._populate_metadata_with_identifier(model_buf, metadata_buf,
                                                        wrong_identifier)
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataDisplayer.with_model_buffer(model_buf)
    self.assertEqual(
        "The model provided does not have the expected identifier, and "
        "may not be a valid TFLite model.", str(error.exception))

  def testLoadModelFileInvalidModelFileThrowsException(self):
    with self.assertRaises(IOError) as error:
      _metadata.MetadataDisplayer.with_model_file(self._invalid_file)
    self.assertEqual("File, '{0}', does not exist.".format(self._invalid_file),
                     str(error.exception))

  def testLoadModelFileModelWithoutMetadataThrowsException(self):
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataDisplayer.with_model_file(self._model_file)
    self.assertEqual("The model does not have metadata.", str(error.exception))

  def testLoadModelFileModelWithMetadata(self):
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)
    self.assertIsInstance(displayer, _metadata.MetadataDisplayer)

  def testLoadModelBufferInvalidModelBufferThrowsException(self):
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataDisplayer.with_model_buffer(_read_file(self._file1))
    self.assertEqual("model_buffer cannot be empty.", str(error.exception))

  def testLoadModelBufferModelWithOutMetadataThrowsException(self):
    with self.assertRaises(ValueError) as error:
      _metadata.MetadataDisplayer.with_model_buffer(self._create_model_buf())
    self.assertEqual("The model does not have metadata.", str(error.exception))

  def testLoadModelBufferModelWithMetadata(self):
    displayer = _metadata.MetadataDisplayer.with_model_buffer(
        _read_file(self._model_with_meta_file))
    self.assertIsInstance(displayer, _metadata.MetadataDisplayer)

  def testGetAssociatedFileBufferShouldSucceed(self):
    # _model_with_meta_file contains file1 and file2.
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)

    actual_content = displayer.get_associated_file_buffer("file2")
    self.assertEqual(actual_content, self._file2_content)

  def testGetAssociatedFileBufferFailsWithNonExistentFile(self):
    # _model_with_meta_file contains file1 and file2.
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)

    non_existent_file = "non_existent_file"
    with self.assertRaises(ValueError) as error:
      displayer.get_associated_file_buffer(non_existent_file)
    self.assertEqual(
        "The file, {}, does not exist in the model.".format(non_existent_file),
        str(error.exception))

  def testGetMetadataBufferShouldSucceed(self):
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)
    actual_buffer = displayer.get_metadata_buffer()
    actual_json = _metadata.convert_to_json(actual_buffer)

    # Verifies the generated json file.
    golden_json_file_path = tf.compat.v1.resource_loader.get_path_to_datafile(
        "testdata/golden_json.json")
    with open(golden_json_file_path, "r") as f:
      expected = f.read()
    self.assertEqual(actual_json, expected)

  def testGetMetadataJsonModelWithMetadata(self):
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)
    actual = displayer.get_metadata_json()

    # Verifies the generated json file.
    golden_json_file_path = tf.compat.v1.resource_loader.get_path_to_datafile(
        "testdata/golden_json.json")
    expected = _read_file(golden_json_file_path, "r")
    self.assertEqual(actual, expected)

  def testGetPackedAssociatedFileListModelWithMetadata(self):
    displayer = _metadata.MetadataDisplayer.with_model_file(
        self._model_with_meta_file)
    packed_files = displayer.get_packed_associated_file_list()

    expected_packed_files = [
        os.path.basename(self._file1),
        os.path.basename(self._file2)
    ]
    self.assertLen(
        packed_files, 2,
        "The following two associated files packed to the model: {0}; {1}"
        .format(expected_packed_files[0], expected_packed_files[1]))
    self.assertEqual(set(packed_files), set(expected_packed_files))


class MetadataUtilTest(MetadataTest):

  def test_convert_to_json_should_succeed(self):
    metadata_buf = _read_file(self._metadata_file_with_version)
    metadata_json = _metadata.convert_to_json(metadata_buf)

    # Verifies the generated json file.
    golden_json_file_path = tf.compat.v1.resource_loader.get_path_to_datafile(
        "testdata/golden_json.json")
    expected = _read_file(golden_json_file_path, "r")
    self.assertEqual(metadata_json, expected)


if __name__ == "__main__":
  tf.test.main()