chromium/third_party/sentencepiece/src/src/model_interface_test.cc

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "model_interface.h"

#include "absl/container/flat_hash_map.h"
#include "model_factory.h"
#include "testharness.h"
#include "util.h"

namespace sentencepiece {
namespace {

#define WS "\xe2\x96\x81"

const std::vector<TrainerSpec::ModelType> kModelTypes = {
    TrainerSpec::UNIGRAM, TrainerSpec::BPE, TrainerSpec::WORD,
    TrainerSpec::CHAR};

ModelProto MakeBaseModelProto(TrainerSpec::ModelType type,
                              bool byte_fallback = false) {
  ModelProto model_proto;
  auto *sp1 = model_proto.add_pieces();
  auto *sp2 = model_proto.add_pieces();
  auto *sp3 = model_proto.add_pieces();
  model_proto.mutable_trainer_spec()->set_model_type(type);
  model_proto.mutable_trainer_spec()->set_byte_fallback(byte_fallback);

  sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
  sp1->set_piece("<unk>");
  sp2->set_type(ModelProto::SentencePiece::CONTROL);
  sp2->set_piece("<s>");
  sp3->set_type(ModelProto::SentencePiece::CONTROL);
  sp3->set_piece("</s>");

  return model_proto;
}

void AddPiece(ModelProto *model_proto, const std::string &piece,
              float score = 0.0) {
  auto *sp = model_proto->add_pieces();
  sp->set_piece(piece);
  sp->set_score(score);
}

void AddBytePiece(ModelProto* model_proto, unsigned char byte) {
  auto* sp = model_proto->add_pieces();
  sp->set_piece(ByteToPiece(byte));
  sp->set_type(ModelProto::SentencePiece::BYTE);
}

TEST(ModelInterfaceTest, GetDefaultPieceTest) {
  {
    ModelProto model_proto;
    EXPECT_EQ("<unk>", model_proto.trainer_spec().unk_piece());
    EXPECT_EQ("<s>", model_proto.trainer_spec().bos_piece());
    EXPECT_EQ("</s>", model_proto.trainer_spec().eos_piece());
    EXPECT_EQ("<pad>", model_proto.trainer_spec().pad_piece());
  }

  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    AddPiece(&model_proto, "a");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_EQ("<unk>", model->unk_piece());
    EXPECT_EQ("<s>", model->bos_piece());
    EXPECT_EQ("</s>", model->eos_piece());
    EXPECT_EQ("<pad>", model->pad_piece());
  }

  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    AddPiece(&model_proto, "a");
    model_proto.mutable_trainer_spec()->clear_unk_piece();
    model_proto.mutable_trainer_spec()->clear_bos_piece();
    model_proto.mutable_trainer_spec()->clear_eos_piece();
    model_proto.mutable_trainer_spec()->clear_pad_piece();
    auto model = ModelFactory::Create(model_proto);
    EXPECT_EQ("<unk>", model->unk_piece());
    EXPECT_EQ("<s>", model->bos_piece());
    EXPECT_EQ("</s>", model->eos_piece());
    EXPECT_EQ("<pad>", model->pad_piece());
  }

  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    AddPiece(&model_proto, "a");
    model_proto.mutable_trainer_spec()->set_unk_piece("UNK");
    model_proto.mutable_trainer_spec()->set_bos_piece("BOS");
    model_proto.mutable_trainer_spec()->set_eos_piece("EOS");
    model_proto.mutable_trainer_spec()->set_pad_piece("PAD");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_EQ("UNK", model->unk_piece());
    EXPECT_EQ("BOS", model->bos_piece());
    EXPECT_EQ("EOS", model->eos_piece());
    EXPECT_EQ("PAD", model->pad_piece());
  }
}

TEST(ModelInterfaceTest, SetModelInterfaceTest) {
  for (const auto type : kModelTypes) {
    ModelProto model_proto = MakeBaseModelProto(type);
    AddPiece(&model_proto, "a");
    AddPiece(&model_proto, "b");
    AddPiece(&model_proto, "c");
    AddPiece(&model_proto, "d");

    auto model = ModelFactory::Create(model_proto);
    EXPECT_EQ(model_proto.SerializeAsString(),
              model->model_proto().SerializeAsString());
  }
}

TEST(ModelInterfaceTest, PieceToIdTest) {
  for (const auto type : kModelTypes) {
    ModelProto model_proto = MakeBaseModelProto(type);

    AddPiece(&model_proto, "a", 0.1);  // 3
    AddPiece(&model_proto, "b", 0.2);  // 4
    AddPiece(&model_proto, "c", 0.3);  // 5
    AddPiece(&model_proto, "d", 0.4);  // 6
    AddPiece(&model_proto, "e", 0.5);  // 7
    model_proto.mutable_pieces(6)->set_type(ModelProto::SentencePiece::UNUSED);
    model_proto.mutable_pieces(7)->set_type(
        ModelProto::SentencePiece::USER_DEFINED);

    auto model = ModelFactory::Create(model_proto);

    EXPECT_EQ(model_proto.SerializeAsString(),
              model->model_proto().SerializeAsString());

    EXPECT_EQ(0, model->PieceToId("<unk>"));
    EXPECT_EQ(1, model->PieceToId("<s>"));
    EXPECT_EQ(2, model->PieceToId("</s>"));
    EXPECT_EQ(3, model->PieceToId("a"));
    EXPECT_EQ(4, model->PieceToId("b"));
    EXPECT_EQ(5, model->PieceToId("c"));
    EXPECT_EQ(6, model->PieceToId("d"));
    EXPECT_EQ(7, model->PieceToId("e"));
    EXPECT_EQ(0, model->PieceToId("f"));  // unk
    EXPECT_EQ(0, model->PieceToId(""));   // unk

    EXPECT_EQ("<unk>", model->IdToPiece(0));
    EXPECT_EQ("<s>", model->IdToPiece(1));
    EXPECT_EQ("</s>", model->IdToPiece(2));
    EXPECT_EQ("a", model->IdToPiece(3));
    EXPECT_EQ("b", model->IdToPiece(4));
    EXPECT_EQ("c", model->IdToPiece(5));
    EXPECT_EQ("d", model->IdToPiece(6));
    EXPECT_EQ("e", model->IdToPiece(7));

    EXPECT_TRUE(model->IsUnknown(0));
    EXPECT_FALSE(model->IsUnknown(1));
    EXPECT_FALSE(model->IsUnknown(2));
    EXPECT_FALSE(model->IsUnknown(3));
    EXPECT_FALSE(model->IsUnknown(4));
    EXPECT_FALSE(model->IsUnknown(5));
    EXPECT_FALSE(model->IsUnknown(6));
    EXPECT_FALSE(model->IsUnknown(7));

    EXPECT_FALSE(model->IsControl(0));
    EXPECT_TRUE(model->IsControl(1));
    EXPECT_TRUE(model->IsControl(2));
    EXPECT_FALSE(model->IsControl(3));
    EXPECT_FALSE(model->IsControl(4));
    EXPECT_FALSE(model->IsControl(5));
    EXPECT_FALSE(model->IsControl(6));
    EXPECT_FALSE(model->IsControl(7));

    EXPECT_FALSE(model->IsUnused(0));
    EXPECT_FALSE(model->IsUnused(1));
    EXPECT_FALSE(model->IsUnused(2));
    EXPECT_FALSE(model->IsUnused(3));
    EXPECT_FALSE(model->IsUnused(4));
    EXPECT_FALSE(model->IsUnused(5));
    EXPECT_TRUE(model->IsUnused(6));
    EXPECT_FALSE(model->IsUnused(7));

    EXPECT_FALSE(model->IsUserDefined(0));
    EXPECT_FALSE(model->IsUserDefined(1));
    EXPECT_FALSE(model->IsUserDefined(2));
    EXPECT_FALSE(model->IsUserDefined(3));
    EXPECT_FALSE(model->IsUserDefined(4));
    EXPECT_FALSE(model->IsUserDefined(5));
    EXPECT_FALSE(model->IsUserDefined(6));
    EXPECT_TRUE(model->IsUserDefined(7));

    EXPECT_NEAR(0, model->GetScore(0), 0.0001);
    EXPECT_NEAR(0, model->GetScore(1), 0.0001);
    EXPECT_NEAR(0, model->GetScore(2), 0.0001);
    EXPECT_NEAR(0.1, model->GetScore(3), 0.0001);
    EXPECT_NEAR(0.2, model->GetScore(4), 0.0001);
    EXPECT_NEAR(0.3, model->GetScore(5), 0.0001);
    EXPECT_NEAR(0.4, model->GetScore(6), 0.0001);
    EXPECT_NEAR(0.5, model->GetScore(7), 0.0001);
  }
}

TEST(ModelInterfaceTest, InvalidModelTest) {
  // Empty piece.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    AddPiece(&model_proto, "");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }

  // Duplicated pieces.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    AddPiece(&model_proto, "a");
    AddPiece(&model_proto, "a");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }

  // Multiple unknowns.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    model_proto.mutable_pieces(1)->set_type(ModelProto::SentencePiece::UNKNOWN);
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }

  // No unknown.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    model_proto.mutable_pieces(0)->set_type(ModelProto::SentencePiece::CONTROL);
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }
}

TEST(ModelInterfaceTest, ByteFallbackModelTest) {
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM, true);
    for (int i = 0; i < 256; ++i) {
      AddBytePiece(&model_proto, i);
    }
    AddPiece(&model_proto, "a");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_TRUE(model->status().ok());
  }

  // `byte_fallback` is true, but there are not 256 byte pieces.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM, true);
    for (int i = 0; i < 10; ++i) {
      AddBytePiece(&model_proto, i);
    }
    AddPiece(&model_proto, "a");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }

  // `byte_fallback` is false, but a byte piece is found.
  {
    ModelProto model_proto = MakeBaseModelProto(TrainerSpec::UNIGRAM);
    for (int i = 0; i < 10; ++i) {
      AddBytePiece(&model_proto, i);
    }
    AddPiece(&model_proto, "a");
    auto model = ModelFactory::Create(model_proto);
    EXPECT_FALSE(model->status().ok());
  }
}

std::string RandomString(int length) {
  const char kAlphaNum[] =
      "0123456789"
      "!@#$%^&*"
      "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
      "abcdefghijklmnopqrstuvwxyz";
  const int kAlphaSize = sizeof(kAlphaNum) - 1;
  const int size = rand() % length + 1;
  std::string result;
  for (int i = 0; i < size; ++i) {
    result += kAlphaNum[rand() % kAlphaSize];
  }
  return result;
}

TEST(ModelInterfaceTest, PieceToIdStressTest) {
  for (const auto type : kModelTypes) {
    for (int i = 0; i < 100; ++i) {
      absl::flat_hash_map<std::string, int> expected_p2i;
      absl::flat_hash_map<int, std::string> expected_i2p;
      ModelProto model_proto = MakeBaseModelProto(type);
      for (int n = 0; n < 1000; ++n) {
        const std::string piece = RandomString(10);
        if (expected_p2i.find(piece) != expected_p2i.end()) {
          continue;
        }
        expected_p2i[piece] = model_proto.pieces_size();
        expected_i2p[model_proto.pieces_size()] = piece;
        AddPiece(&model_proto, piece);
      }

      auto model = ModelFactory::Create(model_proto);
      for (const auto &it : expected_p2i) {
        EXPECT_EQ(it.second, model->PieceToId(it.first));
      }
      for (const auto &it : expected_i2p) {
        EXPECT_EQ(it.second, model->IdToPiece(it.first));
      }
    }
  }
}

TEST(ModelInterfaceTest, SplitIntoWordsTest) {
  {
    const auto v = SplitIntoWords(WS "this" WS "is" WS "a" WS "pen");
    EXPECT_EQ(4, v.size());
    EXPECT_EQ(WS "this", v[0]);
    EXPECT_EQ(WS "is", v[1]);
    EXPECT_EQ(WS "a", v[2]);
    EXPECT_EQ(WS "pen", v[3]);
  }

  {
    const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen");
    EXPECT_EQ(4, v.size());
    EXPECT_EQ("this", v[0]);
    EXPECT_EQ(WS "is", v[1]);
    EXPECT_EQ(WS "a", v[2]);
    EXPECT_EQ(WS "pen", v[3]);
  }

  {
    const auto v = SplitIntoWords(WS "this" WS WS "is");
    EXPECT_EQ(3, v.size());
    EXPECT_EQ(WS "this", v[0]);
    EXPECT_EQ(WS, v[1]);
    EXPECT_EQ(WS "is", v[2]);
  }

  {
    const auto v = SplitIntoWords("");
    EXPECT_TRUE(v.empty());
  }

  {
    const auto v = SplitIntoWords("hello");
    EXPECT_EQ(1, v.size());
    EXPECT_EQ("hello", v[0]);
  }
}

TEST(ModelInterfaceTest, SplitIntoWordsSuffixTest) {
  {
    const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true);
    EXPECT_EQ(4, v.size());
    EXPECT_EQ("this" WS, v[0]);
    EXPECT_EQ("is" WS, v[1]);
    EXPECT_EQ("a" WS, v[2]);
    EXPECT_EQ("pen" WS, v[3]);
  }

  {
    const auto v = SplitIntoWords("this" WS "is" WS "a" WS "pen", true);
    EXPECT_EQ(4, v.size());
    EXPECT_EQ("this" WS, v[0]);
    EXPECT_EQ("is" WS, v[1]);
    EXPECT_EQ("a" WS, v[2]);
    EXPECT_EQ("pen", v[3]);
  }

  {
    const auto v = SplitIntoWords(WS "this" WS WS "is", true);
    EXPECT_EQ(4, v.size());
    EXPECT_EQ(WS, v[0]);
    EXPECT_EQ("this" WS, v[1]);
    EXPECT_EQ(WS, v[2]);
    EXPECT_EQ("is", v[3]);
  }

  {
    const auto v = SplitIntoWords("", true);
    EXPECT_TRUE(v.empty());
  }

  {
    const auto v = SplitIntoWords("hello", true);
    EXPECT_EQ(1, v.size());
    EXPECT_EQ("hello", v[0]);
  }

  {
    const auto v = SplitIntoWords("hello" WS WS, true);
    EXPECT_EQ(2, v.size());
    EXPECT_EQ("hello" WS, v[0]);
    EXPECT_EQ(WS, v[1]);
  }

  {
    const auto v = SplitIntoWords(WS WS "hello" WS WS, true);
    EXPECT_EQ(4, v.size());
    EXPECT_EQ(WS, v[0]);
    EXPECT_EQ(WS, v[1]);
    EXPECT_EQ("hello" WS, v[2]);
    EXPECT_EQ(WS, v[3]);
  }
}

TEST(ModelInterfaceTest, SplitIntoWordsWhiteSpaceOnly) {
  {
    const auto v =
        SplitIntoWords("this" WS "is" WS "a" WS "pen" WS, true, true);
    EXPECT_EQ(4, v.size());
    EXPECT_EQ("this" WS, v[0]);
    EXPECT_EQ("is" WS, v[1]);
    EXPECT_EQ("a" WS, v[2]);
    EXPECT_EQ("pen" WS, v[3]);
  }

  {
    const auto v = SplitIntoWords(WS WS WS "a", false, true);
    EXPECT_EQ(1, v.size());
    EXPECT_EQ(WS WS WS "a", v[0]);
  }

  {
    const auto v = SplitIntoWords("a" WS WS WS, true, true);
    EXPECT_EQ(1, v.size());
    EXPECT_EQ("a" WS WS WS, v[0]);
  }

  {
    const auto v = SplitIntoWords(WS WS, true, true);
    EXPECT_EQ(1, v.size());
    EXPECT_EQ(WS WS, v[0]);
  }

  {
    const auto v = SplitIntoWords(WS WS "a" WS, true, true);
    EXPECT_EQ(2, v.size());
    EXPECT_EQ(WS WS, v[0]);
    EXPECT_EQ("a" WS, v[1]);
  }

  {
    const auto v = SplitIntoWords(WS WS "a" WS, false, true);
    EXPECT_EQ(2, v.size());
    EXPECT_EQ(WS WS "a", v[0]);
    EXPECT_EQ(WS, v[1]);
  }
}

TEST(ModelInterfaceTest, ByteToPieceTest) {
  EXPECT_EQ(ByteToPiece(0), "<0x00>");
  EXPECT_EQ(ByteToPiece(1), "<0x01>");
  EXPECT_EQ(ByteToPiece(10), "<0x0A>");
  EXPECT_EQ(ByteToPiece(16), "<0x10>");
  EXPECT_EQ(ByteToPiece(255), "<0xFF>");
}

TEST(ModelInterfaceTest, PieceToByteTest) {
  // Valid byte pieces.
  EXPECT_EQ(PieceToByte("<0x00>"), 0);
  EXPECT_EQ(PieceToByte("<0x01>"), 1);
  EXPECT_EQ(PieceToByte("<0x0A>"), 10);
  EXPECT_EQ(PieceToByte("<0x10>"), 16);
  EXPECT_EQ(PieceToByte("<0xFF>"), 255);

  // Invalid byte pieces.
  EXPECT_EQ(PieceToByte("<0x0>"), -1);
  EXPECT_EQ(PieceToByte("<0x000>"), -1);
  EXPECT_EQ(PieceToByte("<0x001>"), -1);
  EXPECT_EQ(PieceToByte("<0xff>"), -1);
  EXPECT_EQ(PieceToByte("<0xFG>"), -1);
  EXPECT_EQ(PieceToByte("a"), -1);
}

TEST(ModelInterfaceTest, VerifyOutputsEquivalent) {
  for (const auto type : kModelTypes) {
    ModelProto model_proto = MakeBaseModelProto(type);
    AddPiece(&model_proto, "a", 1.0);
    AddPiece(&model_proto, "b", 2.0);
    auto model = ModelFactory::Create(model_proto);

    // Equivalent outputs.
    EXPECT_TRUE(model->VerifyOutputsEquivalent("", ""));
    EXPECT_TRUE(model->VerifyOutputsEquivalent("a b", "a b"));

    // Inequivalent outputs.
    EXPECT_FALSE(model->VerifyOutputsEquivalent("a", "a b"));
  }
}

}  // namespace
}  // namespace sentencepiece