chromium/third_party/tflite_support/src/tensorflow_lite_support/python/task/processor/proto/embedding_pb2.py

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding result protobuf."""

import dataclasses
from typing import Any, List

import numpy as np
from tensorflow_lite_support.cc.task.processor.proto import embedding_pb2
from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls

_FeatureVectorProto = embedding_pb2.FeatureVector
_EmbeddingProto = embedding_pb2.Embedding
_EmbeddingResultProto = embedding_pb2.EmbeddingResult


@dataclasses.dataclass
class FeatureVector:
  """A dense feature vector.

  Only one of the two fields is ever present.
  Feature vectors are assumed to be one-dimensional and L2-normalized.

  Attributes:
    value: A NumPy array indidcating the raw output of the embedding layer. The
      datatype of elements in the array can be either float or uint8 if
      `quantize` is set to True in `EmbeddingOptions`.
  """

  value: np.ndarray

  @doc_controls.do_not_generate_docs
  def to_pb2(self) -> _FeatureVectorProto:
    """Generates a protobuf object to pass to the C++ layer."""

    if self.value.dtype == float:
      return _FeatureVectorProto(value_float=self.value)

    elif self.value.dtype == np.uint8:
      return _FeatureVectorProto(value_string=bytes(self.value))

    else:
      raise ValueError("Invalid dtype. Only float and np.uint8 are supported.")

  @classmethod
  @doc_controls.do_not_generate_docs
  def create_from_pb2(cls, pb2_obj: _FeatureVectorProto) -> "FeatureVector":
    """Creates a `FeatureVector` object from the given protobuf object."""

    if pb2_obj.value_float:
      return FeatureVector(
          value=np.array(pb2_obj.value_float, dtype=float))

    elif pb2_obj.value_string:
      return FeatureVector(
          value=np.array(bytearray(pb2_obj.value_string), dtype=np.uint8))

    else:
      raise ValueError("Either value_float or value_string must exist.")

  def __eq__(self, other: Any) -> bool:
    """Checks if this object is equal to the given object.

    Args:
      other: The object to be compared with.

    Returns:
      True if the objects are equal.
    """
    if not isinstance(other, FeatureVector):
      return False

    return self.to_pb2().__eq__(other.to_pb2())


@dataclasses.dataclass
class Embedding:
  """Result produced by one of the embedder model output layers.

  Attributes:
    feature_vector: The output feature vector.
    output_index: The index of the model output layer that produced this feature
      vector.
  """

  feature_vector: FeatureVector
  output_index: int

  @doc_controls.do_not_generate_docs
  def to_pb2(self) -> _EmbeddingProto:
    """Generates a protobuf object to pass to the C++ layer."""
    return _EmbeddingProto(
        feature_vector=self.feature_vector.to_pb2(),
        output_index=self.output_index)

  @classmethod
  @doc_controls.do_not_generate_docs
  def create_from_pb2(cls, pb2_obj: _EmbeddingProto) -> "Embedding":
    """Creates a `Embedding` object from the given protobuf object."""
    return Embedding(
        feature_vector=FeatureVector.create_from_pb2(pb2_obj.feature_vector),
        output_index=pb2_obj.output_index)

  def __eq__(self, other: Any) -> bool:
    """Checks if this object is equal to the given object.

    Args:
      other: The object to be compared with.

    Returns:
      True if the objects are equal.
    """
    if not isinstance(other, Embedding):
      return False

    return self.to_pb2().__eq__(other.to_pb2())


@dataclasses.dataclass
class EmbeddingResult:
  """Embeddings produced by the Embedder.

  Attributes:
    embeddings: The embeddings produced by each of the model output layers.
      Except in advanced cases, the embedding model has a single output layer,
      and this list is thus made of a single element feature vector.
  """

  embeddings: List[Embedding]

  @doc_controls.do_not_generate_docs
  def to_pb2(self) -> _EmbeddingResultProto:
    """Generates a protobuf object to pass to the C++ layer."""
    return _EmbeddingResultProto(
        embeddings=[embedding.to_pb2() for embedding in self.embeddings])

  @classmethod
  @doc_controls.do_not_generate_docs
  def create_from_pb2(cls, pb2_obj: _EmbeddingResultProto) -> "EmbeddingResult":
    """Creates a `EmbeddingResult` object from the given protobuf object."""
    return EmbeddingResult(embeddings=[
        Embedding.create_from_pb2(embedding) for embedding in pb2_obj.embeddings
    ])

  def __eq__(self, other: Any) -> bool:
    """Checks if this object is equal to the given object.

    Args:
      other: The object to be compared with.

    Returns:
      True if the objects are equal.
    """
    if not isinstance(other, EmbeddingResult):
      return False

    return self.to_pb2().__eq__(other.to_pb2())