chromium/third_party/sentencepiece/src/src/sentencepiece_trainer.h

// Copyright 2018 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#ifndef SENTENCEPIECE_TRAINER_H_
#define SENTENCEPIECE_TRAINER_H_

#include <string>
#include <unordered_map>

#include "sentencepiece_processor.h"

namespace sentencepiece {

class TrainerSpec;
class NormalizerSpec;

namespace pretokenizer {
class PretokenizerForTrainingInterface;
}  // namespace pretokenizer

// Iterator over the training sentences.
// Training sentences are loaded sequentially as follows:
//
// for (; !it.done(); it.Next()) {
//    const std::string &s = it.value();
// }
// RETURN_IF_ERROR(it.status());
//
class SentenceIterator {
 public:
  virtual ~SentenceIterator() {}
  // Returns true if iteration finishes (including error case).
  // Uses SentenceIterator::status() method to know whether
  // all sentences are loaded successfully.
  virtual bool done() const = 0;
  virtual void Next() = 0;
  virtual const std::string& value() const = 0;
  virtual util::Status status() const = 0;
};

class SentencePieceTrainer {
 public:
  // Trains SentencePiece model with `trainer_spec`.
  // Default `normalizer_spec` is used.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec& trainer_spec,
                            SentenceIterator* sentence_iterator = nullptr,
                            std::string* serialized_model_proto = nullptr);

  // Trains SentencePiece model with `trainer_spec` and
  // `normalizer_spec`.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec& trainer_spec,
                            const NormalizerSpec& normalizer_spec,
                            SentenceIterator* sentence_iterator = nullptr,
                            std::string* serialized_model_proto = nullptr);

  // Trains SentencePiece model with `trainer_spec`, `normalizer_spec`
  // and `denormalizer_spec`.
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(const TrainerSpec& trainer_spec,
                            const NormalizerSpec& normalizer_spec,
                            const NormalizerSpec& denormalizer_spec,
                            SentenceIterator* sentence_iterator = nullptr,
                            std::string* serialized_model_proto = nullptr);
  // Trains SentencePiece model with command-line string in `args`,
  // e.g.,
  // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram'
  // When `sentence_iterator` is passed, load sentences from the iterator.
  static util::Status Train(absl::string_view args,
                            SentenceIterator* sentence_iterator = nullptr,
                            std::string* serialized_model_proto = nullptr);

  // Trains SentencePiece model with mapin `kwargs`.
  // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...}
  static util::Status Train(
      const std::unordered_map<std::string, std::string>& kwargs,
      SentenceIterator* sentence_iterator = nullptr,
      std::string* serialized_model_proto = nullptr);

  // Handy function to make a normalizer spec from the pre-compiled
  // normalization name. Do not use this method in production as it crashes
  // When `name` is invalid. Useful for unittesting.
  static NormalizerSpec GetNormalizerSpec(absl::string_view name);

  // Populates necessary fields (precompiled_charmap) from
  // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`.
  static util::Status PopulateNormalizerSpec(NormalizerSpec* normalizer_spec,
                                             bool is_denormalizer = false);

  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
  // std::unordered_map in `kargs`.
  static util::Status MergeSpecsFromArgs(
      const std::unordered_map<std::string, std::string>& kwargs,
      TrainerSpec* trainer_spec,
      NormalizerSpec* normalizer_spec,
      NormalizerSpec* denormalizer_spec);

  // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the
  // command line flags in `args`.
  static util::Status MergeSpecsFromArgs(absl::string_view args,
                                         TrainerSpec* trainer_spec,
                                         NormalizerSpec* normalizer_spec,
                                         NormalizerSpec* denormalizer_spec);

  // Injects global pre-tokenizer that are applied in training time.
  // Pretokenizer is only used for extracting pieces.
  // TODO(taku): It would be better to inject per `trainer_spec`.
  static util::Status SetPretokenizerForTraining(
      const pretokenizer::PretokenizerForTrainingInterface* pretokenizer);

  // Returns the current pretokenizer. if no pretokenizer is defined, returns
  // nullptr.
  static const pretokenizer::PretokenizerForTrainingInterface *
  GetPretokenizerForTraining();

  // Helper function to set `field_name=value` in `message`.
  // When `field_name` is repeated, multiple values can be passed
  // with comma-separated values. `field_name` must not be a nested message.
  // The body of these functions are automatically generated with
  // data/gen_spec_parser.pl
  static util::Status SetProtoField(absl::string_view name,
                                    absl::string_view value,
                                    TrainerSpec* message);

  static util::Status SetProtoField(absl::string_view name,
                                    absl::string_view value,
                                    NormalizerSpec* message);

  // Populates model type from string representation, e.g., "bpe".
  // Supported model: "unigram", "bpe", "word", "char".
  static util::Status PopulateModelTypeFromString(absl::string_view type,
                                                  TrainerSpec* trainer_spec);

 private:
  SentencePieceTrainer() {}
  ~SentencePieceTrainer() {}
};

}  // namespace sentencepiece

#endif  // SENTENCEPIECE_TRAINER_H_