chromium/third_party/sentencepiece/src/src/unigram_model_trainer_test.cc

// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "unigram_model_trainer.h"

#include <string>
#include <vector>

#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "filesystem.h"
#include "sentencepiece_model.pb.h"
#include "sentencepiece_processor.h"
#include "sentencepiece_trainer.h"
#include "testharness.h"
#include "util.h"

namespace sentencepiece {
namespace unigram {

// Space symbol
#define WS "\xe2\x96\x81"

TEST(UnigramTrainerTest, TrainerModelTest) {
  TrainerSpec trainer_spec;
  NormalizerSpec normalizer_spec;
  const TrainerModel model(trainer_spec, normalizer_spec);
  EXPECT_EQ(EncodeResult(), model.Encode("test"));
}

struct TrainerResult {
  std::string sentence_pieces;
  std::vector<std::pair<std::string, float>> seed_pieces_and_probs;
};

TrainerResult RunTrainer(const std::vector<std::string>& input,
                         int size,
                         const bool use_dp = false,
                         const float dp_noise = 0.0,
                         const uint32 dp_clip = 0) {
  const std::string input_file =
      util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
  const std::string model_prefix =
      util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
  {
    auto output = filesystem::NewWritableFile(input_file);
    for (const auto& line : input) {
      output->WriteLine(line);
    }
  }

  TrainerSpec trainer_spec;
  trainer_spec.set_input_format("tsv");
  trainer_spec.set_model_type(TrainerSpec::UNIGRAM);
  trainer_spec.add_input(input_file);
  trainer_spec.set_vocab_size(size - 3);  // remove <unk>, <s>, </s>
  trainer_spec.set_model_prefix(model_prefix);

  trainer_spec.set_enable_differential_privacy(use_dp);
  trainer_spec.set_differential_privacy_noise_level(dp_noise);
  trainer_spec.set_differential_privacy_clipping_threshold(dp_clip);

  NormalizerSpec normalizer_spec;
  normalizer_spec.set_name("identity");
  normalizer_spec.set_add_dummy_prefix(false);

  NormalizerSpec denormalizer_spec;

  std::vector<std::pair<std::string, float>> seed_pieces;

  {
    Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
    EXPECT_OK(trainer.LoadSentences());
    TrainerModel::SentencePieces res = trainer.MakeSeedSentencePieces();

    for (const auto& piece : res) {
      seed_pieces.emplace_back(piece.first, piece.second);
    }
  }

  std::vector<std::string> pieces;

  {
    Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
    EXPECT_TRUE(trainer.Train().ok());

    SentencePieceProcessor processor;
    EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());

    const auto& model = processor.model_proto();

    // remove <unk>, <s>, </s>
    for (int i = 3; i < model.pieces_size(); ++i) {
      pieces.emplace_back(model.pieces(i).piece());
    }
  }

  TrainerResult res;
  res.seed_pieces_and_probs = seed_pieces;
  std::sort(pieces.begin(), pieces.end());
  res.sentence_pieces = absl::StrJoin(pieces, " ");
  return res;
}

TEST(UnigramTrainerTest, BasicTest) {
  const auto& res = RunTrainer(
      {"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1",
       "Overly \t 6", "Available \t 3"},
      30);

  // Check seed pieces.
  EXPECT_EQ(27, res.seed_pieces_and_probs.size());

  // Check final pieces.
  EXPECT_EQ("A O P a an apple b d e g h i l le m n p r t v ve y ▁ ▁an",
            res.sentence_pieces);
}

TEST(UnigramTrainerTest, BasicDPTest) {
  // no noise, clipping.
  {
    const auto& res = RunTrainer(
        {"magnanimity \t 5", "Pineapple \t 6", "i have an apple and a pen \t 1",
         "Overly \t 6", "Available \t 5"},
        22, true /*use_dp*/, 0 /*dp_noise*/, 4 /*dp_clipping*/);

    // Got 16 instead of 27 seeds.
    EXPECT_EQ(16, res.seed_pieces_and_probs.size());

    // And they are equiv to if the last sentence was not there.
    const auto& res_nodp = RunTrainer(
        {"magnanimity \t 5", "Pineapple \t 6", "Overly \t 6", "Available \t 5"},
        22);

    EXPECT_EQ(res.seed_pieces_and_probs, res_nodp.seed_pieces_and_probs);

    // Check final pieces.
    EXPECT_EQ(res.sentence_pieces, res_nodp.sentence_pieces);
  }
}

namespace {

static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";

TEST(UnigramTrainerTest, EndToEndTest) {
  const std::string input =
      util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);

  ASSERT_TRUE(
      SentencePieceTrainer::Train(
          absl::StrCat(
              "--model_prefix=",
              util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
              " --input=", input,
              " --vocab_size=8000 --normalization_rule_name=identity",
              " --model_type=unigram --user_defined_symbols=<user>",
              " --control_symbols=<ctrl> --max_sentence_length=2048"))
          .ok());

  SentencePieceProcessor sp;
  EXPECT_TRUE(sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
                                     "tmp_model.model"))
                  .ok());
  EXPECT_EQ(8000, sp.GetPieceSize());

  const int cid = sp.PieceToId("<ctrl>");
  const int uid = sp.PieceToId("<user>");
  EXPECT_TRUE(sp.IsControl(cid));
  EXPECT_FALSE(sp.IsUnknown(uid));

  std::vector<std::string> tok;

  EXPECT_TRUE(sp.Encode("", &tok).ok());
  EXPECT_TRUE(tok.empty());

  EXPECT_TRUE(sp.Encode("吾輩《わがはい》は猫である。名前はまだ無い。"
                        "どこで生れたかとんと見当《けんとう》がつかぬ。"
                        "何でも薄暗いじめじめした所でニャーニャー泣いていた事だ"
                        "けは記憶している"
                        "。",
                        &tok)
                  .ok());
  // TODO(taku): Temporally disable this test on Windows.
#ifndef OS_WIN
  LOG(INFO) << "[" << absl::StrJoin(tok, " ") << std::endl;
  EXPECT_EQ(
      WS
      " 吾輩 《 わが はい 》 は猫である 。 名前はまだ 無 い 。 どこ で 生 れた "
      "か とん と 見当 《 けん とう 》 が つか ぬ 。 何でも 薄 暗 い じめ じめ "
      "した 所で ニャーニャー 泣 い ていた 事 だけは 記憶 している 。",
      absl::StrJoin(tok, " "));
#endif
}

}  // namespace
}  // namespace unigram
}  // namespace sentencepiece