chromium/third_party/tflite_support/src/tensorflow_lite_support/codegen/python/codegen.py

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates Android Java sources from a TFLite model with metadata."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import shutil
import sys
from absl import app
from absl import flags
from absl import logging

from tensorflow_lite_support.codegen.python import _pywrap_codegen

FLAGS = flags.FLAGS

flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
                    'Name of generated java package to put the wrapper class.')
flags.DEFINE_string(
    'model_class_name', 'MyModel',
    'Name of generated wrapper class (should not contain package name).')
flags.DEFINE_string(
    'model_asset_path', '',
    '(Optional) Path to the model in generated assets/ dir. If not set, '
    'generator will use base name of input model.'
)


def get_model_buffer(path):
  if not os.path.isfile(path):
    logging.error('Cannot find model at path %s.', path)
  with open(path, 'rb') as f:
    buf = f.read()
    return buf


def prepare_directory_for_file(file_path):
  target_dir = os.path.dirname(file_path)
  if not os.path.exists(target_dir):
    os.makedirs(target_dir)
    return
  if not os.path.isdir(target_dir):
    logging.error('Cannot write to %s', target_dir)


def run_main(argv):
  """Main function of the codegen."""

  if len(argv) > 1:
    logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))

  codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
  model_buffer = get_model_buffer(FLAGS.model)
  model_asset_path = FLAGS.model_asset_path
  if not model_asset_path:
    model_asset_path = os.path.basename(FLAGS.model)
  result = codegen.generate(model_buffer, FLAGS.package_name,
                            FLAGS.model_class_name, model_asset_path)
  error_message = codegen.get_error_message().strip()
  if error_message:
    logging.error(error_message)
  if not result.files:
    logging.error('Generation failed!')
    return

  for each in result.files:
    prepare_directory_for_file(each.path)
    with open(each.path, 'w') as f:
      f.write(each.content)

  logging.info('Generation succeeded!')
  model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
                                  model_asset_path)
  prepare_directory_for_file(model_asset_path)
  shutil.copy(FLAGS.model, model_asset_path)
  logging.info('Model copied into assets!')


# Simple wrapper to make the code pip-friendly
def main():
  flags.mark_flag_as_required('model')
  flags.mark_flag_as_required('destination')
  app.run(main=run_main, argv=sys.argv)


if __name__ == '__main__':
  app.run(main)