// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "unigram_model.h"
#include <cmath>
#include <map>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
namespace unigram {
TEST(LatticeTest, SetSentenceTest) {
Lattice lattice;
EXPECT_EQ(0, lattice.size());
EXPECT_EQ(0, lattice.utf8_size());
lattice.SetSentence("");
EXPECT_EQ(0, lattice.size());
EXPECT_EQ(0, lattice.utf8_size());
EXPECT_STREQ("", lattice.sentence());
EXPECT_STREQ("", lattice.surface(0));
lattice.SetSentence("test");
EXPECT_EQ(4, lattice.size());
EXPECT_EQ(4, lattice.utf8_size());
EXPECT_STREQ("test", lattice.sentence());
EXPECT_STREQ("test", lattice.surface(0));
EXPECT_STREQ("est", lattice.surface(1));
EXPECT_STREQ("st", lattice.surface(2));
EXPECT_STREQ("t", lattice.surface(3));
Lattice::Node *bos = lattice.bos_node();
Lattice::Node *eos = lattice.eos_node();
EXPECT_EQ(-1, bos->id);
EXPECT_EQ(-1, eos->id);
EXPECT_EQ(bos, lattice.end_nodes(0).front());
EXPECT_EQ(eos, lattice.begin_nodes(4).front());
lattice.SetSentence("テストab");
EXPECT_EQ(5, lattice.size());
EXPECT_EQ(11, lattice.utf8_size());
EXPECT_STREQ("テストab", lattice.sentence());
EXPECT_STREQ("テストab", lattice.surface(0));
EXPECT_STREQ("ストab", lattice.surface(1));
EXPECT_STREQ("トab", lattice.surface(2));
EXPECT_STREQ("ab", lattice.surface(3));
EXPECT_STREQ("b", lattice.surface(4));
lattice.Clear();
EXPECT_EQ(0, lattice.size());
EXPECT_EQ(0, lattice.utf8_size());
}
TEST(LatticeTest, InsertTest) {
Lattice lattice;
lattice.SetSentence("ABあい");
Lattice::Node *node[7];
node[0] = lattice.Insert(0, 1);
node[1] = lattice.Insert(1, 1);
node[2] = lattice.Insert(2, 1);
node[3] = lattice.Insert(3, 1);
node[4] = lattice.Insert(0, 2);
node[5] = lattice.Insert(1, 2);
node[6] = lattice.Insert(2, 2);
EXPECT_EQ("A", node[0]->piece);
EXPECT_EQ("B", node[1]->piece);
EXPECT_EQ("あ", node[2]->piece);
EXPECT_EQ("い", node[3]->piece);
EXPECT_EQ("AB", node[4]->piece);
EXPECT_EQ("Bあ", node[5]->piece);
EXPECT_EQ("あい", node[6]->piece);
EXPECT_EQ("A", node[0]->piece);
EXPECT_EQ("B", node[1]->piece);
EXPECT_EQ("あ", node[2]->piece);
EXPECT_EQ("い", node[3]->piece);
EXPECT_EQ("AB", node[4]->piece);
EXPECT_EQ("Bあ", node[5]->piece);
EXPECT_EQ("あい", node[6]->piece);
EXPECT_EQ(0, node[0]->pos);
EXPECT_EQ(1, node[1]->pos);
EXPECT_EQ(2, node[2]->pos);
EXPECT_EQ(3, node[3]->pos);
EXPECT_EQ(0, node[4]->pos);
EXPECT_EQ(1, node[5]->pos);
EXPECT_EQ(2, node[6]->pos);
EXPECT_EQ(1, node[0]->length);
EXPECT_EQ(1, node[1]->length);
EXPECT_EQ(1, node[2]->length);
EXPECT_EQ(1, node[3]->length);
EXPECT_EQ(2, node[4]->length);
EXPECT_EQ(2, node[5]->length);
EXPECT_EQ(2, node[6]->length);
EXPECT_EQ(0, lattice.bos_node()->node_id);
EXPECT_EQ(1, lattice.eos_node()->node_id);
EXPECT_EQ(2, node[0]->node_id);
EXPECT_EQ(3, node[1]->node_id);
EXPECT_EQ(4, node[2]->node_id);
EXPECT_EQ(5, node[3]->node_id);
EXPECT_EQ(6, node[4]->node_id);
EXPECT_EQ(7, node[5]->node_id);
EXPECT_EQ(8, node[6]->node_id);
EXPECT_EQ(2, lattice.begin_nodes(0).size());
EXPECT_EQ(2, lattice.begin_nodes(1).size());
EXPECT_EQ(2, lattice.begin_nodes(2).size());
EXPECT_EQ(1, lattice.begin_nodes(3).size());
EXPECT_EQ(1, lattice.begin_nodes(4).size()); // EOS
EXPECT_EQ(1, lattice.end_nodes(0).size()); // BOS
EXPECT_EQ(1, lattice.end_nodes(1).size());
EXPECT_EQ(2, lattice.end_nodes(2).size());
EXPECT_EQ(2, lattice.end_nodes(3).size());
EXPECT_EQ(2, lattice.end_nodes(4).size());
EXPECT_EQ(node[0], lattice.begin_nodes(0)[0]);
EXPECT_EQ(node[4], lattice.begin_nodes(0)[1]);
EXPECT_EQ(node[1], lattice.begin_nodes(1)[0]);
EXPECT_EQ(node[5], lattice.begin_nodes(1)[1]);
EXPECT_EQ(node[2], lattice.begin_nodes(2)[0]);
EXPECT_EQ(node[6], lattice.begin_nodes(2)[1]);
EXPECT_EQ(node[3], lattice.begin_nodes(3)[0]);
EXPECT_EQ(lattice.eos_node(), lattice.begin_nodes(4)[0]);
EXPECT_EQ(lattice.bos_node(), lattice.end_nodes(0)[0]);
EXPECT_EQ(node[0], lattice.end_nodes(1)[0]);
EXPECT_EQ(node[1], lattice.end_nodes(2)[0]);
EXPECT_EQ(node[4], lattice.end_nodes(2)[1]);
EXPECT_EQ(node[2], lattice.end_nodes(3)[0]);
EXPECT_EQ(node[5], lattice.end_nodes(3)[1]);
EXPECT_EQ(node[3], lattice.end_nodes(4)[0]);
EXPECT_EQ(node[6], lattice.end_nodes(4)[1]);
}
TEST(LatticeTest, ViterbiFromIncompleteLatticeTest) {
Lattice lattice;
lattice.SetSentence("ABC");
EXPECT_TRUE(lattice.Viterbi().first.empty());
// Still incomplete
lattice.Insert(0, 1);
EXPECT_TRUE(lattice.Viterbi().first.empty());
lattice.Insert(1, 1);
lattice.Insert(2, 1);
lattice.Viterbi();
}
std::string GetTokenized(const std::vector<Lattice::Node *> &nodes) {
std::vector<std::string> tokens;
for (auto *node : nodes) {
tokens.push_back(std::string(node->piece));
}
return absl::StrJoin(tokens, " ");
}
void InsertWithScore(Lattice *lattice, int pos, int length, float score) {
lattice->Insert(pos, length)->score = score;
}
void InsertWithScoreAndId(Lattice *lattice, int pos, int length, float score,
int id) {
auto *node = lattice->Insert(pos, length);
node->score = score;
node->id = id;
}
TEST(LatticeTest, ViterbiTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.0); // C
EXPECT_EQ("A B C", GetTokenized(lattice.Viterbi().first));
InsertWithScore(&lattice, 0, 2, 2.0); // AB
EXPECT_EQ("AB C", GetTokenized(lattice.Viterbi().first));
InsertWithScore(&lattice, 1, 2, 5.0); // BC
EXPECT_EQ("A BC", GetTokenized(lattice.Viterbi().first));
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
EXPECT_EQ("ABC", GetTokenized(lattice.Viterbi().first));
}
TEST(LatticeTest, NBestTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.0); // C
InsertWithScore(&lattice, 0, 2, 2.0); // AB
InsertWithScore(&lattice, 1, 2, 5.0); // BC
InsertWithScore(&lattice, 0, 3, 10.0); // ABC
auto nbests = lattice.NBest(10, false, 0.0);
EXPECT_EQ(4, nbests.size());
EXPECT_EQ("ABC", GetTokenized(nbests[0].first));
EXPECT_EQ("A BC", GetTokenized(nbests[1].first));
EXPECT_EQ("AB C", GetTokenized(nbests[2].first));
EXPECT_EQ("A B C", GetTokenized(nbests[3].first));
auto nbests0 = lattice.NBest(0, false, 0.0);
EXPECT_TRUE(nbests0.empty());
auto nbests1 = lattice.NBest(1, false, 0.0);
EXPECT_EQ(nbests1.size(), 1);
}
TEST(LatticeTest, NBestSampleTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
// Calculate expected probabilities of each path
// Note that sampling without replacement affects the expected frequencies!
const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(inv_theta * 1.0);
probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
for (const auto& it : strings) {
EXPECT_EQ(1, probs.count(it));
}
double Z = 0.0;
for (const auto& it : probs) {
Z += it.second;
}
for (auto& it : probs) {
it.second /= Z;
}
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto& first : strings) {
for (const auto& second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
float first_prob = probs[first];
float second_prob = probs[second] / (1 - first_prob);
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
}
}
}
std::map<std::string, float> inclusion_probs;
for (const auto& string : strings) {
float inclusion_prob = 0.0;
for (const auto& other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto& other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;
}
int kTrials = 10000;
std::vector<int> kNumSamples = {1, 2};
for (const auto num_samples : kNumSamples) {
std::map<std::string, int> counts;
for (int i = 0; i < kTrials; i++) {
auto nbests = lattice.NBest(num_samples, true, inv_theta);
for (const auto& nbest : nbests) {
counts[GetTokenized(nbest.first)]++;
}
}
EXPECT_EQ(inclusion_probs.size(), counts.size());
// If we take multiple samples WOR, we have to use corrected probs.
std::map<std::string, float> probs_to_use =
(num_samples == 1 ? probs : inclusion_probs);
for (const auto& it : probs_to_use) {
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
0.02);
}
}
}
}
TEST(LatticeTest, CalculateEntropyTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
// Calculate expected probabilities of each path
const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(inv_theta * 1.0);
probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
double Z = 0.0;
for (const auto& it : probs) {
Z += it.second;
}
for (auto& it : probs) {
it.second /= Z;
}
for (const auto& it : strings) {
EXPECT_EQ(1, probs.count(it));
}
float entropy = 0.0;
for (const auto& it : probs) {
entropy += (it.second * std::log(it.second));
}
EXPECT_NEAR(-entropy, lattice.CalculateEntropy(inv_theta), 0.02);
}
}
TEST(LatticeTest, ForwardAlgorithmTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScore(&lattice, 0, 1, 0.0); // A
InsertWithScore(&lattice, 1, 1, 0.0); // B
InsertWithScore(&lattice, 2, 1, 0.1); // C
InsertWithScore(&lattice, 0, 2, 0.2); // AB
InsertWithScore(&lattice, 1, 2, 0.5); // BC
InsertWithScore(&lattice, 0, 3, 1.0); // ABC
const std::vector<float> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (const auto inv_theta : kInv_Theta) {
std::vector<float> alpha = lattice.ForwardAlgorithm(inv_theta);
EXPECT_EQ(alpha.size(), 8); // 6 nodes, plus BOS, EOS
// only alpha[C], alpha[EOS] have non-zero alpha
for (int i : {0, 1, 2, 3}) {
for (const auto& node : lattice.begin_nodes(i)) {
if (i < 2) {
EXPECT_EQ(alpha[node->node_id], 0.0);
} else if (i == 2) {
float Z = std::log(std::exp(inv_theta * (0.0 + 0.0)) +
std::exp(inv_theta * 0.2));
EXPECT_EQ(alpha[node->node_id], Z);
} else if (i == 3) {
float Z =
std::log(std::exp(inv_theta * (0.0 + 0.0 + 0.1)) + // A + B + C
std::exp(inv_theta * (0.2 + 0.1)) + // AB + C
std::exp(inv_theta * (0.0 + 0.5)) + // A + BC
std::exp(inv_theta * 1.0)); // ABC
EXPECT_EQ(Z, alpha[node->node_id]);
}
}
}
}
}
TEST(LatticeTest, PopulateMarginalTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScoreAndId(&lattice, 0, 1, 1.0, 0); // A
InsertWithScoreAndId(&lattice, 1, 1, 1.2, 1); // B
InsertWithScoreAndId(&lattice, 2, 1, 2.5, 2); // C
InsertWithScoreAndId(&lattice, 0, 2, 3.0, 3); // AB
InsertWithScoreAndId(&lattice, 1, 2, 4.0, 4); // BC
InsertWithScoreAndId(&lattice, 0, 3, 2.0, 5); // ABC
std::vector<float> probs(6, 0.0);
// Expand all paths:
// A B C : exp(1.0 + 1.2 + 2.5) => path1
// AB C : exp(3.0 + 2.5) => path2
// A BC : exp(1.0 + 4.0) => path3
// ABC : exp(2.0) => path4
const float p1 = exp(1.0 + 1.2 + 2.5);
const float p2 = exp(3.0 + 2.5);
const float p3 = exp(1.0 + 4.0);
const float p4 = exp(2.0);
const float Z = p1 + p2 + p3 + p4;
const float logZ = lattice.PopulateMarginal(1.0, &probs);
EXPECT_NEAR((p1 + p3) / Z, probs[0], 0.001); // A
EXPECT_NEAR(p1 / Z, probs[1], 0.001); // B
EXPECT_NEAR((p1 + p2) / Z, probs[2], 0.001); // C
EXPECT_NEAR(p2 / Z, probs[3], 0.001); // AB
EXPECT_NEAR(p3 / Z, probs[4], 0.001); // BC
EXPECT_NEAR(p4 / Z, probs[5], 0.001); // ABC
EXPECT_NEAR(std::log(static_cast<double>(Z)), logZ, 0.001);
}
TEST(LatticeTest, SampleTest) {
Lattice lattice;
lattice.SetSentence("ABC");
InsertWithScoreAndId(&lattice, 0, 1, 1.0, 0); // A
InsertWithScoreAndId(&lattice, 1, 1, 1.2, 1); // B
InsertWithScoreAndId(&lattice, 2, 1, 1.5, 2); // C
InsertWithScoreAndId(&lattice, 0, 2, 1.6, 3); // AB
InsertWithScoreAndId(&lattice, 1, 2, 1.7, 4); // BC
InsertWithScoreAndId(&lattice, 0, 3, 1.8, 5); // ABC
const std::vector<double> kInv_Theta = {0.0, 0.01, 0.5, 0.7, 1.0};
for (int i = 0; i < kInv_Theta.size(); ++i) {
std::map<std::string, double> probs;
// Expands all paths in the lattice.
probs["A B C"] = exp(kInv_Theta[i] * (1.0 + 1.2 + 1.5)); // A B C
probs["AB C"] = exp(kInv_Theta[i] * (1.6 + 1.5)); // AB C
probs["A BC"] = exp(kInv_Theta[i] * (1.0 + 1.7)); // A BC
probs["ABC"] = exp(kInv_Theta[i] * 1.8); // ABC
// Computes expected probabilities.
double Z = 0.0;
for (const auto &it : probs) Z += it.second;
for (auto &it : probs) it.second /= Z;
// Samples `kTrial` times and verifies the probabilities.
constexpr int kTrial = 100000;
std::map<std::string, int> freq;
for (int n = 0; n < kTrial; ++n) {
freq[GetTokenized(lattice.Sample(kInv_Theta[i]))]++;
}
EXPECT_EQ(probs.size(), freq.size());
for (const auto &it : probs) {
EXPECT_NEAR(it.second, 1.0 * freq[it.first] / kTrial, 0.02);
}
}
}
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
// Returns model protos in parameterized tests.
const std::vector<Model::EncoderVersion>& GetEncoderVersions() {
static const std::vector<Model::EncoderVersion>& v =
*new std::vector<Model::EncoderVersion>{Model::kOptimized,
Model::kOriginal};
return v;
}
class UnigramModelTest : public test::TestWithParam<Model::EncoderVersion> {
protected:
void SetUp() override { encoder_version_ = GetParam(); }
void TearDown() override {}
Model::EncoderVersion encoder_version_;
};
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(UnigramModelTest, SetUnigramModelTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a");
AddPiece(&model_proto, "b");
AddPiece(&model_proto, "c");
AddPiece(&model_proto, "d");
const Model model(model_proto);
EXPECT_EQ(model_proto.SerializeAsString(),
model.model_proto().SerializeAsString());
}
TEST(UnigramModelTest, SampleEncodeAndScoreTest) {
// Test whether inclusion probabilities are correct
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "A", 0.0); // 3
AddPiece(&model_proto, "B", 0.0); // 4
AddPiece(&model_proto, "C", 0.1); // 5
AddPiece(&model_proto, "AB", 0.2); // 6
AddPiece(&model_proto, "BC", 0.5); // 7
AddPiece(&model_proto, "ABC", 1.0); // 8
Model model(model_proto);
Lattice lattice;
lattice.SetSentence("ABC");
model.PopulateNodes(&lattice);
std::vector<float> kInv_Theta = {0.0, 1.0};
for (const auto inv_theta : kInv_Theta) {
std::vector<std::string> strings = {"ABC", "AB C", "A BC", "A B C"};
std::map<std::string, float> probs;
probs["ABC"] = std::exp(inv_theta * 1.0);
probs["AB C"] = std::exp(inv_theta * (0.2 + 0.1));
probs["A BC"] = std::exp(inv_theta * (0.0 + 0.5));
probs["A B C"] = std::exp(inv_theta * (0.0 + 0.0 + 0.1));
for (const auto& it : strings) {
EXPECT_EQ(1, probs.count(it));
}
double Z = 0.0;
for (const auto& it : probs) {
Z += it.second;
}
for (auto& it : probs) {
it.second /= Z;
}
std::map<std::pair<std::string, std::string>, float> pair_probs;
for (const auto& first : strings) {
for (const auto& second : strings) {
if (first == second) {
pair_probs[std::make_pair(first, second)] = 0;
} else {
const float first_prob = probs[first];
const float second_prob = probs[second] / (1 - first_prob);
pair_probs[std::make_pair(first, second)] = first_prob * second_prob;
}
}
}
std::map<std::string, float> inclusion_probs;
for (const auto& string : strings) {
float inclusion_prob = 0.0;
for (const auto& other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(string, other_string)];
}
for (const auto& other_string : strings) {
inclusion_prob += pair_probs[std::make_pair(other_string, string)];
}
inclusion_probs[string] = inclusion_prob / 2;
}
std::vector<int> kNumSamples = {1, 2};
for (const auto num_samples : kNumSamples) {
std::map<std::string, int> counts;
std::map<std::string, float> scores;
int kTrials = 50000;
for (int i = 0; i < kTrials; i++) {
NBestEncodeResult sample = model.SampleEncodeAndScore(
"ABC", inv_theta, num_samples, true, false);
for (const auto& it : sample) {
std::vector<std::string> tokens;
for (const auto& inner_it : it.first) {
tokens.push_back(std::string(inner_it.first));
}
std::string sample_string = absl::StrJoin(tokens, " ");
counts[sample_string] += 1;
// use the fact that E(1_{i in sample} / score of i) = 1
// see https://arxiv.org/pdf/1903.06059.pdf appendix D
scores[sample_string] += std::exp(-it.second);
}
}
// Check that counts and probs are correct
std::map<std::string, float> probs_to_use =
(num_samples == 1 ? probs : inclusion_probs);
for (const auto& it : scores) {
Z += it.second;
}
for (const auto& it : probs_to_use) {
EXPECT_NEAR(it.second, 1.0 * counts[it.first] / (kTrials * num_samples),
0.02);
// The expectation is quite loose, use a higher tolerance
EXPECT_NEAR(1.0, scores[it.first] / kTrials, 0.30);
}
}
}
}
TEST_P(UnigramModelTest, PieceToIdTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.1);
AddPiece(&model_proto, "b", 0.2);
AddPiece(&model_proto, "c", 0.3);
AddPiece(&model_proto, "d", 0.4);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
EXPECT_EQ(model_proto.SerializeAsString(),
model.model_proto().SerializeAsString());
EXPECT_NEAR(0.1, model.min_score(), 0.001);
EXPECT_NEAR(0.4, model.max_score(), 0.001);
EXPECT_EQ(0, model.PieceToId("<unk>"));
EXPECT_EQ(1, model.PieceToId("<s>"));
EXPECT_EQ(2, model.PieceToId("</s>"));
EXPECT_EQ(3, model.PieceToId("a"));
EXPECT_EQ(3, model.PieceToId(absl::string_view("a b", 1)));
EXPECT_EQ(4, model.PieceToId("b"));
EXPECT_EQ(5, model.PieceToId("c"));
EXPECT_EQ(6, model.PieceToId("d"));
EXPECT_EQ(0, model.PieceToId("e")); // unk
EXPECT_EQ(0, model.PieceToId("")); // unk
EXPECT_EQ("<unk>", model.IdToPiece(0));
EXPECT_EQ("<s>", model.IdToPiece(1));
EXPECT_EQ("</s>", model.IdToPiece(2));
EXPECT_EQ("a", model.IdToPiece(3));
EXPECT_EQ("b", model.IdToPiece(4));
EXPECT_EQ("c", model.IdToPiece(5));
EXPECT_EQ("d", model.IdToPiece(6));
EXPECT_TRUE(model.IsUnknown(0));
EXPECT_FALSE(model.IsUnknown(1));
EXPECT_FALSE(model.IsUnknown(2));
EXPECT_FALSE(model.IsUnknown(3));
EXPECT_FALSE(model.IsUnknown(4));
EXPECT_FALSE(model.IsUnknown(5));
EXPECT_FALSE(model.IsUnknown(6));
EXPECT_FALSE(model.IsControl(0));
EXPECT_TRUE(model.IsControl(1));
EXPECT_TRUE(model.IsControl(2));
EXPECT_FALSE(model.IsControl(3));
EXPECT_FALSE(model.IsControl(4));
EXPECT_FALSE(model.IsControl(5));
EXPECT_FALSE(model.IsControl(6));
EXPECT_NEAR(0, model.GetScore(0), 0.0001);
EXPECT_NEAR(0, model.GetScore(1), 0.0001);
EXPECT_NEAR(0, model.GetScore(2), 0.0001);
EXPECT_NEAR(0.1, model.GetScore(3), 0.0001);
EXPECT_NEAR(0.2, model.GetScore(4), 0.0001);
EXPECT_NEAR(0.3, model.GetScore(5), 0.0001);
EXPECT_NEAR(0.4, model.GetScore(6), 0.0001);
EXPECT_TRUE(model.Encode("").empty());
}
TEST_P(UnigramModelTest, PopulateNodesAllUnknownsTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "x");
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
model.PopulateNodes(&lattice);
EXPECT_EQ(1, lattice.begin_nodes(0).size());
EXPECT_EQ(1, lattice.begin_nodes(1).size());
EXPECT_EQ(1, lattice.begin_nodes(2).size());
EXPECT_EQ(0, lattice.begin_nodes(0)[0]->id);
EXPECT_EQ(0, lattice.begin_nodes(1)[0]->id);
EXPECT_EQ(0, lattice.begin_nodes(2)[0]->id);
}
TEST_P(UnigramModelTest, PopulateNodesTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.1); // 3
AddPiece(&model_proto, "b", 0.2); // 4
AddPiece(&model_proto, "ab", 0.3); // 5
AddPiece(&model_proto, "bc", 0.4); // 6
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
model.PopulateNodes(&lattice);
EXPECT_EQ(2, lattice.begin_nodes(0).size()); // a,ab
EXPECT_EQ(2, lattice.begin_nodes(1).size()); // b,bc
EXPECT_EQ(1, lattice.begin_nodes(2).size()); // c(unk)
EXPECT_EQ(3, lattice.begin_nodes(0)[0]->id);
EXPECT_EQ(5, lattice.begin_nodes(0)[1]->id);
EXPECT_EQ(4, lattice.begin_nodes(1)[0]->id);
EXPECT_EQ(6, lattice.begin_nodes(1)[1]->id);
EXPECT_EQ(0, lattice.begin_nodes(2)[0]->id);
EXPECT_NEAR(0.1, lattice.begin_nodes(0)[0]->score, 0.001);
EXPECT_NEAR(0.3, lattice.begin_nodes(0)[1]->score, 0.001);
EXPECT_NEAR(0.2, lattice.begin_nodes(1)[0]->score, 0.001);
EXPECT_NEAR(0.4, lattice.begin_nodes(1)[1]->score, 0.001);
}
TEST_P(UnigramModelTest, PopulateNodesWithUnusedTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.1); // 3
AddPiece(&model_proto, "b", 0.2); // 4
AddPiece(&model_proto, "ab", 0.3); // 5
AddPiece(&model_proto, "bc", 0.4); // 6
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
Lattice lattice;
lattice.SetSentence("abc");
model.PopulateNodes(&lattice);
EXPECT_EQ(1, lattice.begin_nodes(0).size()); // a
EXPECT_EQ(1, lattice.begin_nodes(1).size()); // b
EXPECT_EQ(1, lattice.begin_nodes(2).size()); // c(unk)
EXPECT_EQ(3, lattice.begin_nodes(0)[0]->id);
EXPECT_EQ(4, lattice.begin_nodes(1)[0]->id);
EXPECT_EQ(0, lattice.begin_nodes(2)[0]->id);
}
TEST_P(UnigramModelTest, ModelNBestTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "a", 0.0); // 3
AddPiece(&model_proto, "b", 0.0); // 4
AddPiece(&model_proto, "c", 0.0); // 5
AddPiece(&model_proto, "ab", 2.0); // 6
AddPiece(&model_proto, "bc", 5.0); // 7
AddPiece(&model_proto, "abc", 10.0); // 8
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
auto nbest = model.NBestEncode("", 10);
EXPECT_EQ(1, nbest.size());
EXPECT_TRUE(nbest[0].first.empty());
EXPECT_EQ(0.0, nbest[0].second);
nbest = model.NBestEncode("abc", 10);
EXPECT_EQ(4, nbest.size());
auto sample = model.SampleEncode("", 0.1);
EXPECT_EQ(0, sample.size());
sample = model.SampleEncode("abc", 0.1);
EXPECT_FALSE(sample.empty());
}
TEST_P(UnigramModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "ab", 0.0); // 3
AddPiece(&model_proto, "cd", -0.1); // 4
AddPiece(&model_proto, "abc", -0.2); // 5
AddPiece(&model_proto, "a", -0.3); // 6
AddPiece(&model_proto, "b", -0.4); // 7
AddPiece(&model_proto, "c", -0.5); // 8
AddPiece(&model_proto, "ABC", -0.5); // 9
AddPiece(&model_proto, "abcdabcd", -0.5); // 10
AddPiece(&model_proto, "q", -0.5); // 11
AddPiece(&model_proto, "r", -0.5); // 12
AddPiece(&model_proto, "qr", -0.5); // 13
model_proto.mutable_pieces(9)->set_type( // ABC
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(10)->set_type( // abcdabcd
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(11)->set_type( // q
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(12)->set_type( // r
ModelProto::SentencePiece::USER_DEFINED);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
EncodeResult result;
result = model.Encode("abc");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abc", result[0].first);
result = model.Encode("AB");
EXPECT_EQ(2, result.size());
EXPECT_EQ("A", result[0].first);
EXPECT_EQ("B", result[1].first);
result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("cd", result[1].first);
result = model.Encode("abcc");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("c", result[1].first);
result = model.Encode("xabcabaabcdd");
EXPECT_EQ(7, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("abc", result[1].first);
EXPECT_EQ("ab", result[2].first);
EXPECT_EQ("a", result[3].first);
EXPECT_EQ("ab", result[4].first);
EXPECT_EQ("cd", result[5].first);
EXPECT_EQ("d", result[6].first);
// all unknown.
result = model.Encode("xyz東京");
EXPECT_EQ(5, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("y", result[1].first);
EXPECT_EQ("z", result[2].first);
EXPECT_EQ("東", result[3].first);
EXPECT_EQ("京", result[4].first);
// User defined
result = model.Encode("ABC");
EXPECT_EQ(1, result.size());
EXPECT_EQ("ABC", result[0].first);
result = model.Encode("abABCcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("ABC", result[1].first);
EXPECT_EQ("cd", result[2].first);
// middle "abcdabcd" is user defined.
result = model.Encode("ababcdabcdcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("abcdabcd", result[1].first);
EXPECT_EQ("cd", result[2].first);
result = model.Encode("abqrcd");
EXPECT_EQ(4, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("q", result[1].first);
EXPECT_EQ("r", result[2].first);
EXPECT_EQ("cd", result[3].first);
}
TEST_P(UnigramModelTest, EncodeWithUnusedTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "abcd", 10.0); // 3
AddPiece(&model_proto, "abc", 5.0); // 4
AddPiece(&model_proto, "ab", 2.0); // 5
AddPiece(&model_proto, "cd", 1.0); // 6
AddPiece(&model_proto, "a", 0.0); // 7
AddPiece(&model_proto, "b", 0.0); // 8
AddPiece(&model_proto, "c", 0.0); // 9
AddPiece(&model_proto, "d", 0.0); // 10
// No unused.
{
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
}
{
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("d", result[1].first);
}
{
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("d", result[1].first);
}
{
// This is different from BPE segmentation.
// Unigram language model simply finds the best path without unused nodes.
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("cd", result[1].first);
}
}
TEST_P(UnigramModelTest, VerifyOutputsEquivalent) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "abcd", 10.0); // 3
AddPiece(&model_proto, "abc", 5.0); // 4
AddPiece(&model_proto, "ab", 6.0); // 5
AddPiece(&model_proto, "cd", 4.0); // 6
AddPiece(&model_proto, "a", 4.0); // 7
AddPiece(&model_proto, "b", 1.9); // 8
AddPiece(&model_proto, "c", 2.0); // 9
AddPiece(&model_proto, "d", 1.0); // 10
Model model(model_proto);
model.SetEncoderVersion(encoder_version_);
// Equivalent outputs.
EXPECT_TRUE(model.VerifyOutputsEquivalent("", ""));
EXPECT_TRUE(model.VerifyOutputsEquivalent("a b", "a b"));
EXPECT_TRUE(model.VerifyOutputsEquivalent("abcd", "ab cd"));
// Inequivalent outputs.
EXPECT_FALSE(model.VerifyOutputsEquivalent("a", "a b"));
EXPECT_FALSE(model.VerifyOutputsEquivalent("ab", "a b"));
}
INSTANTIATE_TEST_SUITE_P(ParametrizedUnigramModelTests,
UnigramModelTest,
test::ValuesIn(GetEncoderVersions()));
} // namespace unigram
} // namespace sentencepiece