// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include <cstdio>
#include <string>
#include "bpe_model.h"
#include "model_interface.h"
#include "testharness.h"
namespace sentencepiece {
namespace bpe {
namespace {
ModelProto MakeBaseModelProto() {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
return model_proto;
}
void AddPiece(ModelProto *model_proto, const std::string &piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(piece);
sp->set_score(score);
}
TEST(BPEModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "ab", 0.0); // 3
AddPiece(&model_proto, "cd", -0.1); // 4
AddPiece(&model_proto, "abc", -0.2); // 5
AddPiece(&model_proto, "a", -0.3); // 6
AddPiece(&model_proto, "b", -0.4); // 7
AddPiece(&model_proto, "c", -0.5); // 8
AddPiece(&model_proto, "ABC", -0.5); // 9
AddPiece(&model_proto, "abcdabcd", -0.5); // 10
AddPiece(&model_proto, "q", -0.5); // 11
AddPiece(&model_proto, "r", -0.5); // 12
AddPiece(&model_proto, "qr", -0.5); // 13
model_proto.mutable_pieces(9)->set_type( // ABC
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(10)->set_type( // abcdabcd
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(11)->set_type( // q
ModelProto::SentencePiece::USER_DEFINED);
model_proto.mutable_pieces(12)->set_type( // r
ModelProto::SentencePiece::USER_DEFINED);
const Model model(model_proto);
EncodeResult result;
result = model.Encode("");
EXPECT_TRUE(result.empty());
result = model.Encode("abc");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abc", result[0].first);
result = model.Encode("AB");
EXPECT_EQ(2, result.size());
EXPECT_EQ("A", result[0].first);
EXPECT_EQ("B", result[1].first);
result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("cd", result[1].first);
result = model.Encode("abcc");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("c", result[1].first);
result = model.Encode("xabcabaabcdd");
EXPECT_EQ(7, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("abc", result[1].first);
EXPECT_EQ("ab", result[2].first);
EXPECT_EQ("a", result[3].first);
EXPECT_EQ("ab", result[4].first);
EXPECT_EQ("cd", result[5].first);
EXPECT_EQ("d", result[6].first);
// all unknown.
result = model.Encode("xyz東京");
EXPECT_EQ(5, result.size());
EXPECT_EQ("x", result[0].first);
EXPECT_EQ("y", result[1].first);
EXPECT_EQ("z", result[2].first);
EXPECT_EQ("東", result[3].first);
EXPECT_EQ("京", result[4].first);
// User defined
result = model.Encode("ABC");
EXPECT_EQ(1, result.size());
EXPECT_EQ("ABC", result[0].first);
result = model.Encode("abABCcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("ABC", result[1].first);
EXPECT_EQ("cd", result[2].first);
// middle "abcdabcd" is user defined.
result = model.Encode("ababcdabcdcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("abcdabcd", result[1].first);
EXPECT_EQ("cd", result[2].first);
result = model.Encode("abqrcd");
EXPECT_EQ(4, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("q", result[1].first);
EXPECT_EQ("r", result[2].first);
EXPECT_EQ("cd", result[3].first);
}
TEST(BPEModelTest, EncodeAmbiguousTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "aa", -0.1);
AddPiece(&model_proto, "bb", -0.2);
AddPiece(&model_proto, "ab", -0.3);
AddPiece(&model_proto, "a", -0.4);
AddPiece(&model_proto, "b", -0.5);
const Model model(model_proto);
EncodeResult result;
// leftmost symbols are merged first.
result = model.Encode("aaa");
EXPECT_EQ(2, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("a", result[1].first);
// "bb" is replaced earlier than "ab".
result = model.Encode("aabb");
EXPECT_EQ(2, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("bb", result[1].first);
// "bb" is replaced earlier than "ab".
result = model.Encode("aaabbb");
EXPECT_EQ(4, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("a", result[1].first);
EXPECT_EQ("bb", result[2].first);
EXPECT_EQ("b", result[3].first);
result = model.Encode("aaaba");
EXPECT_EQ(3, result.size());
EXPECT_EQ("aa", result[0].first);
EXPECT_EQ("ab", result[1].first);
EXPECT_EQ("a", result[2].first);
// makes a broken utf-8
const std::string broken_utf8 = std::string("あ").substr(0, 1);
result = model.Encode(broken_utf8);
EXPECT_EQ(1, result.size());
EXPECT_EQ(broken_utf8, result[0].first);
}
TEST(BPEModelTest, NotSupportedTest) {
ModelProto model_proto = MakeBaseModelProto();
const Model model(model_proto);
EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10));
}
TEST(BPEModelTest, EncodeWithUnusedTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "abcd", 10.0); // 3
AddPiece(&model_proto, "abc", 5.0); // 4
AddPiece(&model_proto, "ab", 2.0); // 5
AddPiece(&model_proto, "cd", 1.0); // 6
AddPiece(&model_proto, "a", 0.0); // 7
AddPiece(&model_proto, "b", 0.0); // 8
AddPiece(&model_proto, "c", 0.0); // 9
AddPiece(&model_proto, "d", 0.0); // 10
// No unused.
{
const Model model(model_proto);
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
}
{
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
const Model model(model_proto);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("d", result[1].first);
}
{
// The parent rule "abc" is still alive even if the child "ab" is unused.
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
const Model model(model_proto);
const auto result = model.Encode("abcd");
EXPECT_EQ(2, result.size());
EXPECT_EQ("abc", result[0].first);
EXPECT_EQ("d", result[1].first);
}
{
// This is tricky case. Even though "cd" is alive, it is not used, as
// it is not merged during the segmentation step.
// Segmentation: a|b|c|d => ab|c|d| => abc|d => abcd
// Resegmentation: abcd => abc|d => ab|c|d. ("abcd", "abc" are unsued)
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
const Model model(model_proto);
const auto result = model.Encode("abcd");
EXPECT_EQ(3, result.size());
EXPECT_EQ("ab", result[0].first);
EXPECT_EQ("c", result[1].first);
EXPECT_EQ("d", result[2].first);
}
}
TEST(SampleModelTest, EncodeTest) {
ModelProto model_proto = MakeBaseModelProto();
AddPiece(&model_proto, "ab", 0.0);
AddPiece(&model_proto, "cd", -0.1);
AddPiece(&model_proto, "abc", -0.2);
AddPiece(&model_proto, "abcd", -0.3);
// No regularization
{
const Model model(model_proto);
const auto result = model.Encode("abcd");
EXPECT_EQ(1, result.size());
EXPECT_EQ("abcd", result[0].first);
}
{
auto get_tokens = [](const EncodeResult& result) {
std::string out;
for (const auto& r : result) {
if (!result.empty()) {
out += ' ';
}
out += std::string(r.first);
}
return out;
};
const Model model(model_proto);
const std::vector<double> kAlpha = {0.0, 0.1, 0.5, 0.7, 0.9};
for (const auto alpha : kAlpha) {
constexpr int kTrial = 100000;
std::map<std::string, int> freq;
for (int n = 0; n < kTrial; ++n) {
freq[get_tokens(
model.SampleEncode("abcd", static_cast<float>(alpha)))]++;
}
int num = 0;
if (alpha == 0.0) {
EXPECT_EQ(1, freq.size());
} else {
EXPECT_GT(freq.size(), 1);
}
for (const auto& it : freq) {
num += it.second;
}
EXPECT_EQ(num, kTrial);
}
}
}
} // namespace
} // namespace bpe
} // namespace sentencepiece