// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model_trainer.h"
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/str_split.h"
#include "normalizer.h"
#include "pretokenizer_for_training.h"
#include "sentencepiece_trainer.h"
#include "third_party/esaxx/esa.hxx" // Suffix array library.
#include "unicode_script.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
namespace {
constexpr char32 kSentenceBoundary = 0x0000;
double Digamma(double x) {
double result = 0.0;
for (; x < 7; ++x) result -= 1 / x;
x -= 1.0 / 2.0;
const double xx = 1.0 / x;
const double xx2 = xx * xx;
const double xx4 = xx2 * xx2;
result += std::log(x) + (1.0 / 24.0) * xx2 - (7.0 / 960.0) * xx4 +
(31.0 / 8064.0) * xx4 * xx2 - (127.0 / 30720.0) * xx4 * xx4;
return result;
}
template <typename IT>
void ToLogProb(IT begin, IT end) {
float sum = 0.0;
for (auto it = begin; it != end; ++it) {
sum += it->second;
}
float logsum = std::log(static_cast<double>(sum));
for (auto it = begin; it != end; ++it) {
it->second = std::log(static_cast<double>(it->second)) - logsum;
}
}
template <class T>
class BoundedPriorityQueue {
public:
explicit BoundedPriorityQueue(size_t size) : size_(size) {}
~BoundedPriorityQueue() = default;
void push(T elem, int64 score) {
if (queue_.size() > 4 * size_) {
resize();
}
if (sorted && queue_.size() >= size_ && queue_[size_ - 1].second > score) {
return;
}
queue_.emplace_back(elem, score);
}
const std::vector<std::pair<T, int64>>& get() {
resize();
return queue_;
}
private:
void resize() {
std::sort(queue_.begin(), queue_.end(), [](const auto& p1, const auto& p2) {
return (p1.second > p2.second ||
(p1.second == p2.second && p1.first < p2.first));
});
sorted = true;
if (queue_.size() > size_) {
queue_.resize(size_);
}
}
bool sorted = false;
size_t size_ = 0;
std::vector<std::pair<T, int64>> queue_;
};
} // namespace
TrainerModel::TrainerModel(const TrainerSpec &trainer_spec,
const NormalizerSpec &normalizer_spec)
: trainer_spec_(trainer_spec), normalizer_spec_(normalizer_spec) {}
TrainerModel::~TrainerModel() {}
const TrainerModel::SentencePieces &TrainerModel::GetSentencePieces() const {
return sentencepieces_;
}
void TrainerModel::SetSentencePieces(SentencePieces &&sentencepieces) {
sentencepieces_ = std::move(sentencepieces);
CHECK(!sentencepieces_.empty());
min_score_ = FLT_MAX;
model_proto_data_.Clear();
model_proto_ = &model_proto_data_;
std::vector<std::pair<absl::string_view, int>> pieces;
for (size_t i = 0; i < sentencepieces_.size(); ++i) {
const absl::string_view w = sentencepieces_[i].first; // piece
const float score = sentencepieces_[i].second; // score.
CHECK(!std::isnan(score));
pieces.emplace_back(w, i);
min_score_ = std::min(min_score_, score);
auto *piece = model_proto_data_.add_pieces();
piece->set_piece(w.data(), w.size());
piece->set_score(score);
}
BuildTrie(&pieces);
CHECK(status().ok());
}
TrainerModel::SentencePieces Trainer::MakeSeedSentencePieces() {
return trainer_spec_.train_extremely_large_corpus()
? MakeSeedSentencePiecesInternal<int64>()
: MakeSeedSentencePiecesInternal<int32>();
}
// Returns seed sentencepieces for EM training.
template <typename node_int_type>
TrainerModel::SentencePieces Trainer::MakeSeedSentencePiecesInternal() {
CHECK(!sentences_.empty());
CHECK(!required_chars_.empty());
// Pretokenizer applied only in training time.
// Pretokenizer is used as a constraint of piece extractions.
const auto *pretokenizer = SentencePieceTrainer::GetPretokenizerForTraining();
auto pretokenize_or_rewrite = [&](std::pair<std::string, int64>* w) {
if (pretokenizer) {
std::vector<char32> chars;
for (const auto& w : pretokenizer->PreTokenize(w->first)) {
for (const auto& c : string_util::UTF8ToUnicodeText(w)) {
chars.push_back(c);
}
chars.push_back(kSentenceBoundary);
}
return chars;
} else if (!trainer_spec_.pretokenization_delimiter().empty()) {
// When delimiter is specified, tokenize the input with the delimiter.
// For EM training, we assume that the delimiter doesn't exist and
// rewrite the original sentence.
std::vector<char32> chars;
absl::string_view delimiter = trainer_spec_.pretokenization_delimiter();
for (const auto& w : absl::StrSplit(w->first, delimiter)) {
for (const auto& c : string_util::UTF8ToUnicodeText(w)) {
chars.push_back(c);
}
chars.push_back(kSentenceBoundary);
}
// Removes the delimiter.
w->first = absl::StrReplaceAll(w->first, {{delimiter, ""}});
return chars;
}
return string_util::UTF8ToUnicodeText(w->first);
};
// Merges all sentences into one array with 0x0000 delimiter.
std::vector<char32> array;
absl::flat_hash_map<std::string, int64> all_chars;
const bool is_tsv = trainer_spec_.input_format() == "tsv";
for (auto& w : sentences_) {
const auto ut = pretokenize_or_rewrite(&w);
for (const auto &c : ut) {
array.push_back(c);
if (c != kUNKChar && c != kSentenceBoundary) {
all_chars[string_util::UnicodeCharToUTF8(c)] += w.second;
}
}
array.push_back(kSentenceBoundary); // sentence boundary marker.
// Naive workaround to over-sample the input.
// In TSV mode, the frequency field is not used to extract the seed piece.
// we can at least extract all pieces by copying the input because
// the occurrence gets at least larger than or equals to 2.
if (is_tsv) {
for (const auto& c : ut) {
array.push_back(c);
}
array.push_back(kSentenceBoundary);
}
}
CHECK_LE(array.size(),
static_cast<size_t>(std::numeric_limits<node_int_type>::max()))
<< "Input corpus too large, try with train_extremely_large_corpus=true";
const node_int_type n = array.size();
std::vector<node_int_type> SA(n); // suffix array
std::vector<node_int_type> L(n); // left boundaries of internal node
std::vector<node_int_type> R(n); // right boundaries of internal node
std::vector<node_int_type> D(n); // depths of internal node
// Makes a suffix array to extract all sub strings occurring
// more than 2 times in the sentence.
constexpr node_int_type kAlphabetSize = 0x110000; // All UCS4 range.
node_int_type node_num = 0;
LOG(INFO) << "Making suffix array...";
CHECK_EQ(0, esaxx(array.begin(), SA.begin(), L.begin(), R.begin(), D.begin(),
n, kAlphabetSize, node_num));
LOG(INFO) << "Extracting frequent sub strings... node_num=" << node_num;
BoundedPriorityQueue<node_int_type> queue(
static_cast<size_t>(trainer_spec_.seed_sentencepiece_size()));
for (node_int_type i = 0; i < node_num; ++i) {
const node_int_type offset = SA[L[i]];
const node_int_type len = D[i];
if (len <= 1) {
continue;
}
const char32* begin = &array[offset];
const char32* end = &array[offset + len];
// Skips if a substring contains a sentence boundary.
if (std::find(begin, end, kSentenceBoundary) != end) {
continue;
}
const UnicodeText uw(begin, end);
if (!IsValidSentencePiece(uw)) {
continue;
}
// character-wise coverage is the default score.
const node_int_type freq = R[i] - L[i];
const node_int_type score = freq * len;
queue.push(i, score);
}
// all_chars must be included in the seed sentencepieces.
TrainerModel::SentencePieces seed_sentencepieces;
for (const auto &it : Sorted(all_chars)) {
seed_sentencepieces.emplace_back(it);
}
for (const auto& p : queue.get()) {
const node_int_type offset = SA[L[p.first]];
const node_int_type len = D[p.first];
CHECK_GT(len, 0);
const char32 *begin = &array[offset];
const char32 *end = &array[offset + len];
const UnicodeText uw(begin, end);
const std::string w = string_util::UnicodeTextToUTF8(uw);
CHECK(IsValidSentencePiece(uw)); // just in case.
CHECK(!port::ContainsKey(all_chars, w));
seed_sentencepieces.emplace_back(w, p.second);
}
ToLogProb(seed_sentencepieces.begin(), seed_sentencepieces.end());
LOG(INFO) << "Initialized " << seed_sentencepieces.size()
<< " seed sentencepieces";
return seed_sentencepieces;
}
std::vector<float> Trainer::RunEStep(const TrainerModel &model, float *obj,
int64 *num_tokens) const {
std::vector<std::vector<float>> expected(trainer_spec_.num_threads());
std::vector<float> objs(trainer_spec_.num_threads(), 0.0);
std::vector<int64> ntokens(trainer_spec_.num_threads(), 0.0);
auto pool = absl::make_unique<ThreadPool>(trainer_spec_.num_threads());
pool->StartWorkers();
int64 all_sentence_freq = 0;
for (const auto &w : sentences_) {
all_sentence_freq += w.second;
}
// Executes E step in parallel
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
pool->Schedule([&, n]() {
Lattice lattice;
expected[n].resize(model.GetPieceSize(), 0.0);
for (size_t i = n; i < sentences_.size();
i += trainer_spec_.num_threads()) {
const std::string &w = sentences_[i].first;
const int64 freq = sentences_[i].second;
lattice.SetSentence(w);
model.PopulateNodes(&lattice);
const float Z = lattice.PopulateMarginal(freq, &expected[n]);
ntokens[n] += lattice.Viterbi().first.size();
CHECK(!std::isnan(Z))
<< "likelihood is NAN. Input sentence may be too long";
objs[n] -= Z / all_sentence_freq;
}
});
}
pool.reset(nullptr);
// Merges expectations
for (int n = 1; n < trainer_spec_.num_threads(); ++n) {
objs[0] += objs[n];
ntokens[0] += ntokens[n];
for (size_t k = 0; k < expected[0].size(); ++k) {
expected[0][k] += expected[n][k];
}
}
*obj = objs[0];
*num_tokens = ntokens[0];
CHECK(!std::isnan(*obj));
return expected[0];
}
TrainerModel::SentencePieces Trainer::RunMStep(
const TrainerModel &model, const std::vector<float> &expected) const {
const auto &sentencepieces = model.GetSentencePieces();
CHECK_EQ(sentencepieces.size(), expected.size());
TrainerModel::SentencePieces new_sentencepieces;
float sum = 0.0;
for (size_t i = 0; i < expected.size(); ++i) {
const float freq = expected[i];
// Filter infrequent sentencepieces here.
constexpr float kExpectedFrequencyThreshold = 0.5;
if (freq < kExpectedFrequencyThreshold) {
continue;
}
new_sentencepieces.emplace_back(sentencepieces[i].first, freq);
sum += freq;
}
// Here we do not use the original EM, but use the
// Bayesianified/DPified EM algorithm.
// https://cs.stanford.edu/~pliang/papers/tutorial-acl2007-talk.pdf
// This modification will act as a sparse prior.
const float logsum = Digamma(sum);
for (auto &w : new_sentencepieces) {
w.second = Digamma(w.second) - logsum;
}
return new_sentencepieces;
}
TrainerModel::SentencePieces Trainer::PruneSentencePieces(
const TrainerModel &model) const {
const auto &sentencepieces = model.GetSentencePieces();
Lattice lattice;
std::vector<bool> always_keep(sentencepieces.size(), true);
std::vector<std::vector<int>> alternatives(sentencepieces.size());
// First, segments the current sentencepieces to know
// how each sentencepiece is resegmented if this sentencepiece is removed
// from the vocabulary.
// To do so, we take the second best segmentation of sentencepiece[i].
// alternatives[i] stores the sequence of second best sentencepieces.
for (size_t i = 0; i < sentencepieces.size(); ++i) {
const auto &w = sentencepieces[i];
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
const auto nbests = lattice.NBest(2, false, 0.0);
if (nbests.size() == 1) {
// No second-best result is found. always keep this sentencepiece.
always_keep[i] = true;
continue;
} else if (nbests[0].first.size() >= 2) {
// Can safely remove this sentencepiece if its Viterbi path is split.
always_keep[i] = false;
} else if (nbests[0].first.size() == 1) {
always_keep[i] = true;
for (const auto* node : nbests[1].first) {
alternatives[i].push_back(node->id);
}
}
}
// Second, segments all sentences to compute likelihood
// with a unigram language model. inverted[i] stores
// the set of sentence index where the sentencepieces[i] appears.
float vsum = 0.0;
std::vector<float> freq(sentencepieces.size(), 0.0);
std::vector<std::vector<int>> inverted(sentencepieces.size());
{
std::vector<float> vsums(trainer_spec_.num_threads(), 0.0);
std::vector<std::vector<float>> freqs(trainer_spec_.num_threads());
std::vector<std::vector<std::vector<int>>> inverteds(
trainer_spec_.num_threads());
auto pool = absl::make_unique<ThreadPool>(trainer_spec_.num_threads());
pool->StartWorkers();
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
freqs[n].resize(sentencepieces.size(), 0.0);
inverteds[n].resize(sentencepieces.size());
pool->Schedule([&, n]() {
Lattice lattice;
for (size_t i = n; i < sentences_.size();
i += trainer_spec_.num_threads()) {
const auto &w = sentences_[i];
lattice.SetSentence(w.first);
model.PopulateNodes(&lattice);
vsums[n] += w.second;
for (const auto* node : lattice.Viterbi().first) {
if (node->id >= 0) {
freqs[n][node->id] += w.second;
inverteds[n][node->id].push_back(i);
}
}
}
});
}
pool.reset(nullptr);
for (int n = 0; n < trainer_spec_.num_threads(); ++n) {
vsum += vsums[n];
for (size_t i = 0; i < sentencepieces.size(); ++i) {
freq[i] += freqs[n][i];
std::copy(inverteds[n][i].begin(), inverteds[n][i].end(),
std::back_inserter(inverted[i]));
}
}
}
const float sum = std::accumulate(freq.begin(), freq.end(), 0.0);
const float logsum = std::log(static_cast<double>(sum));
std::vector<std::pair<int, float>> candidates;
TrainerModel::SentencePieces new_sentencepieces;
// Finally, computes how likely the LM likelihood is reduced if
// the sentencepiece[i] is removed from the vocabulary.
// Since the exact computation of loss is difficult, we compute the
// loss approximately by assuming that all sentencepiece[i] in the sentences
// are replaced with alternatives[i] when sentencepiece[i] is removed.
for (size_t i = 0; i < sentencepieces.size(); ++i) {
if (freq[i] == 0 || !always_keep[i]) {
// not found in Viterbi path. Can remove this entry safely.
continue;
} else if (alternatives[i].empty()) {
// no alternatives. Keeps this entry.
new_sentencepieces.push_back(sentencepieces[i]);
} else {
float F = 0.0; // the frequency of sentencepieces[i].
for (const int n : inverted[i]) {
F += sentences_[n].second;
}
F /= vsum; // normalizes by all sentence frequency.
// The logprob with the sentencepiece[i].
const float logprob_sp = std::log(static_cast<double>(freq[i])) - logsum;
// After removing the sentencepiece[i], its frequency freq[i] is
// re-assigned to alternatives.
// new_sum = current_sum - freq[i] + freq[i] * alternatives[i].size()
// = current_sum + freq[i] * (alternatives[i] - 1)
const float logsum_alt = std::log(
static_cast<double>(sum + freq[i] * (alternatives[i].size() - 1)));
// The frequencies of altenatives are increased by freq[i].
float logprob_alt = 0.0;
for (const int n : alternatives[i]) {
logprob_alt +=
(std::log(static_cast<double>(freq[n] + freq[i])) - logsum_alt);
}
// loss: the diff of likelihood after removing the sentencepieces[i].
const float loss = F * (logprob_sp - logprob_alt);
candidates.emplace_back(i, loss);
}
}
const int pruned_size =
std::max<int>(desired_vocab_size_,
trainer_spec_.shrinking_factor() * sentencepieces.size());
// Keeps trainer_spec_.shrinking_factor * sentencepieces.size() pieces.
// shrinking_factor is 0.75 by default.
for (const auto &w : Sorted(candidates)) {
if (new_sentencepieces.size() == static_cast<size_t>(pruned_size)) {
break;
}
new_sentencepieces.emplace_back(sentencepieces[w.first]);
}
return new_sentencepieces;
}
TrainerModel::SentencePieces Trainer::FinalizeSentencePieces(
const TrainerModel &model) const {
const auto &sentencepieces = model.GetSentencePieces();
absl::flat_hash_map<std::string, float> final_sentencepieces;
absl::flat_hash_map<std::string, float> sp(sentencepieces.begin(),
sentencepieces.end());
// required_chars_ must be included in the final sentencepieces.
float min_score_penalty = 0.0;
constexpr float kMinScorePenaltyDelta = 0.0001;
for (const auto &w : Sorted(required_chars_)) {
const std::string s = string_util::UnicodeCharToUTF8(w.first);
if (port::ContainsKey(sp, s)) {
final_sentencepieces[s] = sp[s];
} else {
// Add penalty to avoid required pieces from having the same score.
// Since the required_chars_ is sorted, frequent pieces have
// less penalties.
final_sentencepieces[s] = model.min_score() + min_score_penalty;
min_score_penalty += kMinScorePenaltyDelta;
}
}
const int vocab_size_size = trainer_spec_.vocab_size() - meta_pieces_.size();
CHECK_GT(vocab_size_size, 0);
// Then keeps sentencepieces with higher scores.
for (const auto &w : Sorted(sentencepieces)) {
if (port::ContainsKey(final_sentencepieces, w.first)) {
continue;
}
if (static_cast<size_t>(vocab_size_size) == final_sentencepieces.size()) {
break;
}
final_sentencepieces[w.first] = w.second;
}
return Sorted(final_sentencepieces);
}
util::Status Trainer::Train() {
RETURN_IF_ERROR(status());
CHECK_EQ_OR_RETURN(TrainerSpec::UNIGRAM, trainer_spec_.model_type());
CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
TrainerModel model(trainer_spec_, normalizer_spec_);
RETURN_IF_ERROR(model.status());
RETURN_IF_ERROR(LoadSentences());
auto seed_sentencepieces = MakeSeedSentencePieces();
model.SetSentencePieces(std::move(seed_sentencepieces));
if (trainer_spec_.split_by_whitespace()) {
SplitSentencesByWhitespace();
}
LOG(INFO) << "Using " << sentences_.size() << " sentences for EM training";
desired_vocab_size_ = static_cast<size_t>(trainer_spec_.vocab_size() * 1.1);
while (true) {
// Sub-EM iteration.
for (int iter = 0; iter < trainer_spec_.num_sub_iterations(); ++iter) {
// Executes E step
float objective = 0.0;
int64 num_tokens = 0;
const auto expected = RunEStep(model, &objective, &num_tokens);
// Executes M step.
auto new_sentencepieces = RunMStep(model, expected);
model.SetSentencePieces(std::move(new_sentencepieces));
LOG(INFO) << "EM sub_iter=" << iter << " size=" << model.GetPieceSize()
<< " obj=" << objective << " num_tokens=" << num_tokens
<< " num_tokens/piece="
<< 1.0 * num_tokens / model.GetPieceSize();
} // end of Sub EM iteration
// Stops the iteration when the size of sentences reaches to the
// desired symbol size.
if (model.GetPieceSize() <= desired_vocab_size_) {
break;
}
// Prunes pieces.
auto new_sentencepieces = PruneSentencePieces(model);
model.SetSentencePieces(std::move(new_sentencepieces));
} // end of EM iteration
// Finally, adjusts the size of sentencepices to be |vocab_size|.
final_pieces_ = FinalizeSentencePieces(model);
return Save();
}
} // namespace unigram
} // namespace sentencepiece