# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generates Android Java sources from a TFLite model with metadata."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import sys
from absl import app
from absl import flags
from absl import logging
from tensorflow_lite_support.codegen.python import _pywrap_codegen
FLAGS = flags.FLAGS
flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
flags.DEFINE_string('destination', None, 'Path of destination of generation.')
flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
'Name of generated java package to put the wrapper class.')
flags.DEFINE_string(
'model_class_name', 'MyModel',
'Name of generated wrapper class (should not contain package name).')
flags.DEFINE_string(
'model_asset_path', '',
'(Optional) Path to the model in generated assets/ dir. If not set, '
'generator will use base name of input model.'
)
def get_model_buffer(path):
if not os.path.isfile(path):
logging.error('Cannot find model at path %s.', path)
with open(path, 'rb') as f:
buf = f.read()
return buf
def prepare_directory_for_file(file_path):
target_dir = os.path.dirname(file_path)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
return
if not os.path.isdir(target_dir):
logging.error('Cannot write to %s', target_dir)
def run_main(argv):
"""Main function of the codegen."""
if len(argv) > 1:
logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
model_buffer = get_model_buffer(FLAGS.model)
model_asset_path = FLAGS.model_asset_path
if not model_asset_path:
model_asset_path = os.path.basename(FLAGS.model)
result = codegen.generate(model_buffer, FLAGS.package_name,
FLAGS.model_class_name, model_asset_path)
error_message = codegen.get_error_message().strip()
if error_message:
logging.error(error_message)
if not result.files:
logging.error('Generation failed!')
return
for each in result.files:
prepare_directory_for_file(each.path)
with open(each.path, 'w') as f:
f.write(each.content)
logging.info('Generation succeeded!')
model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
model_asset_path)
prepare_directory_for_file(model_asset_path)
shutil.copy(FLAGS.model, model_asset_path)
logging.info('Model copied into assets!')
# Simple wrapper to make the code pip-friendly
def main():
flags.mark_flag_as_required('model')
flags.mark_flag_as_required('destination')
app.run(main=run_main, argv=sys.argv)
if __name__ == '__main__':
app.run(main)