chromium/third_party/sentencepiece/src/src/trainer_interface.cc

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "trainer_interface.h"

#include <algorithm>
#include <cstdlib>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/random/distributions.h"
#include "absl/random/random.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "filesystem.h"
#include "model_factory.h"
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "unicode_script.h"
#include "util.h"

namespace sentencepiece {

const char32 TrainerInterface::kWSChar = L'\u2581';
const char TrainerInterface::kWSStr[] = "\xe2\x96\x81";

const char32 TrainerInterface::kUNKChar = L'\u2585';
const char TrainerInterface::kUNKStr[] = "\xe2\x96\x85";

const char32 TrainerInterface::kUPPBoundaryChar = L'\u0009';
const char TrainerInterface::kUPPBoundaryStr[] = "\t";

namespace {
util::Status VerifySpec(const TrainerSpec& trainer_spec) {
  CHECK_GT_OR_RETURN(trainer_spec.vocab_size(), 0);

  if (trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
      trainer_spec.model_type() == TrainerSpec::BPE) {
    CHECK_OR_RETURN(!trainer_spec.use_all_vocab())
        << "--use_all_vocab=true is valid for WORD/CHAR model.";
  }

#define CHECK_RANGE(variable, minval, maxval) \
  CHECK_OR_RETURN(variable >= minval && variable <= maxval)

  CHECK_RANGE(trainer_spec.character_coverage(), 0.98, 1.0);
  CHECK_RANGE(trainer_spec.max_sentencepiece_length(), 1, 512);
  CHECK_RANGE(trainer_spec.num_sub_iterations(), 1, 10);
  CHECK_RANGE(trainer_spec.num_threads(), 1, 1024);
  CHECK_RANGE(trainer_spec.self_test_sample_size(), 0, 1000);
  CHECK_RANGE(trainer_spec.shrinking_factor(), 0.5, 0.95);
  CHECK_RANGE(trainer_spec.max_sentence_length(), 10, 1073741824);
#undef CHECK_RANGE

  CHECK_OR_RETURN(trainer_spec.input_sentence_size() <= 0 ||
                  trainer_spec.input_sentence_size() > 100);

  CHECK_OR_RETURN(!trainer_spec.unk_piece().empty());
  CHECK_OR_RETURN(!trainer_spec.bos_piece().empty());
  CHECK_OR_RETURN(!trainer_spec.eos_piece().empty());
  CHECK_OR_RETURN(!trainer_spec.pad_piece().empty());

  if (SentencePieceTrainer::GetPretokenizerForTraining() ||
      !trainer_spec.pretokenization_delimiter().empty()) {
    CHECK_OR_RETURN(trainer_spec.model_type() == TrainerSpec::UNIGRAM ||
                    trainer_spec.model_type() == TrainerSpec::BPE)
        << "PretokenizerForTraining is only supported in UNIGRAM or BPE mode.";
  }

  return util::OkStatus();
}

bool is_unicode_decimal_number(char32 c) {
  return (c >= 0x30 && c <= 0x39) || (c >= 0xff10 && c <= 0xff19);
}

class SentenceSelector {
 public:
  using Sampler = random::ReservoirSampler<TrainerInterface::Sentence>;

  static constexpr int64 kTooBigSentencesSize = 1000000;

  SentenceSelector(TrainerInterface::Sentences *sentences,
                   const TrainerSpec &spec)
      : sentences_(sentences), spec_(&spec) {
    if (spec_->input_sentence_size() > 0) {
      if (spec_->shuffle_input_sentence()) {
        constexpr size_t kSeed = 12345678;
        sampler_ = absl::make_unique<Sampler>(
            sentences, spec_->input_sentence_size(), kSeed);
      } else {
        LOG(INFO)
            << "First " << spec_->input_sentence_size()
            << " sentences are selected. Remaining sentences are discarded.";
      }
    }
  }

  void Finish() const {
    if (sentences_->size() > kTooBigSentencesSize) {
      LOG(WARNING) << "Too many sentences are loaded! (" << sentences_->size()
                   << "), which may slow down training.";
      LOG(WARNING) << "Consider using "
                      "--input_sentence_size=<size> and "
                      "--shuffle_input_sentence=true.";
      LOG(WARNING) << "They allow to randomly sample <size> sentences from "
                      "the entire corpus.";
    }
  }

  bool Add(const std::pair<std::string, int64> &sentence) {
    if (spec_->input_sentence_size() == 0) {
      sentences_->emplace_back(sentence);
    } else {
      if (spec_->shuffle_input_sentence()) {
        sampler_->Add(sentence);
      } else {
        sentences_->emplace_back(sentence);
        if (sentences_->size() >= spec_->input_sentence_size()) {
          return false;
        }
      }
    }

    if (total_size() > 0 && total_size() % kTooBigSentencesSize == 0) {
      LOG(INFO) << "Loaded " << total_size() << " lines";
    }

    return true;
  }

  size_t total_size() const {
    return sampler_.get() ? sampler_->total_size() : sentences_->size();
  }

 private:
  TrainerInterface::Sentences *sentences_ = nullptr;
  const TrainerSpec *spec_ = nullptr;
  std::unique_ptr<Sampler> sampler_;
};
}  // namespace

MultiFileSentenceIterator::MultiFileSentenceIterator(
    const std::vector<std::string>& files)
    : files_(files) {
  Next();
}

bool MultiFileSentenceIterator::done() const {
  return (!read_done_ && file_index_ == files_.size());
}

util::Status MultiFileSentenceIterator::status() const {
  CHECK_OR_RETURN(fp_);
  return fp_->status();
}

void MultiFileSentenceIterator::Next() {
  TryRead();

  if (!read_done_ && file_index_ < files_.size()) {
    const auto& filename = files_[file_index_++];
    fp_ = filesystem::NewReadableFile(filename);
    LOG(INFO) << "Loading corpus: " << filename;
    if (fp_->status() != util::OkStatus()) {
      file_index_ = files_.size();
      read_done_ = false;
      return;
    }

    TryRead();
  }
}

void MultiFileSentenceIterator::TryRead() {
  read_done_ = fp_ && fp_->ReadLine(&value_);
}

TrainerInterface::TrainerInterface(const TrainerSpec& trainer_spec,
                                   const NormalizerSpec& normalizer_spec,
                                   const NormalizerSpec& denormalizer_spec)
    : trainer_spec_(trainer_spec),
      normalizer_spec_(normalizer_spec),
      denormalizer_spec_(denormalizer_spec) {
  status_ = VerifySpec(trainer_spec_);
  if (status_.ok()) status_ = InitMetaPieces();
}

TrainerInterface::~TrainerInterface() {}

bool TrainerInterface::IsValidSentencePiece(
    const string_util::UnicodeText &sentencepiece) const {
  // Returns false if the length of piece is invalid.
  if (sentencepiece.empty() ||
      sentencepiece.size() >
          static_cast<size_t>(trainer_spec_.max_sentencepiece_length())) {
    return false;
  }

  constexpr unicode_script::ScriptType kAnyType =
      static_cast<unicode_script::ScriptType>(-1);

  unicode_script::ScriptType prev_script = kAnyType;
  bool all_whitespace_piece =
      std::all_of(sentencepiece.begin(), sentencepiece.end(),
                  [](char32 c) { return c == kWSChar; });

  for (size_t pos = 0; pos < sentencepiece.size(); ++pos) {
    const char32 c = sentencepiece[pos];
    if (c == kUNKChar) {  // UNK must not be included
      return false;
    }
    if (c == 0x0000) {  // NULL is not allowed for Darts (TRIE).
      return false;
    }
    if (c == kUPPBoundaryChar) {
      return false;
    }
    if (c == 0x0020) {
      LOG(WARNING) << "space must not be included in normalized string.";
      return false;
    }
    if (!string_util::IsValidCodepoint(c)) {
      return false;
    }

    if (c == kWSChar) {
      // Only allows whitespace to appear as a prefix of piece unless
      // allow_whitespace_only_pieces is True.
      // When split_by_whitespace is false, we allow whitespaces to
      // appear in the middle, "foo_bar", but do not allow them
      // to appear as suffix, "foo_bar_".
      // Regardless of the setting of split_by_whitespace,
      // whitespace is treated as a prefix/infix of symbol or
      // independent symbol, unless allow_whitespace_only_pieces() is true,
      // in which case whitespace only pieces can occur.
      if (!trainer_spec_.allow_whitespace_only_pieces() ||
          !all_whitespace_piece) {
        if (trainer_spec_.treat_whitespace_as_suffix()) {
          if ((trainer_spec_.split_by_whitespace() &&
               pos < sentencepiece.size() - 1) ||
              (!trainer_spec_.split_by_whitespace() &&
               pos < sentencepiece.size() - 1 && pos == 0)) {
            return false;
          }
        } else {
          if ((trainer_spec_.split_by_whitespace() && pos > 0) ||
              (!trainer_spec_.split_by_whitespace() && pos > 0 &&
               pos == sentencepiece.size() - 1)) {
            return false;
          }
        }
      }
    } else {
      auto s = unicode_script::GetScript(c);

      // Merge Hiragana/Katakana into Han.
      if (s == unicode_script::U_Hiragana || s == unicode_script::U_Katakana ||
          c == 0x30FC) {  // long vowel sound (Katakana) should be Katakana
        s = unicode_script::U_Han;
      } else if (s == unicode_script::U_Inherited) {
        s = prev_script;
      }

      if (!trainer_spec_.split_by_number() && is_unicode_decimal_number(c)) {
        s = kAnyType;
      }

      if (trainer_spec_.split_digits() && is_unicode_decimal_number(c)) {
        if (sentencepiece.size() > 1) {
          return false;
        }
      }

      // Do not allow a piece to include multiple Unicode scripts
      // when split_by_unicode_script() is true (default = true).
      if (trainer_spec_.split_by_unicode_script() && s != kAnyType &&
          prev_script != kAnyType && prev_script != s) {
        return false;
      }

      prev_script = s;
    }
  }
  return true;
}

template <typename T>
void AddDPNoise(const TrainerSpec& trainer_spec,
                absl::SharedBitGen& generator,
                T* to_update) {
  if (trainer_spec.differential_privacy_noise_level() > 0) {
    float random_num = absl::Gaussian<float>(
        generator, 0, trainer_spec.differential_privacy_noise_level());

    *to_update =
        std::round(std::max(0.f, random_num + static_cast<float>(*to_update)));
  }
  // Clip anything below the clipping threshold to 0.
  if (*to_update < trainer_spec.differential_privacy_clipping_threshold()) {
    *to_update = 0;
  }
}

util::Status TrainerInterface::LoadSentences() {
  RETURN_IF_ERROR(status());
  CHECK_OR_RETURN(sentences_.empty());
  CHECK_OR_RETURN(required_chars_.empty());
  CHECK_OR_RETURN(trainer_spec_.input_format().empty() ||
                  trainer_spec_.input_format() == "text" ||
                  trainer_spec_.input_format() == "tsv")
      << "Supported formats are 'text' and 'tsv'.";

  CHECK_OR_RETURN(
      (sentence_iterator_ != nullptr && trainer_spec_.input().empty()) ||
      (sentence_iterator_ == nullptr && !trainer_spec_.input().empty()))
      << "SentenceIterator and trainer_spec.input() must be exclusive.";

  CHECK_OR_RETURN(
      (output_model_proto_ != nullptr &&
       trainer_spec_.model_prefix().empty()) ||
      (output_model_proto_ == nullptr && !trainer_spec_.model_prefix().empty()))
      << "ModelProto and trainer_spec.model_prefix() must be exclusive.";

  const bool is_tsv = trainer_spec_.input_format() == "tsv";

  SentenceSelector selector(&sentences_, trainer_spec_);
  random::ReservoirSampler<std::string> test_sentence_sampler(
      &self_test_samples_, trainer_spec_.self_test_sample_size());

  int too_long_lines = 0;

  std::unique_ptr<SentenceIterator> sentence_iterator_impl;
  if (sentence_iterator_ == nullptr) {
    LOG(INFO) << "SentenceIterator is not specified. Using "
                 "MultiFileSentenceIterator.";
    sentence_iterator_impl =
        absl::make_unique<MultiFileSentenceIterator>(std::vector<std::string>(
            trainer_spec_.input().begin(), trainer_spec_.input().end()));
    sentence_iterator_ = sentence_iterator_impl.get();
  }

  for (; !sentence_iterator_->done(); sentence_iterator_->Next()) {
    int64 freq = 1;
    std::string sentence = sentence_iterator_->value();

    if (is_tsv) {
      const std::vector<std::string> v = absl::StrSplit(sentence, '\t');
      CHECK_EQ_OR_RETURN(v.size(), 2)
          << "Input format must be: word <tab> freq. " << sentence;
      sentence = v[0];
      CHECK_OR_RETURN(absl::SimpleAtoi(v[1], &freq))
          << "Could not parse the frequency";
      CHECK_GE_OR_RETURN(freq, 1);
    }

    if (sentence.empty()) {
      continue;
    }

    if (static_cast<int>(sentence.size()) >
        trainer_spec_.max_sentence_length()) {
      if (too_long_lines == 0) {
        LOG(WARNING) << "Found too long line (" << sentence.size() << " > "
                     << trainer_spec_.max_sentence_length() << ").";
        LOG(WARNING) << "Too long lines are skipped in the training.";
        LOG(WARNING) << "The maximum length can be changed with "
                        "--max_sentence_length=<size> flag.";
      }
      ++too_long_lines;
      continue;
    }

    if (sentence.find(kUNKStr) != std::string::npos) {
      LOG(INFO) << "Reserved chars are found. Skipped: " << sentence;
      continue;
    }

    test_sentence_sampler.Add(sentence);

    if (!selector.Add(std::make_pair(sentence, freq))) {
      goto END;
    }
  }

  RETURN_IF_ERROR(sentence_iterator_->status());

END:
  // Emits error message if any.
  selector.Finish();

  if (sentences_.size() == selector.total_size()) {
    LOG(INFO) << "Loaded all " << sentences_.size() << " sentences";
  } else {
    LOG(INFO) << "Sampled " << sentences_.size() << " sentences from "
              << selector.total_size() << " sentences.";
  }

  if (too_long_lines > 0)
    LOG(INFO) << "Skipped " << too_long_lines << " too long sentences.";
  if (self_test_samples_.size() > 0)
    LOG(INFO) << "Loaded " << self_test_samples_.size() << " test sentences";

  // Normalize and removes empty string.
  {
    const normalizer::Normalizer normalizer(normalizer_spec_, trainer_spec_);
    std::set<absl::string_view> meta_pieces_set;
    for (const auto &it : meta_pieces_) {
      LOG(INFO) << "Adding meta_piece: " << it.second.first;
      meta_pieces_set.insert(it.second.first);
    }
    const normalizer::PrefixMatcher meta_pieces_matcher(meta_pieces_set);

    LOG(INFO) << "Normalizing sentences...";
    CHECK_OR_RETURN(!sentences_.empty());
    {
      auto pool = absl::make_unique<ThreadPool>(trainer_spec_.num_threads());
      pool->StartWorkers();
      for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
        pool->Schedule([&, n]() {
          for (size_t i = n; i < sentences_.size();
               i += trainer_spec_.num_threads()) {
            auto *s = &sentences_[i].first;
            *s = meta_pieces_matcher.GlobalReplace(normalizer.Normalize(*s),
                                                   kUPPBoundaryStr);
          }
        });
      }
    }

    for (size_t i = 0; i < sentences_.size(); ++i) {
      auto *s = &sentences_[i].first;
      CHECK_OR_RETURN(s->find(" ") == std::string::npos)
          << "Normalized string must not include spaces";
      if (s->empty()) {
        std::swap(sentences_[i], sentences_[sentences_.size() - 1]);
        sentences_.resize(sentences_.size() - 1);
      }
    }
  }

  // If DP is required, add the noise/clip the input.
  if (trainer_spec_.enable_differential_privacy()) {
    if (trainer_spec_.input_format() != "tsv") {
      LOG(ERROR)
          << "Dp version will not work correctly with text input format.";
    }
    if (trainer_spec_.differential_privacy_noise_level() <= 0) {
      LOG(WARNING) << "Private version with <=0 noise level will give "
                      "infinity epsilon guarantees.";
    }
    if (trainer_spec_.differential_privacy_clipping_threshold() <= 0) {
      LOG(WARNING) << "Private version with <=0 clipping threshold will give "
                      "infinity epsilon guarantees.";
    }

    // Add noise to all the sentences via threadpool.

    // This line is mainly for tests with small num of sentences.
    const auto num_workers =
        std::min<uint64>(trainer_spec_.num_threads(), sentences_.size() - 1);

    {
      auto pool = absl::make_unique<ThreadPool>(num_workers);
      pool->StartWorkers();
      for (int n = 0; n < num_workers; ++n) {
        pool->Schedule([&, n]() {
          // One per thread generator.
          absl::SharedBitGen generator;
          for (size_t i = n; i < sentences_.size(); i += num_workers) {
            AddDPNoise<int64>(trainer_spec_, generator,
                              &(sentences_[i].second));
          }
        });
      }
    }

    // Remove zero freq elements.
    const auto before_size = sentences_.size();
    auto it = std::remove_if(sentences_.begin(), sentences_.end(),
                             [](const Sentence& s) { return s.second <= 0; });
    const auto new_size = std::distance(sentences_.begin(), it);
    const int num_erased = before_size - new_size;
    sentences_.erase(it, sentences_.end());

    LOG(INFO) << "DP noise resulted in " << 1.0 * num_erased / before_size
              << " fraction of sentences removed.";
  }

  // Count character frequencies.
  int64 all_chars_count = 0;
  // A map from a character to {is_required_char, character count}.
  absl::flat_hash_map<char32, std::pair<bool, int64>> chars_count;
  for (const char32 c :
       string_util::UTF8ToUnicodeText(trainer_spec_.required_chars())) {
    CHECK_OR_RETURN(string_util::IsValidCodepoint(c));
    if (c == 0x0000) {
      LOG(INFO) << "Found null character. The required_chars field must be "
                   "encoded in utf-8.";
      continue;
    }
    chars_count[c].first = true;  // is_required_character.
  }
  for (const auto &w : sentences_) {
    for (const char32 c : string_util::UTF8ToUnicodeText(w.first)) {
      if (!string_util::IsValidCodepoint(c)) continue;
      if (c == 0x0000) {
        LOG(INFO)
            << "Found null character. The corpus must be encoded in utf-8.";
        continue;
      }
      if (c == 0x0020) {
        // UTF8ToUnicodeText returns a white space if the text
        // contains an interchange-invalid character.
        CHECK_OR_RETURN(w.first.find(" ") == std::string::npos)
            << "space must not be included in normalized string.";
        continue;
      }
      chars_count[c].second += w.second;
      all_chars_count += w.second;
    }
  }
  LOG(INFO) << "all chars count=" << all_chars_count;

  // Determines required_chars which must be included in the vocabulary.
  int64 accumulated_chars_count = 0;
  // Sorted() sorts the chars_count values in the decsending order of pair<>.
  // I.e. characters are sorted in the order of required characters and then
  // frequent characters.
  for (const auto &w : Sorted(chars_count)) {
    const float coverage = 1.0 * accumulated_chars_count / all_chars_count;
    if (!trainer_spec_.use_all_vocab() &&
        coverage >= trainer_spec_.character_coverage()) {
      LOG(INFO) << "Done: " << 100.0 * coverage << "% characters are covered.";
      break;
    }
    accumulated_chars_count += w.second.second;
    CHECK_NE_OR_RETURN(w.first, 0x0020)
        << "space must not be included in normalized string.";
    if (w.first == kUPPBoundaryChar) continue;  // Tab is not included.
    required_chars_.emplace(w.first, w.second.second);
  }

  LOG(INFO) << "Alphabet size=" << required_chars_.size();
  LOG(INFO) << "Final character coverage="
            << 1.0 * accumulated_chars_count / all_chars_count;

  CHECK_OR_RETURN(!port::ContainsKey(required_chars_, kUNKChar));

  // Replaces rare characters (characters not included in required_chars_)
  // with kUNKChar.
  for (auto &w : sentences_) {
    string_util::UnicodeText uw2;
    for (const char32 c : string_util::UTF8ToUnicodeText(w.first)) {
      if (port::ContainsKey(required_chars_, c)) {
        uw2.push_back(c);
      } else {
        uw2.push_back(kUNKChar);
      }
    }
    w.first = string_util::UnicodeTextToUTF8(uw2);
  }

  // +3 for meta pieces.
  if (trainer_spec_.model_type() != TrainerSpec::WORD &&
      trainer_spec_.model_type() != TrainerSpec::CHAR) {
    CHECK_LE_OR_RETURN(
        static_cast<int>(required_chars_.size() + meta_pieces_.size()),
        trainer_spec_.vocab_size())
        << "Vocabulary size is smaller than required_chars. "
        << trainer_spec_.vocab_size() << " vs "
        << required_chars_.size() + meta_pieces_.size() << ". "
        << "Increase vocab_size or decrease character_coverage with "
        << "--character_coverage option.";
  }

  LOG(INFO) << "Done! preprocessed " << sentences_.size() << " sentences.";

  return util::OkStatus();
}

void TrainerInterface::SplitSentencesByWhitespace() {
  LOG(INFO) << "Tokenizing input sentences with whitespace: "
            << sentences_.size();
  absl::flat_hash_map<std::string, int64> tokens;
  for (const auto &s : sentences_) {
    for (const auto& w :
         SplitIntoWords(s.first, trainer_spec_.treat_whitespace_as_suffix(),
                        trainer_spec_.allow_whitespace_only_pieces())) {
      tokens[std::string(w)] += s.second;
    }
  }
  sentences_ = Sorted(tokens);
  LOG(INFO) << "Done! " << sentences_.size();
}

util::Status TrainerInterface::Serialize(ModelProto* model_proto) const {
  RETURN_IF_ERROR(status());

  // Duplicated sentencepiece is not allowed.
  std::set<std::string> dup;

  model_proto->Clear();

#define CHECK_PIECE(piece)                                  \
  CHECK_OR_RETURN(string_util::IsStructurallyValid(piece)); \
  CHECK_OR_RETURN(!piece.empty());                          \
  CHECK_OR_RETURN(dup.insert(piece).second) << piece << " is already defined";

  size_t fid = 0;
  for (int id = 0; id < trainer_spec_.vocab_size(); ++id) {
    const auto it = meta_pieces_.find(id);
    if (it != meta_pieces_.end()) {
      auto *sp = model_proto->add_pieces();
      sp->set_piece(it->second.first);
      sp->set_type(it->second.second);
      sp->set_score(0.0);
      CHECK_EQ_OR_RETURN(model_proto->pieces_size() - 1, it->first);
      CHECK_NE_OR_RETURN(ModelProto::SentencePiece::NORMAL, sp->type());
      CHECK_PIECE(sp->piece());
    } else if (fid < final_pieces_.size()) {
      const auto &w = final_pieces_[fid++];
      auto *sp = model_proto->add_pieces();
      sp->set_piece(w.first);
      sp->set_score(w.second);
      CHECK_PIECE(sp->piece());
    }
  }

  CHECK_EQ_OR_RETURN(fid, final_pieces_.size());

  *(model_proto->mutable_trainer_spec()) = trainer_spec_;
  *(model_proto->mutable_normalizer_spec()) = normalizer_spec_;

  if (!denormalizer_spec_.normalization_rule_tsv().empty()) {
    *(model_proto->mutable_denormalizer_spec()) = denormalizer_spec_;
  }

  if (!trainer_spec_.hard_vocab_limit() ||
      trainer_spec_.model_type() == TrainerSpec::CHAR) {
    CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size());
    CHECK_GE_OR_RETURN(trainer_spec_.vocab_size(),
                       static_cast<int32>(dup.size()));
    model_proto->mutable_trainer_spec()->set_vocab_size(
        model_proto->pieces_size());
  } else {
    CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(), model_proto->pieces_size())
        << absl::StrFormat(
               "Vocabulary size too high (%d). Please set it to a value <= %d.",
               trainer_spec_.vocab_size(), model_proto->pieces_size());
    CHECK_EQ_OR_RETURN(trainer_spec_.vocab_size(),
                       static_cast<int32>(dup.size()));
  }

  // Saves self-testing data.
  if (!self_test_samples_.empty()) {
    SentencePieceProcessor sp;
    RETURN_IF_ERROR(sp.Load(*model_proto));
    for (const auto &input : self_test_samples_) {
      std::vector<std::string> sps;
      RETURN_IF_ERROR(sp.Encode(input, &sps));
      auto* sample = model_proto->mutable_self_test_data()->add_samples();
      sample->set_input(input);
      sample->set_expected(absl::StrJoin(sps, " "));
    }
  }

  return util::OkStatus();
}

util::Status TrainerInterface::SaveModel(absl::string_view filename) const {
  LOG(INFO) << "Saving model: " << filename;
  ModelProto model_proto;

  RETURN_IF_ERROR(Serialize(&model_proto));

  auto output = filesystem::NewWritableFile(filename.data(), true);
  RETURN_IF_ERROR(output->status());
  output->Write(model_proto.SerializeAsString());
  return util::OkStatus();
}

util::Status TrainerInterface::SaveVocab(absl::string_view filename) const {
  LOG(INFO) << "Saving vocabs: " << filename;
  ModelProto model_proto;
  RETURN_IF_ERROR(Serialize(&model_proto));
  auto output = filesystem::NewWritableFile(filename);
  RETURN_IF_ERROR(output->status());

  for (const auto& piece : model_proto.pieces()) {
    if (piece.piece().find_first_of(" \t\r\n") != std::string::npos) {
      LOG(WARNING) << "The piece [" << piece.piece()
                   << "] contains escaped characters that break the format of "
                   << filename;
    }
  }

  if (trainer_spec_.vocabulary_output_piece_score()) {
    for (const auto &piece : model_proto.pieces()) {
      std::ostringstream os;
      os << piece.piece() << "\t" << piece.score();
      CHECK_OR_RETURN(output->WriteLine(os.str()));
    }
  } else {
    for (const auto &piece : model_proto.pieces()) {
      CHECK_OR_RETURN(output->WriteLine(piece.piece()));
    }
  }

  return util::OkStatus();
}

util::Status TrainerInterface::Save() const {
  if (output_model_proto_) {
    RETURN_IF_ERROR(Serialize(output_model_proto_));
  } else {
    RETURN_IF_ERROR(SaveModel(trainer_spec_.model_prefix() + ".model"));
    RETURN_IF_ERROR(SaveVocab(trainer_spec_.model_prefix() + ".vocab"));
  }
  return util::OkStatus();
}

util::Status TrainerInterface::InitMetaPieces() {
  CHECK_OR_RETURN(meta_pieces_.empty());
  bool has_unk = false;

  auto insert_id = [&has_unk, this](int id, const std::string &w) -> bool {
    if (id < 0) return true;
    if (id >= trainer_spec_.vocab_size() ||
        meta_pieces_.find(id) != meta_pieces_.end() ||
        (has_unk && w == trainer_spec_.unk_piece()))
      return false;
    if (w == trainer_spec_.unk_piece()) has_unk = true;
    meta_pieces_[id] = std::make_pair(
        w, w == trainer_spec_.unk_piece() ? ModelProto::SentencePiece::UNKNOWN
                                          : ModelProto::SentencePiece::CONTROL);
    return true;
  };

  CHECK_OR_RETURN(insert_id(trainer_spec_.unk_id(), trainer_spec_.unk_piece()));
  CHECK_OR_RETURN(insert_id(trainer_spec_.bos_id(), trainer_spec_.bos_piece()));
  CHECK_OR_RETURN(insert_id(trainer_spec_.eos_id(), trainer_spec_.eos_piece()));
  CHECK_OR_RETURN(insert_id(trainer_spec_.pad_id(), trainer_spec_.pad_piece()));

  CHECK_OR_RETURN(has_unk) << trainer_spec_.unk_piece() << " must be defined.";

  std::set<std::string> dup;

  int id = 0;
  auto insert_meta_symbol =
      [&id, &dup, this](const std::string& w,
                        ModelProto::SentencePiece::Type type) -> util::Status {
    if (!dup.insert(w).second) {
      return util::InternalError(absl::StrCat(
          w, " is already defined. duplicated symbols are not allowed."));
    }

    if (w == trainer_spec_.unk_piece()) {
      return util::InternalError(
          absl::StrCat(trainer_spec_.unk_piece(),
                       " must not be defined with --control_symbols and "
                       "--user_defined_symbols."));
    }

    if (w == trainer_spec_.bos_piece() && trainer_spec_.bos_id() >= 0) {
      meta_pieces_[trainer_spec_.bos_id()].second = type;
    } else if (w == trainer_spec_.eos_piece() && trainer_spec_.eos_id() >= 0) {
      meta_pieces_[trainer_spec_.eos_id()].second = type;
    } else if (w == trainer_spec_.pad_piece() && trainer_spec_.pad_id() >= 0) {
      meta_pieces_[trainer_spec_.pad_id()].second = type;
    } else {
      while (meta_pieces_.find(id) != meta_pieces_.end())
        ++id;
      meta_pieces_[id] = std::make_pair(w, type);
    }

    return util::OkStatus();
  };

  for (const auto &w : trainer_spec_.control_symbols()) {
    RETURN_IF_ERROR(insert_meta_symbol(w, ModelProto::SentencePiece::CONTROL));
  }

  for (const auto &w : trainer_spec_.user_defined_symbols()) {
    RETURN_IF_ERROR(
        insert_meta_symbol(w, ModelProto::SentencePiece::USER_DEFINED));
  }

  if (trainer_spec_.byte_fallback()) {
    for (int i = 0; i < 256; ++i) {
      RETURN_IF_ERROR(
          insert_meta_symbol(ByteToPiece(i), ModelProto::SentencePiece::BYTE));
    }
  }

  return util::OkStatus();
}

}  // namespace sentencepiece