chromium/third_party/tflite_support/src/tensorflow_lite_support/custom_ops/python/sentencepiece_tokenizer.py

# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Python class that implements Sentencepiece tokenizer.

It follows TF.text designers design.

"""
import tensorflow.compat.v2 as tf  # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.ops.ragged import ragged_tensor  # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_sentencepiece_detokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_detokenizer_op.so'))
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_sentencepiece_tokenizer_op = load_library.load_op_library(resource_loader.get_path_to_datafile('../kernel/sentencepiece/sentencepiece_tokenizer_op.so'))
from tensorflow_lite_support.custom_ops.kernel.sentencepiece.py import pywrap_model_converter as model_converter


class SentencepieceTokenizer:
  """Sentencepiece tokenizer with tf.text interface."""

  def __init__(self, model, reverse=False, add_bos=False, add_eos=False):
    converted_model = model_converter.convert_sentencepiece_model(model)
    converted_model_detokenizer = model_converter.convert_sentencepiece_model_for_decoder(
        model)
    # Use uint8 tensor as a buffer for the model to avoid any possible changes,
    # for example truncation by '\0'.
    self._converted_model = tf.constant(list(converted_model), dtype=tf.uint8)
    self._converted_model_detokenizer = tf.constant(
        list(converted_model_detokenizer), dtype=tf.uint8)
    self._vocab_size = model_converter.get_vocabulary_size(converted_model)
    self._reverse = reverse
    self._add_bos = add_bos
    self._add_eos = add_eos

  def tokenize(self, inputs):
    """The main tokenization function."""
    input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(inputs)
    if input_tensor.shape.ndims is None:
      raise ValueError("Rank of input_tensor must be statically known.")
    if ragged_tensor.is_ragged(input_tensor):
      # Ensure that input has row_split_dtype is int32
      input_tensor = input_tensor.with_row_splits_dtype(tf.int32)
      # Recursively process the values of the ragged tensor.
      tokens = self.tokenize(input_tensor.flat_values)
      return input_tensor.with_flat_values(tokens)
    else:
      if input_tensor.shape.ndims > 1:
        # Convert the input tensor to ragged and process it.
        return self.tokenize(
            tf.RaggedTensor.from_tensor(
                input_tensor, row_splits_dtype=tf.int32))
      elif input_tensor.shape.ndims == 0:
        tokens = self.tokenize(tf.stack([input_tensor]))
        return tokens.values
      else:
        # Our rank 1 tensor is the correct shape, so we can process it as
        # normal.
        (output_values, row_splits) = (
            gen_sentencepiece_tokenizer_op.tf_sentencepiece_tokenize_op(
                self._converted_model, input_tensor, 0, 0, self._add_bos,
                self._add_eos, self._reverse))
        tokens = tf.RaggedTensor.from_nested_row_splits(
            flat_values=output_values,
            nested_row_splits=[row_splits],
            validate=False)
        return tokens

  def detokenize(self, input):  # pylint: disable=redefined-builtin
    """Detokenizes tokens into preprocessed text.

    Args:
      input: A `RaggedTensor` or `Tensor` with int32 encoded text with rank >=
        1.

    Returns:
      A N-1 dimensional string Tensor or RaggedTensor of the detokenized text.
    """
    input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
    if input_tensor.shape.ndims is None:
      raise ValueError("Rank of input_tensor must be statically known.")
    if input_tensor.shape.ndims == 0:
      raise ValueError("Rank of input_tensor must be at least 1.")
    if ragged_tensor.is_ragged(input_tensor):
      if input_tensor.flat_values.shape.ndims > 1:
        # If the flat_values of our ragged tensor is multi-dimensional, we can
        # process it separately and our output will have the same nested
        # splits as our input.
        tokens = self.detokenize(input_tensor.flat_values)
        return input_tensor.with_flat_values(tokens)
      elif input_tensor.ragged_rank > 1:
        # Recursively process the values of the ragged tensor.
        tokens = self.detokenize(input_tensor.values)
        return input_tensor.with_values(tokens)
      else:
        return gen_sentencepiece_detokenizer_op.tf_sentencepiece_detokenize_op(
            self._converted_model_detokenizer, input_tensor.flat_values,
            input_tensor.row_splits)
    else:
      if input_tensor.shape.ndims > 1:
        # Convert the input tensor to ragged and process it.
        return self.detokenize(
            tf.RaggedTensor.from_tensor(
                input_tensor, row_splits_dtype=tf.int32))
      else:
        tokens = self.detokenize(tf.stack([input_tensor]))
        return tf.reshape(tokens, [])

  def vocab_size(self):
    """Returns size of the vocabulary in Sentencepiece model."""
    return self._vocab_size