// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!
#include "sentencepiece_processor.h"
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "builder.h"
#include "filesystem.h"
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"
namespace sentencepiece {
// Space symbol
#define WS "\xe2\x96\x81"
class MockModel : public ModelInterface {
public:
void SetEncodeResult(absl::string_view input, const EncodeResult &output) {
input_ = input;
output_ = output;
}
void SetNBestEncodeResult(absl::string_view input,
const NBestEncodeResult &output) {
input_ = input;
nbest_output_ = output;
}
EncodeResult Encode(absl::string_view normalized) const {
EXPECT_EQ(normalized, input_);
return output_;
}
EncodeResult SampleEncode(absl::string_view normalized, float alpha) const {
EXPECT_EQ(normalized, input_);
return output_;
}
NBestEncodeResult NBestEncode(absl::string_view normalized,
int nbest_size) const {
EXPECT_EQ(normalized, input_);
return nbest_output_;
}
bool IsSampleEncodeAvailable() const override { return true; }
bool IsNBestEncodeAvailable() const override { return true; }
bool IsControl(int id) const { return id == 1 || id == 2; }
bool IsUnknown(int id) const { return id == 0; }
int GetPieceSize() const { return 10; }
int PieceToId(absl::string_view piece) const { return 0; }
const std::string &IdToPiece(int id) const { return kEmptyString; }
float GetScore(int id) const { return 0.0; }
private:
absl::string_view input_;
EncodeResult output_;
NBestEncodeResult nbest_output_;
const std::string kEmptyString;
};
class ByteFallbackMockModel : public MockModel {
public:
bool ByteFallbackEnabled() const override { return true; }
};
std::vector<std::string> GetSpVec(const EncodeResult &pieces) {
std::vector<std::string> sps;
for (const auto &p : pieces) {
sps.emplace_back(std::string(p.first));
}
return sps;
}
std::vector<int> GetIdVec(const EncodeResult &pieces) {
std::vector<int> ids;
for (const auto &p : pieces) {
ids.emplace_back(p.second);
}
return ids;
}
std::vector<std::string> GetSpVec(const SentencePieceText &spt) {
std::vector<std::string> sps;
for (auto &sp : spt.pieces()) {
sps.emplace_back(sp.piece());
}
return sps;
}
NormalizerSpec MakeDefaultNormalizerSpec() {
return SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc");
}
TEST(SentencepieceProcessorTest, StatusTest) {
SentencePieceProcessor sp;
EXPECT_FALSE(sp.status().ok());
auto mock = absl::make_unique<MockModel>();
sp.SetModel(std::move(mock));
EXPECT_FALSE(sp.status().ok());
}
TEST(SentencepieceProcessorTest, EncodeTest) {
const absl::string_view kInput = WS "ABC" WS "DEF";
SentencePieceProcessor sp;
const auto normalization_spec = MakeDefaultNormalizerSpec();
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
EXPECT_TRUE(sp.Encode("ABC DEF", &output).ok());
EXPECT_EQ(GetSpVec(result), output);
std::vector<int> ids;
EXPECT_TRUE(sp.Encode("ABC DEF", &ids).ok());
EXPECT_EQ(GetIdVec(result), ids);
SentencePieceText spt;
EXPECT_TRUE(sp.Encode("ABC DEF", &spt).ok());
EXPECT_EQ(4, spt.pieces_size());
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
SentencePieceText spt2;
EXPECT_TRUE(spt2.ParseFromString(sp.EncodeAsSerializedProto("ABC DEF")));
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString());
EXPECT_EQ("ABC", spt.pieces(0).surface());
EXPECT_EQ(" DE", spt.pieces(1).surface());
EXPECT_EQ("F", spt.pieces(2).surface());
EXPECT_EQ("", spt.pieces(3).surface()); // </s>
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(0, spt.pieces(2).id());
EXPECT_EQ(2, spt.pieces(3).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(3, spt.pieces(0).end());
EXPECT_EQ(3, spt.pieces(1).begin());
EXPECT_EQ(6, spt.pieces(1).end());
EXPECT_EQ(6, spt.pieces(2).begin());
EXPECT_EQ(7, spt.pieces(2).end());
EXPECT_EQ(7, spt.pieces(3).begin());
EXPECT_EQ(7, spt.pieces(3).end());
}
// Unknown sequences.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {
{WS "ABC", 3}, {WS "D", 4}, {"E", 0}, {"F", 0}, {"</s>", 2}};
const EncodeResult expected = {
{WS "ABC", 3}, {WS "D", 4}, {"EF", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
EXPECT_TRUE(sp.Encode("ABC DEF", &output).ok());
EXPECT_EQ(GetSpVec(expected), output);
std::vector<int> ids;
EXPECT_TRUE(sp.Encode("ABC DEF", &ids).ok());
EXPECT_EQ(GetIdVec(expected), ids);
SentencePieceText spt;
EXPECT_TRUE(sp.Encode("ABC DEF", &spt).ok());
EXPECT_EQ(4, spt.pieces_size());
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(expected[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("ABC", spt.pieces(0).surface());
EXPECT_EQ(" D", spt.pieces(1).surface());
EXPECT_EQ("EF", spt.pieces(2).surface());
EXPECT_EQ("", spt.pieces(3).surface()); // </s>
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(0, spt.pieces(2).id());
EXPECT_EQ(2, spt.pieces(3).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(3, spt.pieces(0).end());
EXPECT_EQ(3, spt.pieces(1).begin());
EXPECT_EQ(5, spt.pieces(1).end());
EXPECT_EQ(5, spt.pieces(2).begin());
EXPECT_EQ(7, spt.pieces(2).end());
EXPECT_EQ(7, spt.pieces(3).begin());
EXPECT_EQ(7, spt.pieces(3).end());
}
// Byte-fallback.
{
const absl::string_view kInput2 = WS "ABC" WS "DEFあ";
auto mock = absl::make_unique<ByteFallbackMockModel>();
const EncodeResult result = {{WS "ABC", 3}, {WS "D", 4}, {"E", 0},
{"F", 0}, {"あ", 0}, {"</s>", 2}};
// "E" -> 0x45
// "F" -> 0x46
// "あ" -> 0xe38182
const EncodeResult expected = {{WS "ABC", 3}, {WS "D", 4}, {"<0x45>", 0},
{"<0x46>", 0}, {"<0xE3>", 0}, {"<0x81>", 0},
{"<0x82>", 0}, {"</s>", 2}};
mock->SetEncodeResult(kInput2, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
EXPECT_TRUE(sp.Encode("ABC DEFあ", &output).ok());
EXPECT_EQ(GetSpVec(expected), output);
std::vector<int> ids;
EXPECT_TRUE(sp.Encode("ABC DEFあ", &ids).ok());
EXPECT_EQ(GetIdVec(expected), ids);
SentencePieceText spt;
EXPECT_TRUE(sp.Encode("ABC DEFあ", &spt).ok());
EXPECT_EQ(8, spt.pieces_size());
for (int i = 0; i < 8; ++i) {
EXPECT_EQ(expected[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("ABC", spt.pieces(0).surface());
EXPECT_EQ(" D", spt.pieces(1).surface());
EXPECT_EQ("E", spt.pieces(2).surface());
EXPECT_EQ("F", spt.pieces(3).surface());
EXPECT_EQ("", spt.pieces(4).surface()); // あ
EXPECT_EQ("", spt.pieces(5).surface()); // あ
EXPECT_EQ("あ", spt.pieces(6).surface()); // あ
EXPECT_EQ("", spt.pieces(7).surface()); // </s>
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(0, spt.pieces(2).id());
EXPECT_EQ(0, spt.pieces(3).id());
EXPECT_EQ(0, spt.pieces(4).id());
EXPECT_EQ(0, spt.pieces(5).id());
EXPECT_EQ(0, spt.pieces(6).id());
EXPECT_EQ(2, spt.pieces(7).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(3, spt.pieces(0).end());
EXPECT_EQ(3, spt.pieces(1).begin());
EXPECT_EQ(5, spt.pieces(1).end());
EXPECT_EQ(5, spt.pieces(2).begin());
EXPECT_EQ(6, spt.pieces(2).end());
EXPECT_EQ(6, spt.pieces(3).begin());
EXPECT_EQ(7, spt.pieces(3).end());
EXPECT_EQ(7, spt.pieces(4).begin()); // あ
EXPECT_EQ(7, spt.pieces(4).end());
EXPECT_EQ(7, spt.pieces(5).begin()); // あ
EXPECT_EQ(7, spt.pieces(5).end());
EXPECT_EQ(7, spt.pieces(6).begin()); // あ
EXPECT_EQ(10, spt.pieces(6).end());
EXPECT_EQ(10, spt.pieces(7).begin()); // </s>
EXPECT_EQ(10, spt.pieces(7).end());
}
// Crash if
// ModelInterface::Encode() returns shorter results.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {{WS "ABC", 3}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok());
}
// Crash if
// ModelInterface::Encode() returns longer results.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 5}, {"G", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok());
}
// Crash if
// ModelInterface::Encode() returns an empty piece.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"", 5}, {"F", 6}};
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
// Expects crash.
EXPECT_FALSE(sp.Encode("ABC DEF", &spt).ok());
}
// Halfwidth to Fullwidith katakana normalization.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {{WS "グー", 3}, {"グル", 4}, {"</s>", 2}};
const absl::string_view input = WS "グーグル";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
std::vector<std::string> output;
EXPECT_TRUE(sp.Encode("グーグル", &output).ok());
EXPECT_EQ(GetSpVec(result), output);
SentencePieceText spt;
EXPECT_TRUE(sp.Encode("グーグル", &spt).ok());
EXPECT_EQ(3, spt.pieces_size());
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("グー", spt.pieces(0).surface());
EXPECT_EQ("グル", spt.pieces(1).surface());
EXPECT_EQ("", spt.pieces(2).surface());
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(2, spt.pieces(2).id());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(9, spt.pieces(0).end());
EXPECT_EQ(9, spt.pieces(1).begin());
EXPECT_EQ(18, spt.pieces(1).end());
EXPECT_EQ(18, spt.pieces(2).begin()); // </s>
EXPECT_EQ(18, spt.pieces(2).end());
}
// One to many normalization.
{
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {{WS "株式", 3}, {"会社", 4}, {"</s>", 2}};
const absl::string_view input = WS "株式会社";
mock->SetEncodeResult(input, result);
sp.SetModel(std::move(mock));
std::vector<std::string> output;
EXPECT_TRUE(sp.Encode("㍿", &output).ok());
EXPECT_EQ(GetSpVec(result), output);
SentencePieceText spt;
EXPECT_TRUE(sp.Encode("㍿", &spt).ok());
EXPECT_EQ(3, spt.pieces_size());
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
}
EXPECT_EQ("", spt.pieces(0).surface());
EXPECT_EQ("㍿", spt.pieces(1).surface());
EXPECT_EQ("", spt.pieces(2).surface());
EXPECT_EQ(3, spt.pieces(0).id());
EXPECT_EQ(4, spt.pieces(1).id());
EXPECT_EQ(2, spt.pieces(2).id());
EXPECT_EQ(0, spt.pieces(0).begin()); // 株式
EXPECT_EQ(0, spt.pieces(0).end());
EXPECT_EQ(0, spt.pieces(1).begin()); // 会社
EXPECT_EQ(3, spt.pieces(1).end());
EXPECT_EQ(3, spt.pieces(2).begin()); // </s>
EXPECT_EQ(3, spt.pieces(2).end());
}
}
TEST(SentencepieceProcessorTest, NBestEncodeTest) {
const std::string kInput = WS "ABC" WS "DEF";
SentencePieceProcessor sp;
const auto normalization_spec = MakeDefaultNormalizerSpec();
auto mock = absl::make_unique<MockModel>();
const NBestEncodeResult result = {
{{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}},
static_cast<float>(1.0)},
{{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}},
static_cast<float>(0.9)}};
mock->SetNBestEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
std::vector<std::vector<std::string>> output;
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &output).ok());
EXPECT_EQ(2, output.size());
EXPECT_EQ(GetSpVec(result[0].first), output[0]);
EXPECT_EQ(GetSpVec(result[1].first), output[1]);
std::vector<std::vector<int>> ids;
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &ids).ok());
EXPECT_EQ(2, ids.size());
EXPECT_EQ(GetIdVec(result[0].first), ids[0]);
EXPECT_EQ(GetIdVec(result[1].first), ids[1]);
NBestSentencePieceText spt;
EXPECT_TRUE(sp.NBestEncode("ABC DEF", 2, &spt).ok());
EXPECT_EQ(2, spt.nbests_size());
EXPECT_EQ(4, spt.nbests(0).pieces_size());
EXPECT_EQ(4, spt.nbests(1).pieces_size());
EXPECT_NEAR(result[0].second, spt.nbests(0).score(), 0.001);
EXPECT_NEAR(result[1].second, spt.nbests(1).score(), 0.001);
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(result[0].first[i].first, spt.nbests(0).pieces(i).piece());
EXPECT_EQ(result[1].first[i].first, spt.nbests(1).pieces(i).piece());
}
NBestSentencePieceText spt2;
EXPECT_TRUE(
spt2.ParseFromString(sp.NBestEncodeAsSerializedProto("ABC DEF", 2)));
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString());
auto mock_empty = absl::make_unique<MockModel>();
mock_empty->SetNBestEncodeResult(kInput, {});
sp.SetModel(std::move(mock_empty));
EXPECT_FALSE(sp.NBestEncode("ABC DEF", 2, &output).ok());
}
TEST(SentencepieceProcessorTest, SampleEncodeTest) {
const std::string kInput = WS "ABC" WS "DEF";
SentencePieceProcessor sp;
const auto normalization_spec = MakeDefaultNormalizerSpec();
auto mock = absl::make_unique<MockModel>();
const EncodeResult result = {
{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}};
const NBestEncodeResult nbest_result = {
{{{WS "ABC", 3}, {WS "DE", 4}, {"F", 0}, {"</s>", 2}},
static_cast<float>(1.0)},
{{{WS "AB", 5}, {WS "CD", 6}, {"EF", 7}, {"</s>", 2}},
static_cast<float>(0.1)}};
mock->SetNBestEncodeResult(kInput, nbest_result);
mock->SetEncodeResult(kInput, result);
sp.SetModel(std::move(mock));
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
std::vector<std::string> output;
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &output).ok());
EXPECT_EQ(4, output.size());
EXPECT_EQ(GetSpVec(result), output);
std::vector<int> ids;
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &ids).ok());
EXPECT_EQ(4, ids.size());
EXPECT_EQ(GetIdVec(result), ids);
SentencePieceText spt;
EXPECT_TRUE(sp.SampleEncode("ABC DEF", -1, 0.5, &spt).ok());
EXPECT_EQ(4, spt.pieces_size());
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(result[i].first, spt.pieces(i).piece());
EXPECT_EQ(result[i].second, spt.pieces(i).id());
}
SentencePieceText spt2;
EXPECT_TRUE(spt2.ParseFromString(
sp.SampleEncodeAsSerializedProto("ABC DEF", -1, 0.5)));
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString());
EXPECT_FALSE(sp.SampleEncode("ABC DEF", 1024, 0.5, &output).ok());
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 0, 0.5, &output).ok());
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 1, 0.5, &output).ok());
std::vector<int> freq(2, 0);
for (int i = 0; i < 5000; ++i) {
EXPECT_TRUE(sp.SampleEncode("ABC DEF", 20, 0.5, &output).ok());
EXPECT_EQ(4, output.size());
if (GetSpVec(nbest_result[0].first) == output)
freq[0]++;
else if (GetSpVec(nbest_result[1].first) == output)
freq[1]++;
else
LOG(FATAL) << "Invalid result.";
}
const float expected_prob =
std::exp(0.5 * 1.0) / (std::exp(0.5 * 1.0) + std::exp(0.5 * 0.1));
const float prob = 1.0 * freq[0] / (freq[0] + freq[1]);
EXPECT_NEAR(prob, expected_prob, 0.05);
auto mock_empty = absl::make_unique<MockModel>();
mock_empty->SetNBestEncodeResult(kInput, {});
sp.SetModel(std::move(mock_empty));
EXPECT_FALSE(sp.SampleEncode("ABC DEF", 10, 0.5, &output).ok());
}
TEST(SentencepieceProcessorTest, DecodeTest) {
class DecodeMockModel : public ModelInterface {
public:
EncodeResult Encode(absl::string_view normalized) const override {
return {};
}
int GetPieceSize() const override { return 7; }
int PieceToId(absl::string_view piece) const override {
static absl::flat_hash_map<absl::string_view, int> kMap = {
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}};
return port::FindWithDefault(kMap, piece, 0);
}
const std::string &IdToPiece(int id) const override {
static std::vector<std::string> kMap = {
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H"};
return kMap[id];
}
bool IsUnknown(int id) const override { return (id == 0); }
bool IsControl(int id) const override { return (id == 1 || id == 2); }
bool IsByte(int id) const override { return false; }
float GetScore(int id) const override { return 0.0; }
};
const std::vector<std::string> input = {"<s>", WS "ABC", "<unk>", WS "DE",
"F", "G" WS "H", "I", "</s>"};
{
SentencePieceProcessor sp;
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC \xE2\x81\x87 DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(input[i], spt.pieces(i).piece());
}
EXPECT_EQ("", spt.pieces(0).surface());
EXPECT_EQ("ABC", spt.pieces(1).surface());
EXPECT_EQ(" \xE2\x81\x87 ", spt.pieces(2).surface());
EXPECT_EQ(" DE", spt.pieces(3).surface());
EXPECT_EQ("F", spt.pieces(4).surface());
EXPECT_EQ("G H", spt.pieces(5).surface());
EXPECT_EQ("I", spt.pieces(6).surface());
EXPECT_EQ("", spt.pieces(7).surface());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(0, spt.pieces(0).end());
EXPECT_EQ(0, spt.pieces(1).begin());
EXPECT_EQ(3, spt.pieces(1).end());
EXPECT_EQ(3, spt.pieces(2).begin());
EXPECT_EQ(8, spt.pieces(2).end());
EXPECT_EQ(8, spt.pieces(3).begin());
EXPECT_EQ(11, spt.pieces(3).end());
EXPECT_EQ(11, spt.pieces(4).begin());
EXPECT_EQ(12, spt.pieces(4).end());
EXPECT_EQ(12, spt.pieces(5).begin());
EXPECT_EQ(15, spt.pieces(5).end());
EXPECT_EQ(15, spt.pieces(6).begin());
EXPECT_EQ(16, spt.pieces(6).end());
EXPECT_EQ(16, spt.pieces(7).begin());
EXPECT_EQ(16, spt.pieces(7).end());
SentencePieceText spt2;
EXPECT_TRUE(spt2.ParseFromString(sp.DecodePiecesAsSerializedProto(input)));
EXPECT_EQ(spt.SerializeAsString(), spt2.SerializeAsString());
}
// unk_surface is not defined.
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC \xE2\x81\x87 DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
}
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
}
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("<UNK>");
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC<UNK> DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
}
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
proto->mutable_normalizer_spec()->set_add_dummy_prefix(false);
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false);
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ(" ABC DEFG HI", spt.text());
EXPECT_EQ(8, spt.pieces_size());
}
}
TEST(SentencepieceProcessorTest, DummyPrefixDecodeTest) {
class DecodeMockModel : public ModelInterface {
public:
EncodeResult Encode(absl::string_view normalized) const override {
return {};
}
int GetPieceSize() const override { return 7; }
int PieceToId(absl::string_view piece) const override {
static absl::flat_hash_map<absl::string_view, int> kMap = {
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {WS "ABC", 3},
{WS "DE", 4}, {"F", 5}, {"G" WS "H", 6}, {WS, 7}};
return port::FindWithDefault(kMap, piece, 0);
}
const std::string& IdToPiece(int id) const override {
static std::vector<std::string> kMap = {
"<unk>", "<s>", "</s>", WS "ABC", WS "DE", "F", "G" WS "H", WS};
return kMap[id];
}
bool IsUnknown(int id) const override { return (id == 0); }
bool IsControl(int id) const override { return (id == 1 || id == 2); }
bool IsByte(int id) const override { return false; }
float GetScore(int id) const override { return 0.0; }
};
// start the sequence with a whitespace token
const std::vector<std::string> input = {
"<s>", WS, WS "ABC", "<unk>", WS "DE", "F", "G" WS "H", "I", "</s>"};
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true);
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(false);
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ(" ABC DEFG HI", spt.text());
EXPECT_EQ(9, spt.pieces_size());
}
{
SentencePieceProcessor sp;
auto proto = absl::make_unique<ModelProto>();
proto->mutable_trainer_spec()->set_unk_surface("");
proto->mutable_normalizer_spec()->set_add_dummy_prefix(true);
proto->mutable_normalizer_spec()->set_remove_extra_whitespaces(true);
sp.Load(std::move(proto)).IgnoreError();
auto mock = absl::make_unique<DecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABC DEFG HI", spt.text());
EXPECT_EQ(9, spt.pieces_size());
}
}
TEST(SentencepieceProcessorTest, ByteFallbackDecodeTest) {
class ByteFallbackDecodeMockModel : public ModelInterface {
public:
EncodeResult Encode(absl::string_view normalized) const override {
return {};
}
int PieceToId(absl::string_view piece) const override {
using Map = absl::flat_hash_map<std::string, int>;
static const Map kMap = []() -> Map {
Map m = {
{"<unk>", 0}, {"<s>", 1}, {"</s>", 2}, {"A", 3}, {"B", 4}, {"C", 5},
};
for (int i = 0; i < 256; ++i) {
m[ByteToPiece(i)] = 6 + i;
}
return m;
}();
return port::FindWithDefault(kMap, std::string(piece), 0);
}
const std::string& IdToPiece(int id) const override {
static std::vector<std::string> kMap = []() -> std::vector<std::string> {
std::vector<std::string> m = {"<unk>", "<s>", "</s>", "A", "B", "C"};
for (int i = 0; i < 256; ++i) {
m.push_back(ByteToPiece(i));
}
return m;
}();
return kMap[id];
}
int GetPieceSize() const override { return 256; }
bool IsUnknown(int id) const override { return (id == 0); }
bool IsControl(int id) const override { return (id == 1 || id == 2); }
bool IsByte(int id) const override { return id >= 6; }
bool ByteFallbackEnabled() const override { return true; }
};
SentencePieceProcessor sp;
auto mock = absl::make_unique<ByteFallbackDecodeMockModel>();
sp.SetModel(std::move(mock));
const auto normalization_spec = MakeDefaultNormalizerSpec();
sp.SetNormalizer(
absl::make_unique<normalizer::Normalizer>(normalization_spec));
{
const std::vector<std::string> input = {
"<s>",
"A",
"B",
// "あ" -> 0xE3 0x81 0x82
"<0xE3>",
"<0x81>",
"<0x82>",
// "Z" -> 0x5A
"<0x5A>",
// "Ω" -> 0xCE 0xA9
"<0xCE>",
"<0xA9>",
"C",
// Invalid UTF-8 bytes.
"<0xE0>",
"<0x80>",
// "い" -> 0xE3 0x81 0x84
"<0xE3>",
"<0x81>",
"<0x84>",
// REPLACEMENT CHARACTER as byte pieces.
"<0xEF>",
"<0xBF>",
"<0xBD>",
};
SentencePieceText spt;
EXPECT_TRUE(sp.Decode(input, &spt).ok());
EXPECT_EQ("ABあZΩC\xEF\xBF\xBD\xEF\xBF\xBDい\xEF\xBF\xBD", spt.text());
EXPECT_EQ(18, spt.pieces_size());
for (int i = 0; i < 18; ++i) {
EXPECT_EQ(input[i], spt.pieces(i).piece());
}
EXPECT_EQ("", spt.pieces(0).surface());
EXPECT_EQ(0, spt.pieces(0).begin());
EXPECT_EQ(0, spt.pieces(0).end());
EXPECT_EQ("A", spt.pieces(1).surface());
EXPECT_EQ(0, spt.pieces(1).begin());
EXPECT_EQ(1, spt.pieces(1).end());
EXPECT_EQ("B", spt.pieces(2).surface());
EXPECT_EQ(1, spt.pieces(2).begin());
EXPECT_EQ(2, spt.pieces(2).end());
EXPECT_EQ("", spt.pieces(3).surface());
EXPECT_EQ("", spt.pieces(4).surface());
EXPECT_EQ("あ", spt.pieces(5).surface());
EXPECT_EQ(2, spt.pieces(3).begin());
EXPECT_EQ(2, spt.pieces(3).end());
EXPECT_EQ(2, spt.pieces(4).begin());
EXPECT_EQ(2, spt.pieces(4).end());
EXPECT_EQ(2, spt.pieces(5).begin());
EXPECT_EQ(5, spt.pieces(5).end());
EXPECT_EQ("Z", spt.pieces(6).surface());
EXPECT_EQ(5, spt.pieces(6).begin());
EXPECT_EQ(6, spt.pieces(6).end());
EXPECT_EQ("", spt.pieces(7).surface());
EXPECT_EQ("Ω", spt.pieces(8).surface());
EXPECT_EQ(6, spt.pieces(7).begin());
EXPECT_EQ(6, spt.pieces(7).end());
EXPECT_EQ(6, spt.pieces(8).begin());
EXPECT_EQ(8, spt.pieces(8).end());
EXPECT_EQ("C", spt.pieces(9).surface());
EXPECT_EQ(8, spt.pieces(9).begin());
EXPECT_EQ(9, spt.pieces(9).end());
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(10).surface());
EXPECT_EQ(9, spt.pieces(10).begin());
EXPECT_EQ(12, spt.pieces(10).end());
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(11).surface());
EXPECT_EQ(12, spt.pieces(11).begin());
EXPECT_EQ(15, spt.pieces(11).end());
EXPECT_EQ("", spt.pieces(12).surface());
EXPECT_EQ("", spt.pieces(13).surface());
EXPECT_EQ("い", spt.pieces(14).surface());
EXPECT_EQ(15, spt.pieces(12).begin());
EXPECT_EQ(15, spt.pieces(12).end());
EXPECT_EQ(15, spt.pieces(13).begin());
EXPECT_EQ(15, spt.pieces(13).end());
EXPECT_EQ(15, spt.pieces(14).begin());
EXPECT_EQ(18, spt.pieces(14).end());
EXPECT_EQ("", spt.pieces(15).surface());
EXPECT_EQ("", spt.pieces(16).surface());
EXPECT_EQ("\xEF\xBF\xBD", spt.pieces(17).surface());
EXPECT_EQ(18, spt.pieces(15).begin());
EXPECT_EQ(18, spt.pieces(15).end());
EXPECT_EQ(18, spt.pieces(16).begin());
EXPECT_EQ(18, spt.pieces(16).end());
EXPECT_EQ(18, spt.pieces(17).begin());
EXPECT_EQ(21, spt.pieces(17).end());
}
}
void AddPiece(ModelProto *model_proto, absl::string_view piece,
float score = 0.0) {
auto *sp = model_proto->add_pieces();
sp->set_piece(std::string(piece));
sp->set_score(score);
}
TEST(SentencePieceProcessorTest, LoadInvalidModelTest) {
SentencePieceProcessor sp;
EXPECT_FALSE(sp.Load("").ok());
EXPECT_FALSE(sp.Load("__UNKNOWN_FILE__").ok());
}
TEST(SentencePieceProcessorTest, LoadSerializedProtoTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
AddPiece(&model_proto, WS, 0.0);
*(model_proto.mutable_normalizer_spec()) = MakeDefaultNormalizerSpec();
SentencePieceProcessor sp;
EXPECT_FALSE(sp.LoadFromSerializedProto("__NOT_A_PROTO__").ok());
EXPECT_TRUE(sp.LoadFromSerializedProto(model_proto.SerializeAsString()).ok());
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
}
TEST(SentencePieceProcessorTest, EndToEndTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
AddPiece(&model_proto, "c", 0.2);
AddPiece(&model_proto, "ab", 1.0);
AddPiece(&model_proto, "\xE2\x96\x81", 3.0); // kSpaceSymbol
*(model_proto.mutable_normalizer_spec()) = MakeDefaultNormalizerSpec();
{
auto output = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model"), true);
output->Write(model_proto.SerializeAsString());
}
SentencePieceProcessor sp;
EXPECT_TRUE(
sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model")).ok());
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
EXPECT_EQ(8, sp.GetPieceSize());
EXPECT_EQ(0, sp.PieceToId("<unk>"));
EXPECT_EQ(1, sp.PieceToId("<s>"));
EXPECT_EQ(2, sp.PieceToId("</s>"));
EXPECT_EQ(3, sp.PieceToId("a"));
EXPECT_EQ(4, sp.PieceToId("b"));
EXPECT_EQ(5, sp.PieceToId("c"));
EXPECT_EQ(6, sp.PieceToId("ab"));
EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81"));
EXPECT_EQ("<unk>", sp.IdToPiece(0));
EXPECT_EQ("<s>", sp.IdToPiece(1));
EXPECT_EQ("</s>", sp.IdToPiece(2));
EXPECT_EQ("a", sp.IdToPiece(3));
EXPECT_EQ("b", sp.IdToPiece(4));
EXPECT_EQ("c", sp.IdToPiece(5));
EXPECT_EQ("ab", sp.IdToPiece(6));
EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7));
EXPECT_NEAR(0.0, sp.GetScore(0), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(1), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(2), 0.001);
EXPECT_NEAR(0.0, sp.GetScore(3), 0.001);
EXPECT_NEAR(0.3, sp.GetScore(4), 0.001);
EXPECT_NEAR(0.2, sp.GetScore(5), 0.001);
EXPECT_NEAR(1.0, sp.GetScore(6), 0.001);
EXPECT_NEAR(3.0, sp.GetScore(7), 0.001);
EXPECT_TRUE(sp.IsUnknown(0));
EXPECT_FALSE(sp.IsUnknown(1));
EXPECT_FALSE(sp.IsUnknown(2));
EXPECT_FALSE(sp.IsUnknown(3));
EXPECT_FALSE(sp.IsUnknown(4));
EXPECT_FALSE(sp.IsUnknown(5));
EXPECT_FALSE(sp.IsUnknown(6));
EXPECT_FALSE(sp.IsUnknown(7));
EXPECT_FALSE(sp.IsControl(0));
EXPECT_TRUE(sp.IsControl(1));
EXPECT_TRUE(sp.IsControl(2));
EXPECT_FALSE(sp.IsControl(3));
EXPECT_FALSE(sp.IsControl(4));
EXPECT_FALSE(sp.IsControl(5));
EXPECT_FALSE(sp.IsControl(6));
EXPECT_FALSE(sp.IsControl(7));
EXPECT_EQ(0, sp.unk_id());
EXPECT_EQ(1, sp.bos_id());
EXPECT_EQ(2, sp.eos_id());
EXPECT_EQ(-1, sp.pad_id());
{
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 6, 5};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 7, 6, 5};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("eos").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c", "</s>"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 6, 5, 2};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("reverse").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"c", "ab", WS};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {5, 6, 7};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos:eos").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", WS, "ab", "c",
"</s>"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 7, 6, 5, 2};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("reverse:bos:eos").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"<s>", "c", "ab", WS,
"</s>"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {1, 5, 6, 7, 2};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
EXPECT_TRUE(sp.SetEncodeExtraOptions("bos:eos:reverse").ok());
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {"</s>", "c", "ab", WS,
"<s>"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {2, 5, 6, 7, 1};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("eos").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("cba", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse:bos:eos").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("cba", output);
}
// Out of range
{
std::string output;
const std::vector<int> ids = {3, 4, 127};
EXPECT_FALSE(sp.Decode(ids, &output).ok());
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("bos:eos:reverse").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("cab", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("cba", output);
}
{
EXPECT_TRUE(sp.SetDecodeExtraOptions("reverse:reverse").ok());
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
EXPECT_TRUE(sp.SetEncodeExtraOptions("").ok());
EXPECT_TRUE(sp.SetDecodeExtraOptions("").ok());
EXPECT_FALSE(sp.SetEncodeExtraOptions("foo").ok());
EXPECT_FALSE(sp.SetDecodeExtraOptions("foo").ok());
auto RunTest = [&model_proto](const SentencePieceProcessor &sp) {
EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
EXPECT_EQ(8, sp.GetPieceSize());
EXPECT_EQ(0, sp.PieceToId("<unk>"));
EXPECT_EQ(1, sp.PieceToId("<s>"));
EXPECT_EQ(2, sp.PieceToId("</s>"));
EXPECT_EQ(3, sp.PieceToId("a"));
EXPECT_EQ(4, sp.PieceToId("b"));
EXPECT_EQ(5, sp.PieceToId("c"));
EXPECT_EQ(6, sp.PieceToId("ab"));
EXPECT_EQ(7, sp.PieceToId("\xE2\x96\x81"));
EXPECT_EQ("<unk>", sp.IdToPiece(0));
EXPECT_EQ("<s>", sp.IdToPiece(1));
EXPECT_EQ("</s>", sp.IdToPiece(2));
EXPECT_EQ("a", sp.IdToPiece(3));
EXPECT_EQ("b", sp.IdToPiece(4));
EXPECT_EQ("c", sp.IdToPiece(5));
EXPECT_EQ("ab", sp.IdToPiece(6));
EXPECT_EQ("\xE2\x96\x81", sp.IdToPiece(7));
EXPECT_TRUE(sp.IsUnknown(0));
EXPECT_FALSE(sp.IsUnknown(1));
EXPECT_FALSE(sp.IsUnknown(2));
EXPECT_FALSE(sp.IsUnknown(3));
EXPECT_FALSE(sp.IsUnknown(4));
EXPECT_FALSE(sp.IsUnknown(5));
EXPECT_FALSE(sp.IsUnknown(6));
EXPECT_FALSE(sp.IsUnknown(7));
EXPECT_FALSE(sp.IsControl(0));
EXPECT_TRUE(sp.IsControl(1));
EXPECT_TRUE(sp.IsControl(2));
EXPECT_FALSE(sp.IsControl(3));
EXPECT_FALSE(sp.IsControl(4));
EXPECT_FALSE(sp.IsControl(5));
EXPECT_FALSE(sp.IsControl(6));
EXPECT_FALSE(sp.IsControl(7));
{
std::vector<std::string> sps;
const std::vector<std::string> expected_str = {WS, "ab", "c"};
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 6, 5};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
{
std::string output;
const std::vector<std::string> sps = {"ab", "c"};
EXPECT_TRUE(sp.Decode(sps, &output).ok());
EXPECT_EQ("abc", output);
const std::vector<int> ids = {3, 4, 5};
EXPECT_TRUE(sp.Decode(ids, &output).ok());
EXPECT_EQ("abc", output);
}
};
// Copies ModelProto.
{
SentencePieceProcessor sp;
const ModelProto copied = model_proto;
EXPECT_TRUE(sp.Load(copied).ok());
RunTest(sp);
}
// Moves ModelProto.
{
SentencePieceProcessor sp;
auto moved = absl::make_unique<ModelProto>();
const ModelProto *moved_ptr = moved.get();
*moved = model_proto;
EXPECT_TRUE(sp.Load(std::move(moved)).ok());
EXPECT_EQ(moved_ptr, &sp.model_proto());
RunTest(sp);
}
// Restrict Vocabulary.
{
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(model_proto).ok());
EXPECT_TRUE(sp.SetVocabulary({"a", "b", "c"}).ok()); // remove "ab"
const std::vector<std::string> expected_str = {WS, "a", "b", "c"};
std::vector<std::string> sps;
EXPECT_TRUE(sp.Encode("abc", &sps).ok());
EXPECT_EQ(expected_str, sps);
std::vector<int> ids;
const std::vector<int> expected_id = {7, 3, 4, 5};
EXPECT_TRUE(sp.Encode("abc", &ids).ok());
EXPECT_EQ(expected_id, ids);
}
}
TEST(SentencePieceProcessorTest, SkipNormalizationTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::USER_DEFINED);
sp2->set_piece("<USER>");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
AddPiece(&model_proto, "c", 0.2);
AddPiece(&model_proto, "u", 0.2);
AddPiece(&model_proto, "s", 0.2);
AddPiece(&model_proto, "e", 0.2);
AddPiece(&model_proto, "r", 0.2);
*(model_proto.mutable_normalizer_spec()) =
SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf");
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(model_proto).ok());
std::vector<std::string> pieces;
EXPECT_TRUE(sp.Encode("AB<USER>C<uSEr>", &pieces).ok());
for (const auto &sp : pieces) LOG(INFO) << sp;
EXPECT_EQ(std::vector<std::string>(
{WS, "a", "b", "<USER>", "c", "<", "u", "s", "e", "r", ">"}),
pieces);
}
TEST(SentencePieceProcessorTest, ExtraOptionsUndefinedTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
// No BOS/EOS.
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
AddPiece(&model_proto, "c", 0.2);
AddPiece(&model_proto, "ab", 1.0);
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(model_proto).ok());
EXPECT_FALSE(sp.SetEncodeExtraOptions("bos").ok());
EXPECT_FALSE(sp.SetDecodeExtraOptions("eos").ok());
}
TEST(SentencePieceProcessorTest, OverrideSpecialPieceTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
model_proto.mutable_trainer_spec()->set_unk_piece("__UNK__");
model_proto.mutable_trainer_spec()->set_bos_piece("__BOS__");
model_proto.mutable_trainer_spec()->set_eos_piece("__EOS__");
model_proto.mutable_trainer_spec()->set_pad_piece("__PAD__");
// No BOS/EOS.
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("__UNK__");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("__BOS__");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("__EOS__");
AddPiece(&model_proto, "a", 0.0);
AddPiece(&model_proto, "b", 0.3);
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(model_proto).ok());
EXPECT_EQ(0, sp.unk_id());
EXPECT_EQ(1, sp.bos_id());
EXPECT_EQ(2, sp.eos_id());
EXPECT_EQ(-1, sp.pad_id());
EXPECT_EQ("__UNK__", sp.IdToPiece(sp.unk_id()));
EXPECT_EQ("__BOS__", sp.IdToPiece(sp.bos_id()));
EXPECT_EQ("__EOS__", sp.IdToPiece(sp.eos_id()));
}
TEST(SentencePieceProcessorTest, VocabularyTest) {
ModelProto model_proto;
auto *sp1 = model_proto.add_pieces();
auto *sp2 = model_proto.add_pieces();
auto *sp3 = model_proto.add_pieces();
auto GetInlineFilename = [](const std::string content) {
{
auto out = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt"));
out->Write(content);
}
return util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt");
};
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
sp1->set_piece("<unk>");
sp2->set_type(ModelProto::SentencePiece::CONTROL);
sp2->set_piece("<s>");
sp3->set_type(ModelProto::SentencePiece::CONTROL);
sp3->set_piece("</s>");
AddPiece(&model_proto, "aa", 0.0);
AddPiece(&model_proto, "bb", 0.0);
AddPiece(&model_proto, "cc", 0.0);
AddPiece(&model_proto, "dd", 0.0);
AddPiece(&model_proto, "e", 0.0);
SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(model_proto).ok());
EXPECT_FALSE(sp.IsUnused(0));
EXPECT_FALSE(sp.IsUnused(1));
EXPECT_FALSE(sp.IsUnused(2));
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_FALSE(sp.IsUnused(4));
EXPECT_FALSE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.SetVocabulary({"aa", "dd", "e"}).ok());
EXPECT_FALSE(sp.IsUnused(0));
EXPECT_FALSE(sp.IsUnused(1));
EXPECT_FALSE(sp.IsUnused(2));
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7)); // single char "e" is always used.
EXPECT_TRUE(sp.ResetVocabulary().ok());
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_FALSE(sp.IsUnused(4));
EXPECT_FALSE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.SetVocabulary({"bb"}).ok());
EXPECT_TRUE(sp.IsUnused(3));
EXPECT_FALSE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_TRUE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t2\n"), 2).ok());
EXPECT_TRUE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 2).ok());
EXPECT_TRUE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_TRUE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t1\ndd\t1\n"), 1).ok());
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\t0\ndd\t0\n"), 0).ok());
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
// No frequency.
EXPECT_TRUE(sp.LoadVocabulary(GetInlineFilename("aa\ndd\n"), 1).ok());
EXPECT_FALSE(sp.IsUnused(3));
EXPECT_TRUE(sp.IsUnused(4));
EXPECT_TRUE(sp.IsUnused(5));
EXPECT_FALSE(sp.IsUnused(6));
EXPECT_FALSE(sp.IsUnused(7));
}
TEST(SentencePieceProcessorTest, ImmutableSentencePieceTextTest) {
ImmutableSentencePieceText spt;
EXPECT_TRUE(spt.text().empty());
EXPECT_EQ(spt.score(), 0.0);
EXPECT_TRUE(spt.SerializeAsString().empty());
auto* v = spt.mutable_proto();
v->set_text("hello world");
v->set_score(1.0);
for (int i = 0; i < 10; ++i) {
auto* p = v->add_pieces();
p->set_surface(absl::StrCat("surface_", i));
p->set_piece(absl::StrCat("surface_", i));
p->set_id(i);
p->set_begin(i + 10);
p->set_end(i + 20);
}
EXPECT_EQ(v->pieces_size(), spt.pieces_size());
for (int i = 0; i < spt.pieces_size(); ++i) {
EXPECT_EQ(v->pieces(i).surface(), spt.pieces(i).surface());
EXPECT_EQ(v->pieces(i).piece(), spt.pieces(i).piece());
EXPECT_EQ(v->pieces(i).id(), spt.pieces(i).id());
EXPECT_EQ(v->pieces(i).begin(), spt.pieces(i).begin());
EXPECT_EQ(v->pieces(i).end(), spt.pieces(i).end());
}
auto check_proto = [&v](const ImmutableSentencePieceText& s) {
int n = 0;
for (auto& p : s.pieces()) {
EXPECT_EQ(v->pieces(n).surface(), p.surface());
EXPECT_EQ(v->pieces(n).piece(), p.piece());
EXPECT_EQ(v->pieces(n).id(), p.id());
EXPECT_EQ(v->pieces(n).begin(), p.begin());
EXPECT_EQ(v->pieces(n).end(), p.end());
++n;
}
EXPECT_EQ(v->text(), s.text());
EXPECT_EQ(v->score(), s.score());
EXPECT_EQ(v->SerializeAsString(), s.SerializeAsString());
};
// test copy.
const auto spt2 = spt;
check_proto(spt2);
// test assign.
const ImmutableSentencePieceText spt3(spt);
check_proto(spt3);
// default piece.
const ImmutableSentencePieceText_ImmutableSentencePiece piece;
EXPECT_TRUE(piece.surface().empty());
EXPECT_TRUE(piece.piece().empty());
EXPECT_EQ(piece.begin(), 0);
EXPECT_EQ(piece.end(), 0);
EXPECT_EQ(piece.id(), 0);
}
TEST(SentencePieceProcessorTest, ImmutableNBestSentencePieceTextTest) {
ImmutableNBestSentencePieceText spt;
EXPECT_EQ(spt.nbests_size(), 0);
EXPECT_TRUE(spt.SerializeAsString().empty());
auto* v = spt.mutable_proto();
for (int i = 0; i < 10; ++i) {
auto* p = v->add_nbests();
p->set_text(absl::StrCat("text_", i));
p->set_score(2.0 * i);
}
auto check_proto = [&v](const ImmutableNBestSentencePieceText& s) {
EXPECT_EQ(v->nbests_size(), s.nbests_size());
for (int i = 0; i < v->nbests_size(); ++i) {
EXPECT_EQ(v->nbests(i).text(), s.nbests(i).text());
EXPECT_EQ(v->nbests(i).score(), s.nbests(i).score());
}
EXPECT_EQ(v->SerializeAsString(), s.SerializeAsString());
};
check_proto(spt);
// test copy.
const auto spt2 = spt;
check_proto(spt2);
// test assign.
const ImmutableNBestSentencePieceText spt3(spt);
check_proto(spt3);
}
TEST(SentencePieceProcessorTest, ConvertToUnicodeSpansTest) {
auto make_spt = [&](const std::vector<std::string>& tokens) {
ImmutableSentencePieceText ispt;
auto* spt = ispt.mutable_proto();
int prev = 0;
std::string text;
for (const auto& tok : tokens) {
auto* piece = spt->add_pieces();
piece->set_surface(tok);
piece->set_piece(tok);
piece->set_begin(prev);
piece->set_end(prev + tok.size());
prev += tok.size();
text += tok;
}
spt->set_text(text);
ispt.ConvertToUnicodeSpans();
return ispt;
};
{
const auto spt = make_spt({"hello", "_world", "."});
EXPECT_EQ(spt.pieces_size(), 3);
EXPECT_EQ(spt.pieces(0).begin(), 0);
EXPECT_EQ(spt.pieces(0).end(), 5);
EXPECT_EQ(spt.pieces(1).begin(), 5);
EXPECT_EQ(spt.pieces(1).end(), 11);
EXPECT_EQ(spt.pieces(2).begin(), 11);
EXPECT_EQ(spt.pieces(2).end(), 12);
}
{
const auto spt = make_spt({"これは", "test", "です"});
EXPECT_EQ(spt.pieces_size(), 3);
EXPECT_EQ(spt.pieces(0).begin(), 0);
EXPECT_EQ(spt.pieces(0).end(), 3);
EXPECT_EQ(spt.pieces(1).begin(), 3);
EXPECT_EQ(spt.pieces(1).end(), 7);
EXPECT_EQ(spt.pieces(2).begin(), 7);
EXPECT_EQ(spt.pieces(2).end(), 9);
}
{
const auto spt = make_spt({"いABは", "にほCD", "へと"});
EXPECT_EQ(spt.pieces_size(), 3);
EXPECT_EQ(spt.pieces(0).begin(), 0);
EXPECT_EQ(spt.pieces(0).end(), 4);
EXPECT_EQ(spt.pieces(1).begin(), 4);
EXPECT_EQ(spt.pieces(1).end(), 8);
EXPECT_EQ(spt.pieces(2).begin(), 8);
EXPECT_EQ(spt.pieces(2).end(), 10);
}
}
} // namespace sentencepiece