chromium/third_party/sentencepiece/src/src/trainer_interface.h

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#ifndef TRAINER_INTERFACE_H_
#define TRAINER_INTERFACE_H_

#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "common.h"
#include "filesystem.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "util.h"

namespace sentencepiece {

template <typename K, typename V>
std::vector<std::pair<K, V>> Sorted(const std::vector<std::pair<K, V>> &m) {
  std::vector<std::pair<K, V>> v = m;
  std::sort(v.begin(), v.end(),
            [](const std::pair<K, V> &p1, const std::pair<K, V> &p2) {
              return (p1.second > p2.second ||
                      (p1.second == p2.second && p1.first < p2.first));
            });
  return v;
}

template <typename K, typename V>
std::vector<std::pair<K, V>> Sorted(const absl::flat_hash_map<K, V> &m) {
  std::vector<std::pair<K, V>> v(m.begin(), m.end());
  return Sorted(v);
}

class MultiFileSentenceIterator : public SentenceIterator {
 public:
  explicit MultiFileSentenceIterator(const std::vector<std::string>& files);
  ~MultiFileSentenceIterator() {}

  bool done() const override;
  void Next() override;
  const std::string& value() const override { return value_; }
  util::Status status() const override;

 private:
  void TryRead();

  bool read_done_ = false;
  size_t file_index_ = 0;
  std::vector<std::string> files_;
  std::string value_;
  std::unique_ptr<filesystem::ReadableFile> fp_;
};

// Base trainer class
class TrainerInterface {
 public:
  using Sentence = std::pair<std::string, int64>;
  using Sentences = std::vector<Sentence>;

  static const char32 kWSChar;
  static const char32 kUNKChar;
  static const char32 kUPPBoundaryChar;
  static const char kWSStr[];
  static const char kUNKStr[];
  static const char kUPPBoundaryStr[];

  TrainerInterface(const TrainerSpec& trainer_spec,
                   const NormalizerSpec& normalizer_spec,
                   const NormalizerSpec& denormalizer_spec);

  virtual ~TrainerInterface();

  // Loads sentence from `sentence_iterator` and stores the model
  // to `output_model_proto`.
  virtual util::Status Train(SentenceIterator* sentence_iterator,
                             ModelProto* output_model_proto) {
    sentence_iterator_ = sentence_iterator;
    output_model_proto_ = output_model_proto;
    return Train();
  }

  virtual util::Status Train() { return status(); }

  virtual util::Status status() const { return status_; }

  FRIEND_TEST(TrainerInterfaceTest, IsValidSentencePieceTest);
  FRIEND_TEST(TrainerInterfaceTest, OverrideSpecialPiecesTest);
  FRIEND_TEST(TrainerInterfaceTest, BytePiecesTest);
  FRIEND_TEST(TrainerInterfaceTest, SerializeTest);
  FRIEND_TEST(TrainerInterfaceTest, CharactersTest);

  // Loads all sentences from spec.input() or SentenceIterator.
  // It loads at most input_sentence_size sentences.
  util::Status LoadSentences();

 protected:
  // Returns true if |piece| is valid sentence piece.
  // The result is affected by
  // max_sentencepiece_length, split_by_whiespace, split_by_unicode_script.
  bool IsValidSentencePiece(const string_util::UnicodeText &piece) const;

  // Splits all sentencecs by whitespaces and
  // replace the |sentences_| with tokenized string.
  // e.g.,
  //  [ ["hello world ", 1], ["hi world]" ] =>
  //  [ ["hello", 1], ["hi", 1], ["world", 2] ]
  void SplitSentencesByWhitespace();

  // Save model files into spec.model_prefix().
  util::Status Save() const;

  // Set of characters which must be included in the final vocab.
  // The value of this map stores the frequency.
  absl::flat_hash_map<char32, int64> required_chars_;

  // Final output pieces
  std::vector<std::pair<std::string, float>> final_pieces_;

  // All sentences.
  Sentences sentences_;

  // Trainer spec.
  TrainerSpec trainer_spec_;

  // Normalizer spec
  NormalizerSpec normalizer_spec_;

  // Denormalizer spec
  NormalizerSpec denormalizer_spec_;

  // Reserved control pieces. e.g., <unk>, <s>, </s>.
  // key is vocab id.
  std::map<int, std::pair<std::string, ModelProto::SentencePiece::Type>>
      meta_pieces_;

  // Detect errors on initialization.
  util::Status status_;

  // Loads sentences from SentenceIterator if not null.
  SentenceIterator* sentence_iterator_ = nullptr;

  // Emits model to this proto instead of file.
  ModelProto* output_model_proto_ = nullptr;

 private:
  // Serialize final_pieces_ to |model_proto|.
  util::Status Serialize(ModelProto* model_proto) const;

  // Saves the best sentence split with the current model for debugging.
  util::Status SaveSplits(absl::string_view filename) const;

  // Saves model file.
  util::Status SaveModel(absl::string_view filename) const;

  // Saves vocabulary file for NMT.
  util::Status SaveVocab(absl::string_view filename) const;

  // Initializes `meta_pieces_` from TrainerSpec.
  util::Status InitMetaPieces();

  // Randomly sampled raw sentences for self-testing.
  std::vector<std::string> self_test_samples_;
};
}  // namespace sentencepiece
#endif  // TRAINER_INTERFACE_H_