chromium/third_party/sentencepiece/src/src/spm_train_main.cc

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#include <map>

#include "absl/flags/flag.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "filesystem.h"
#include "init.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "util.h"

using sentencepiece::NormalizerSpec;
using sentencepiece::TrainerSpec;

namespace {
static sentencepiece::TrainerSpec kDefaultTrainerSpec;
static sentencepiece::NormalizerSpec kDefaultNormalizerSpec;
}  // namespace

ABSL_FLAG(std::string, input, "", "comma separated list of input sentences");
ABSL_FLAG(std::string,
          input_format,
          kDefaultTrainerSpec.input_format(),
          "Input format. Supported format is `text` or `tsv`.");
ABSL_FLAG(std::string, model_prefix, "", "output model prefix");
ABSL_FLAG(std::string,
          model_type,
          "unigram",
          "model algorithm: unigram, bpe, word or char");
ABSL_FLAG(int32,
          vocab_size,
          kDefaultTrainerSpec.vocab_size(),
          "vocabulary size");
ABSL_FLAG(std::string,
          accept_language,
          "",
          "comma-separated list of languages this model can accept");
ABSL_FLAG(int32,
          self_test_sample_size,
          kDefaultTrainerSpec.self_test_sample_size(),
          "the size of self test samples");
ABSL_FLAG(double,
          character_coverage,
          kDefaultTrainerSpec.character_coverage(),
          "character coverage to determine the minimum symbols");
ABSL_FLAG(std::uint64_t,
          input_sentence_size,
          kDefaultTrainerSpec.input_sentence_size(),
          "maximum size of sentences the trainer loads");
ABSL_FLAG(bool,
          shuffle_input_sentence,
          kDefaultTrainerSpec.shuffle_input_sentence(),
          "Randomly sample input sentences in advance. Valid when "
          "--input_sentence_size > 0");
ABSL_FLAG(int32,
          seed_sentencepiece_size,
          kDefaultTrainerSpec.seed_sentencepiece_size(),
          "the size of seed sentencepieces");
ABSL_FLAG(double,
          shrinking_factor,
          kDefaultTrainerSpec.shrinking_factor(),
          "Keeps top shrinking_factor pieces with respect to the loss");
ABSL_FLAG(int32,
          num_threads,
          kDefaultTrainerSpec.num_threads(),
          "number of threads for training");
ABSL_FLAG(int32,
          num_sub_iterations,
          kDefaultTrainerSpec.num_sub_iterations(),
          "number of EM sub-iterations");
ABSL_FLAG(int32,
          max_sentencepiece_length,
          kDefaultTrainerSpec.max_sentencepiece_length(),
          "maximum length of sentence piece");
ABSL_FLAG(int32,
          max_sentence_length,
          kDefaultTrainerSpec.max_sentence_length(),
          "maximum length of sentence in byte");
ABSL_FLAG(bool,
          split_by_unicode_script,
          kDefaultTrainerSpec.split_by_unicode_script(),
          "use Unicode script to split sentence pieces");
ABSL_FLAG(bool,
          split_by_number,
          kDefaultTrainerSpec.split_by_number(),
          "split tokens by numbers (0-9)");
ABSL_FLAG(bool,
          split_by_whitespace,
          kDefaultTrainerSpec.split_by_whitespace(),
          "use a white space to split sentence pieces");
ABSL_FLAG(bool,
          split_digits,
          kDefaultTrainerSpec.split_digits(),
          "split all digits (0-9) into separate pieces");
ABSL_FLAG(std::string,
          pretokenization_delimiter,
          kDefaultTrainerSpec.pretokenization_delimiter(),
          "specifies the delimiter of pre-tokenization");
ABSL_FLAG(bool,
          treat_whitespace_as_suffix,
          kDefaultTrainerSpec.treat_whitespace_as_suffix(),
          "treat whitespace marker as suffix instead of prefix.");
ABSL_FLAG(bool,
          allow_whitespace_only_pieces,
          kDefaultTrainerSpec.allow_whitespace_only_pieces(),
          "allow pieces that only contain (consecutive) whitespace tokens");
ABSL_FLAG(std::string,
          control_symbols,
          "",
          "comma separated list of control symbols");
ABSL_FLAG(std::string,
          control_symbols_file,
          "",
          "load control_symbols from file.");
ABSL_FLAG(std::string,
          user_defined_symbols,
          "",
          "comma separated list of user defined symbols");
ABSL_FLAG(std::string,
          user_defined_symbols_file,
          "",
          "load user_defined_symbols from file.");
ABSL_FLAG(std::string,
          required_chars,
          "",
          "UTF8 characters in this flag are always used in the character "
          "set regardless of --character_coverage");
ABSL_FLAG(std::string,
          required_chars_file,
          "",
          "load required_chars from file.");
ABSL_FLAG(bool,
          byte_fallback,
          kDefaultTrainerSpec.byte_fallback(),
          "decompose unknown pieces into UTF-8 byte pieces");
ABSL_FLAG(bool,
          vocabulary_output_piece_score,
          kDefaultTrainerSpec.vocabulary_output_piece_score(),
          "Define score in vocab file");
ABSL_FLAG(std::string,
          normalization_rule_name,
          "nmt_nfkc",
          "Normalization rule name. "
          "Choose from nfkc or identity");
ABSL_FLAG(std::string,
          normalization_rule_tsv,
          "",
          "Normalization rule TSV file. ");
ABSL_FLAG(std::string,
          denormalization_rule_tsv,
          "",
          "Denormalization rule TSV file.");
ABSL_FLAG(bool,
          add_dummy_prefix,
          kDefaultNormalizerSpec.add_dummy_prefix(),
          "Add dummy whitespace at the beginning of text");
ABSL_FLAG(bool,
          remove_extra_whitespaces,
          kDefaultNormalizerSpec.remove_extra_whitespaces(),
          "Removes leading, trailing, and "
          "duplicate internal whitespace");
ABSL_FLAG(bool,
          hard_vocab_limit,
          kDefaultTrainerSpec.hard_vocab_limit(),
          "If set to false, --vocab_size is considered as a soft limit.");
ABSL_FLAG(bool,
          use_all_vocab,
          kDefaultTrainerSpec.use_all_vocab(),
          "If set to true, use all tokens as vocab. "
          "Valid for word/char models.");
ABSL_FLAG(int32,
          unk_id,
          kDefaultTrainerSpec.unk_id(),
          "Override UNK (<unk>) id.");
ABSL_FLAG(int32,
          bos_id,
          kDefaultTrainerSpec.bos_id(),
          "Override BOS (<s>) id. Set -1 to disable BOS.");
ABSL_FLAG(int32,
          eos_id,
          kDefaultTrainerSpec.eos_id(),
          "Override EOS (</s>) id. Set -1 to disable EOS.");
ABSL_FLAG(int32,
          pad_id,
          kDefaultTrainerSpec.pad_id(),
          "Override PAD (<pad>) id. Set -1 to disable PAD.");
ABSL_FLAG(std::string,
          unk_piece,
          kDefaultTrainerSpec.unk_piece(),
          "Override UNK (<unk>) piece.");
ABSL_FLAG(std::string,
          bos_piece,
          kDefaultTrainerSpec.bos_piece(),
          "Override BOS (<s>) piece.");
ABSL_FLAG(std::string,
          eos_piece,
          kDefaultTrainerSpec.eos_piece(),
          "Override EOS (</s>) piece.");
ABSL_FLAG(std::string,
          pad_piece,
          kDefaultTrainerSpec.pad_piece(),
          "Override PAD (<pad>) piece.");
ABSL_FLAG(std::string,
          unk_surface,
          kDefaultTrainerSpec.unk_surface(),
          "Dummy surface string for <unk>. In decoding <unk> is decoded to "
          "`unk_surface`.");
ABSL_FLAG(bool,
          train_extremely_large_corpus,
          kDefaultTrainerSpec.train_extremely_large_corpus(),
          "Increase bit depth for unigram tokenization.");
ABSL_FLAG(uint32,
          random_seed,
          static_cast<uint32>(-1),
          "Seed value for random generator.");

// DP related.
ABSL_FLAG(bool,
          enable_differential_privacy,
          false,
          "Whether to add DP while training. Currently supported only by "
          "UNIGRAM model.");

ABSL_FLAG(float,
          differential_privacy_noise_level,
          0.0f,
          "Amount of noise to add for"
          " DP");
ABSL_FLAG(std::uint64_t,
          differential_privacy_clipping_threshold,
          0,
          "Threshold for"
          " clipping the counts for DP");

int main(int argc, char *argv[]) {
  sentencepiece::ScopedResourceDestructor cleaner;
  sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);

  sentencepiece::TrainerSpec trainer_spec;
  sentencepiece::NormalizerSpec normalizer_spec;
  NormalizerSpec denormalizer_spec;

  CHECK(!absl::GetFlag(FLAGS_input).empty());
  CHECK(!absl::GetFlag(FLAGS_model_prefix).empty());

  if (absl::GetFlag(FLAGS_random_seed) != -1) {
    sentencepiece::SetRandomGeneratorSeed(absl::GetFlag(FLAGS_random_seed));
  }

  auto load_lines = [](absl::string_view filename) {
    std::vector<std::string> lines;
    auto input = sentencepiece::filesystem::NewReadableFile(filename);
    CHECK_OK(input->status());
    std::string line;
    while (input->ReadLine(&line)) {
      lines.emplace_back(line);
    }
    return lines;
  };

// Populates the value from flags to spec.
#define SetTrainerSpecFromFlag(name) \
  trainer_spec.set_##name(absl::GetFlag(FLAGS_##name));

#define SetNormalizerSpecFromFlag(name) \
  normalizer_spec.set_##name(absl::GetFlag(FLAGS_##name));

#define SetTrainerSpecFromFile(name)                                   \
  if (!absl::GetFlag(FLAGS_##name##_file).empty()) {                   \
    const auto lines = load_lines(absl::GetFlag(FLAGS_##name##_file)); \
    trainer_spec.set_##name(absl::StrJoin(lines, ""));                 \
  }

#define SetRepeatedTrainerSpecFromFlag(name)                                \
  if (!absl::GetFlag(FLAGS_##name).empty()) {                               \
    for (const auto& v :                                                    \
         sentencepiece::util::StrSplitAsCSV(absl::GetFlag(FLAGS_##name))) { \
      trainer_spec.add_##name(v);                                           \
    }                                                                       \
  }

#define SetRepeatedTrainerSpecFromFile(name)                               \
  if (!absl::GetFlag(FLAGS_##name##_file).empty()) {                       \
    for (const auto& v : load_lines(absl::GetFlag(FLAGS_##name##_file))) { \
      trainer_spec.add_##name(v);                                          \
    }                                                                      \
  }

  SetRepeatedTrainerSpecFromFlag(input);

  SetTrainerSpecFromFlag(input_format);
  SetTrainerSpecFromFlag(model_prefix);
  SetTrainerSpecFromFlag(vocab_size);
  SetTrainerSpecFromFlag(self_test_sample_size);
  SetTrainerSpecFromFlag(character_coverage);
  SetTrainerSpecFromFlag(input_sentence_size);
  SetTrainerSpecFromFlag(shuffle_input_sentence);
  SetTrainerSpecFromFlag(seed_sentencepiece_size);
  SetTrainerSpecFromFlag(shrinking_factor);
  SetTrainerSpecFromFlag(num_threads);
  SetTrainerSpecFromFlag(num_sub_iterations);
  SetTrainerSpecFromFlag(max_sentencepiece_length);
  SetTrainerSpecFromFlag(max_sentence_length);
  SetTrainerSpecFromFlag(split_by_unicode_script);
  SetTrainerSpecFromFlag(split_by_whitespace);
  SetTrainerSpecFromFlag(split_by_number);
  SetTrainerSpecFromFlag(split_digits);
  SetTrainerSpecFromFlag(pretokenization_delimiter);
  SetTrainerSpecFromFlag(byte_fallback);
  SetTrainerSpecFromFlag(treat_whitespace_as_suffix);
  SetTrainerSpecFromFlag(allow_whitespace_only_pieces);
  SetTrainerSpecFromFlag(hard_vocab_limit);
  SetTrainerSpecFromFlag(use_all_vocab);
  SetTrainerSpecFromFlag(unk_id);
  SetTrainerSpecFromFlag(bos_id);
  SetTrainerSpecFromFlag(eos_id);
  SetTrainerSpecFromFlag(pad_id);
  SetTrainerSpecFromFlag(unk_piece);
  SetTrainerSpecFromFlag(bos_piece);
  SetTrainerSpecFromFlag(eos_piece);
  SetTrainerSpecFromFlag(pad_piece);
  SetTrainerSpecFromFlag(unk_surface);
  SetTrainerSpecFromFlag(required_chars);
  SetTrainerSpecFromFile(required_chars);
  SetTrainerSpecFromFlag(vocabulary_output_piece_score);
  SetRepeatedTrainerSpecFromFlag(accept_language);
  SetRepeatedTrainerSpecFromFlag(control_symbols);
  SetRepeatedTrainerSpecFromFlag(user_defined_symbols);
  SetTrainerSpecFromFlag(train_extremely_large_corpus);
  // DP related.
  SetTrainerSpecFromFlag(enable_differential_privacy);
  SetTrainerSpecFromFlag(differential_privacy_noise_level);
  SetTrainerSpecFromFlag(differential_privacy_clipping_threshold);

  SetRepeatedTrainerSpecFromFile(control_symbols);
  SetRepeatedTrainerSpecFromFile(user_defined_symbols);

  normalizer_spec.set_name(absl::GetFlag(FLAGS_normalization_rule_name));
  SetNormalizerSpecFromFlag(normalization_rule_tsv);
  SetNormalizerSpecFromFlag(add_dummy_prefix);
  SetNormalizerSpecFromFlag(remove_extra_whitespaces);

  if (!absl::GetFlag(FLAGS_denormalization_rule_tsv).empty()) {
    denormalizer_spec.set_normalization_rule_tsv(
        absl::GetFlag(FLAGS_denormalization_rule_tsv));
    denormalizer_spec.set_add_dummy_prefix(false);
    denormalizer_spec.set_remove_extra_whitespaces(false);
    denormalizer_spec.set_escape_whitespaces(false);
  }

  CHECK_OK(sentencepiece::SentencePieceTrainer::PopulateModelTypeFromString(
      absl::GetFlag(FLAGS_model_type), &trainer_spec));

  CHECK_OK(sentencepiece::SentencePieceTrainer::Train(
      trainer_spec, normalizer_spec, denormalizer_spec));

  return 0;
}