chromium/third_party/mediapipe/src/mediapipe/tasks/cc/text/utils/vocab_convert_utils.cc

#include "mediapipe/tasks/cc/text/utils/vocab_convert_utils.h"

#include <fstream>
#include <istream>
#include <streambuf>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/util/resource_util.h"
#include "nlohmann/json.hpp"  // from @com_github_nlohmann_json
#include "nlohmann/json_fwd.hpp"
#include "sentencepiece/src/builder.h"  // from @com_google_sentencepiece
#include "sentencepiece/src/sentencepiece_model.pb.h"  // from @com_google_sentencepiece

namespace mediapipe {
namespace tasks {
namespace text {
namespace {

using ::nlohmann::json;
using ::sentencepiece::ModelProto;
using ::sentencepiece::NormalizerSpec;
using ::sentencepiece::TrainerSpec;
using ::sentencepiece::normalizer::Builder;

// Loads Hugging Face's `tokenizer_config.json` and `tokenizer.json`. The
// files include the preprocessing and postprocessing steps and the token
// mappings. The loaded jsons are returned as a pair containing
// `tokenizer_config.json` and `tokenizer.json` in the same order.
absl::StatusOr<std::pair<json, json>> LoadHFTokenizerConfigs(
    absl::string_view path) {
  std::string contents;
  MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
      absl::StrCat(path, "/tokenizer_config.json"), &contents));
  auto config_json = json::parse(contents, nullptr, false);
  if (config_json.is_discarded()) {
    return absl::InternalError("Failed to parse tokenizer_config.json");
  }
  MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
      absl::StrCat(path, "/tokenizer.json"), &contents));
  auto tokenizer_json = json::parse(contents);
  if (tokenizer_json.is_discarded()) {
    return absl::InternalError("Failed to parse tokenizer.json");
  }
  return std::make_pair(config_json, tokenizer_json);
}

absl::Status ConfigureNormalizerSpecs(NormalizerSpec* spec) {
  spec->set_add_dummy_prefix(false);
  spec->set_remove_extra_whitespaces(false);
  spec->set_escape_whitespaces(false);
  return absl::OkStatus();
}

absl::Status ConfigureDenormalizerSpecs(NormalizerSpec* spec) {
  spec->set_add_dummy_prefix(false);
  spec->set_remove_extra_whitespaces(false);
  spec->set_escape_whitespaces(false);
  return absl::OkStatus();
}
}  // namespace
absl::Status ConvertHfTokenizer(const std::string& hf_tokenizer,
                                const std::string& output_vocab_path) {
  MP_ASSIGN_OR_RETURN(auto configs, LoadHFTokenizerConfigs(hf_tokenizer));

  ModelProto model_proto;

  MP_RETURN_IF_ERROR(
      ConfigureNormalizerSpecs(model_proto.mutable_normalizer_spec()));
  MP_RETURN_IF_ERROR(
      ConfigureDenormalizerSpecs(model_proto.mutable_denormalizer_spec()));

  // The scores assigned here are heuristic based and only captures the ordering
  // of elements within HF configs. This may not be optimal.
  std::vector<std::string> normal_vocabs(
      configs.second["model"]["vocab"].size());
  for (const auto& [vocab, id] : configs.second["model"]["vocab"].items()) {
    normal_vocabs[id] = vocab;
  }
  std::string unk_token = configs.first.at("unk_token").get<std::string>();
  for (int i = 0; i < normal_vocabs.size(); ++i) {
    auto* sp = model_proto.add_pieces();
    auto vocab = normal_vocabs[i];
    sp->set_type(unk_token == vocab ? ModelProto::SentencePiece::UNKNOWN
                                    : ModelProto::SentencePiece::NORMAL);
    sp->set_piece(vocab);
    sp->set_score(-i);
  }
  const auto& added_tokens = configs.second["added_tokens"];
  for (int i = 0; i < added_tokens.size(); ++i) {
    if (added_tokens[i]["normalized"]) {
      auto vocab = added_tokens[i]["content"];
      auto* sp = model_proto.add_pieces();
      sp->set_type(ModelProto::SentencePiece::USER_DEFINED);
      sp->set_piece(vocab);
      sp->set_score(-(normal_vocabs.size() + i));
    }
  }

  auto* trainer_spec = model_proto.mutable_trainer_spec();
  trainer_spec->set_model_type(TrainerSpec::BPE);
  trainer_spec->set_vocab_size(model_proto.pieces_size());

  absl::string_view output_dir = ::mediapipe::file::Dirname(output_vocab_path);
  if (!::mediapipe::file::IsDirectory(output_dir).ok()) {
    MP_RETURN_IF_ERROR(::mediapipe::file::RecursivelyCreateDir(output_dir));
  }

  MP_RETURN_IF_ERROR(mediapipe::file::SetContents(
      output_vocab_path, model_proto.SerializeAsString()));

  return absl::OkStatus();
}

}  // namespace text
}  // namespace tasks
}  // namespace mediapipe