chromium/third_party/distributed_point_functions/code/dpf/distributed_point_function_test.cc

// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dpf/distributed_point_function.h"

#include <memory>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/config.h"
#include "absl/numeric/int128.h"
#include "absl/random/random.h"
#include "absl/random/uniform_int_distribution.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "absl/utility/utility.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/proto_validator.h"
#include "dpf/internal/status_matchers.h"
#include "dpf/xor_wrapper.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

namespace distributed_point_functions {
namespace {

using dpf_internal::IsOk;
using dpf_internal::IsOkAndHolds;
using dpf_internal::StatusIs;
using ::testing::HasSubstr;
using ::testing::Ne;
using ::testing::StartsWith;

TEST(DistributedPointFunction, TestCreate) {
  for (int log_domain_size = 0; log_domain_size <= 128; ++log_domain_size) {
    for (int element_bitsize = 1; element_bitsize <= 128;
         element_bitsize *= 2) {
      DpfParameters parameters;

      parameters.set_log_domain_size(log_domain_size);
      parameters.mutable_value_type()->mutable_integer()->set_bitsize(
          element_bitsize);

      EXPECT_THAT(DistributedPointFunction::Create(parameters),
                  IsOkAndHolds(Ne(nullptr)))
          << "log_domain_size=" << log_domain_size
          << " element_bitsize=" << element_bitsize;
    }
  }
}

TEST(DistributedPointFunction, TestCreateIncrementalLargeDomain) {
  std::vector<DpfParameters> parameters(2);
  parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128);
  parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128);

  // Test that creating an incremental DPF with a large total domain size works.
  parameters[0].set_log_domain_size(10);
  parameters[1].set_log_domain_size(100);

  EXPECT_THAT(DistributedPointFunction::CreateIncremental(parameters),
              IsOkAndHolds(Ne(nullptr)));
}

TEST(DistributedPointFunction, CreateFailsForTupleTypesWithDifferentIntModN) {
  DpfParameters parameters;
  parameters.set_log_domain_size(10);
  *(parameters.mutable_value_type()) =
      ToValueType<Tuple<IntModN<uint32_t, 3>, IntModN<uint64_t, 4>>>();

  EXPECT_THAT(
      DistributedPointFunction::Create(parameters),
      StatusIs(absl::StatusCode::kUnimplemented,
               "All elements of type IntModN in a tuple must be the same"));
}

TEST(DistributedPointFunction, CreateFailsForMissingValueType) {
  DpfParameters parameters;
  parameters.set_log_domain_size(10);

  EXPECT_THAT(
      DistributedPointFunction::Create(parameters),
      StatusIs(absl::StatusCode::kInvalidArgument, "`value_type` is required"));
}

TEST(DistributedPointFunction, CreateFailsForInvalidValueType) {
  DpfParameters parameters;
  parameters.set_log_domain_size(10);
  *(parameters.mutable_value_type()) = ValueType{};

  EXPECT_THAT(DistributedPointFunction::Create(parameters),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       StartsWith("ValidateValueType: Unsupported ValueType")));
}

TEST(DistributedPointFunction, TestGenerateKeysIncrementalVariadicTemplate) {
  std::vector<DpfParameters> parameters(2);

  parameters[0].set_log_domain_size(10);
  parameters[1].set_log_domain_size(20);
  *(parameters[0].mutable_value_type()) = ToValueType<uint16_t>();
  *(parameters[1].mutable_value_type()) =
      ToValueType<Tuple<uint32_t, uint64_t>>();
  DPF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<DistributedPointFunction> dpf,
      DistributedPointFunction::CreateIncremental(parameters));

  absl::StatusOr<std::pair<DpfKey, DpfKey>> keys = dpf->GenerateKeysIncremental(
      23, uint16_t{42}, Tuple<uint32_t, uint64_t>{123, 456});
  EXPECT_THAT(keys, IsOk());
}

TEST(DistributedPointFunction, TestGenerateKeysIncrementalTemplate) {
  std::vector<DpfParameters> parameters(2);
  using T = XorWrapper<absl::uint128>;

  parameters[0].set_log_domain_size(10);
  parameters[1].set_log_domain_size(20);
  *(parameters[0].mutable_value_type()) = ToValueType<T>();
  *(parameters[1].mutable_value_type()) = ToValueType<T>();
  DPF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<DistributedPointFunction> dpf,
      DistributedPointFunction::CreateIncremental(parameters));

  absl::StatusOr<std::pair<DpfKey, DpfKey>> keys =
      dpf->GenerateKeysIncremental(23, T{42}, T{123});
  EXPECT_THAT(keys, IsOk());
}

TEST(DistributedPointFunction, KeyGenerationFailsIfValueTypeNotRegistered) {
  DpfParameters parameters;
  parameters.set_log_domain_size(10);
  parameters.mutable_value_type()
      ->mutable_tuple()
      ->add_elements()
      ->mutable_integer()
      ->set_bitsize(32);
  DPF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DistributedPointFunction> dpf,
                           DistributedPointFunction::Create(parameters));

  // Tuple<uint32_t> should not be registered by default.
  absl::uint128 alpha = 23;
  Value beta;
  beta.mutable_tuple()->add_elements()->mutable_integer()->set_value_uint64(42);

  EXPECT_THAT(dpf->GenerateKeys(alpha, beta),
              StatusIs(absl::StatusCode::kFailedPrecondition,
                       StartsWith("No value correction function known")));
}

TEST(DistributedPointFunction, EvaluationFailsIfDomainSizeGapTooLarge) {
  std::vector<DpfParameters> parameters(2);
  parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128);
  parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128);
  parameters[0].set_log_domain_size(10);
  parameters[1].set_log_domain_size(100);
  DPF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<DistributedPointFunction> dpf,
      DistributedPointFunction::CreateIncremental(parameters));

  std::pair<DpfKey, DpfKey> keys;
  DPF_ASSERT_OK_AND_ASSIGN(keys, dpf->GenerateKeysIncremental(123, 456u, 789u));
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf->CreateEvaluationContext(keys.first));

  EXPECT_THAT(
      dpf->EvaluateUntil<absl::uint128>(1, {}, ctx),
      StatusIs(absl::StatusCode::kInvalidArgument, StartsWith("Domain size")));
}

TEST(DistributedPointFunction, EvaluationFailsIfOutputSizeTooLarge) {
  std::vector<DpfParameters> parameters(2);
  parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(128);
  parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(128);
  parameters[0].set_log_domain_size(10);
  parameters[1].set_log_domain_size(72);
  DPF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<DistributedPointFunction> dpf,
      DistributedPointFunction::CreateIncremental(parameters));

  std::pair<DpfKey, DpfKey> keys;
  DPF_ASSERT_OK_AND_ASSIGN(keys, dpf->GenerateKeysIncremental(123, 456u, 789u));
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf->CreateEvaluationContext(keys.first));

  // Evaluate on 2**2 prefixes, bringing the output size to 2**(72-10+2) =
  // 2**64, which overflows an int64_t. Assumes a size_t is at most 64 bits.
  std::vector<absl::uint128> prefixes = {123, 456, 789, 1011};
  DPF_ASSERT_OK(dpf->EvaluateUntil<absl::uint128>(0, {}, ctx));
  EXPECT_THAT(
      dpf->EvaluateUntil<absl::uint128>(1, prefixes, ctx),
      StatusIs(absl::StatusCode::kInvalidArgument, StartsWith("Output size")));
}

TEST(DistributedPointFunction, TestSinglePointPartialEvaluation) {
  // Two hierarchy levels: The first will be evaluated with only a single
  // prefix, the second will be fully evaluated.
  std::vector<DpfParameters> parameters(2);
  parameters[0].set_log_domain_size(108);
  parameters[0].mutable_value_type()->mutable_integer()->set_bitsize(32);
  parameters[1].set_log_domain_size(128);
  parameters[1].mutable_value_type()->mutable_integer()->set_bitsize(32);

  DPF_ASSERT_OK_AND_ASSIGN(
      auto dpf, DistributedPointFunction::CreateIncremental(parameters));
  absl::uint128 prefix = 0xdeadbeef;
  absl::uint128 suffix = 23;
  absl::uint128 alpha = (prefix << 20) + suffix;
  uint32_t beta = 42;
  DpfKey key_a, key_b;
  DPF_ASSERT_OK_AND_ASSIGN(std::tie(key_a, key_b),
                           dpf->GenerateKeysIncremental(alpha, {beta, beta}));

  // First evaluate directly up to `prefix`
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx_a,
                           dpf->CreateEvaluationContext(key_a));
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx_b,
                           dpf->CreateEvaluationContext(key_b));
  std::vector<uint32_t> result_a, result_b;
  DPF_ASSERT_OK_AND_ASSIGN(result_a,
                           dpf->EvaluateAt<uint32_t>(0, {prefix}, ctx_a));
  DPF_ASSERT_OK_AND_ASSIGN(result_b,
                           dpf->EvaluateAt<uint32_t>(0, {prefix}, ctx_b));
  EXPECT_EQ(result_a[0] + result_b[0], beta);

  // Now fully evaluate the second level.
  DPF_ASSERT_OK_AND_ASSIGN(result_a,
                           dpf->EvaluateUntil<uint32_t>(1, {prefix}, ctx_a));
  DPF_ASSERT_OK_AND_ASSIGN(result_b,
                           dpf->EvaluateUntil<uint32_t>(1, {prefix}, ctx_b));
  EXPECT_EQ(result_a.size(), result_b.size());
  EXPECT_EQ(result_a.size(), 1 << 20);
  for (int i = 0; i < static_cast<int>(result_a.size()); ++i) {
    if (i == suffix) {
      EXPECT_EQ(result_a[i] + result_b[i], beta);
    } else {
      EXPECT_EQ(result_a[i] + result_b[i], 0);
    }
  }
}

class RegularDpfKeyGenerationTest
    : public testing::TestWithParam<std::tuple<int, int>> {
 public:
  void SetUp() {
    std::tie(log_domain_size_, element_bitsize_) = GetParam();
    DpfParameters parameters;
    parameters.set_log_domain_size(log_domain_size_);
    parameters.mutable_value_type()->mutable_integer()->set_bitsize(
        element_bitsize_);
    DPF_ASSERT_OK_AND_ASSIGN(dpf_,
                             DistributedPointFunction::Create(parameters));
    DPF_ASSERT_OK_AND_ASSIGN(
        proto_validator_, dpf_internal::ProtoValidator::Create({parameters}));
  }

 protected:
  int log_domain_size_;
  int element_bitsize_;
  std::unique_ptr<DistributedPointFunction> dpf_;
  std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
};

TEST_P(RegularDpfKeyGenerationTest, KeyHasCorrectFormat) {
  DpfKey key_a, key_b;
  DPF_ASSERT_OK_AND_ASSIGN(std::tie(key_a, key_b), dpf_->GenerateKeys(0, 0));

  // Check that party is set correctly.
  EXPECT_EQ(key_a.party(), 0);
  EXPECT_EQ(key_b.party(), 1);
  // Check that keys are accepted by proto_validator_.
  DPF_EXPECT_OK(proto_validator_->ValidateDpfKey(key_a));
  DPF_EXPECT_OK(proto_validator_->ValidateDpfKey(key_b));
}

TEST_P(RegularDpfKeyGenerationTest, FailsIfBetaHasTheWrongSize) {
  EXPECT_THAT(
      dpf_->GenerateKeysIncremental(0, {1, 2}),
      StatusIs(absl::StatusCode::kInvalidArgument,
               "`beta` has to have the same size as `parameters` passed at "
               "construction"));
}

TEST_P(RegularDpfKeyGenerationTest, FailsIfAlphaIsTooLarge) {
  if (log_domain_size_ >= 128) {
    // Alpha is an absl::uint128, so never too large in this case.
    return;
  }

  EXPECT_THAT(dpf_->GenerateKeys((absl::uint128{1} << log_domain_size_), 1),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "`alpha` must be smaller than the output domain size"));
}

TEST_P(RegularDpfKeyGenerationTest, FailsIfBetaIsTooLarge) {
  if (element_bitsize_ >= 128) {
    // Beta is an absl::uint128, so never too large in this case.
    return;
  }

  const auto beta = absl::uint128{1} << element_bitsize_;

  // Not testing error message, as it's an implementation detail of
  // ProtoValidator.
  EXPECT_THAT(dpf_->GenerateKeys(0, beta),
              StatusIs(absl::StatusCode::kInvalidArgument));
}

INSTANTIATE_TEST_SUITE_P(VaryDomainAndElementSizes, RegularDpfKeyGenerationTest,
                         testing::Combine(testing::Values(0, 1, 2, 3, 4, 5, 6,
                                                          7, 8, 9, 10, 32, 62),
                                          testing::Values(8, 16, 32, 64, 128)));

struct DpfTestParameters {
  int log_domain_size;
  int element_bitsize;

  friend std::ostream& operator<<(std::ostream& os,
                                  const DpfTestParameters& parameters) {
    return os << "{ log_domain_size: " << parameters.log_domain_size
              << ", element_bitsize: " << parameters.element_bitsize << " }";
  }
};

class IncrementalDpfTest : public testing::TestWithParam<
                               std::tuple</*parameters*/
                                          std::vector<DpfTestParameters>,
                                          /*alpha*/ absl::uint128,
                                          /*beta*/ std::vector<absl::uint128>,
                                          /*level_step*/ int,
                                          /*single_point*/ bool>> {
 protected:
  void SetUp() {
    const std::vector<DpfTestParameters>& parameters = std::get<0>(GetParam());
    parameters_.resize(parameters.size());
    for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
      parameters_[i].set_log_domain_size(parameters[i].log_domain_size);
      parameters_[i].mutable_value_type()->mutable_integer()->set_bitsize(
          parameters[i].element_bitsize);
    }
    DPF_ASSERT_OK_AND_ASSIGN(
        dpf_, DistributedPointFunction::CreateIncremental(parameters_));
    alpha_ = std::get<1>(GetParam());
    beta_ = std::get<2>(GetParam());
    DPF_ASSERT_OK_AND_ASSIGN(keys_,
                             dpf_->GenerateKeysIncremental(alpha_, beta_));
    level_step_ = std::get<3>(
        GetParam());  // Number of hierarchy level to evaluate at once.
    single_point_ = std::get<4>(GetParam());
    DPF_ASSERT_OK_AND_ASSIGN(proto_validator_,
                             dpf_internal::ProtoValidator::Create(parameters_));
  }

  // Returns the prefix of `index` for the domain of `hierarchy_level`.
  absl::uint128 GetPrefixForLevel(int hierarchy_level, absl::uint128 index) {
    absl::uint128 result = 0;
    int shift_amount = parameters_.back().log_domain_size() -
                       parameters_[hierarchy_level].log_domain_size();
    if (shift_amount < 128) {
      result = index >> shift_amount;
    }
    return result;
  }

  // Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the
  // appropriate prefixes of `evaluation_points`. Checks that the expansion of
  // both keys form correct DPF shares, i.e., they add up to
  // `beta_[ctx.hierarchy_level()]` under prefixes of `alpha_`, and to 0
  // otherwise. If `singl_point == true`, only evaluates at the prefixes of the
  // given `evaluation_points`. Otherwise, fully expands the given
  // `hierarchy_level`.
  template <typename T>
  void EvaluateAndCheckLevel(int hierarchy_level,
                             absl::Span<const absl::uint128> evaluation_points,
                             EvaluationContext& ctx0, EvaluationContext& ctx1) {
    int previous_hierarchy_level = ctx0.previous_hierarchy_level();
    int current_log_domain_size =
        parameters_[hierarchy_level].log_domain_size();
    int previous_log_domain_size = 0;
    int num_expansions = 1;
    bool is_first_evaluation = previous_hierarchy_level < 0;
    std::vector<absl::uint128> prefixes;
    if (single_point_) {
      // Single point evaluation: Generate prefixes for the current hierarchy
      // level.
      prefixes.resize(evaluation_points.size());
      for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i) {
        prefixes[i] = GetPrefixForLevel(hierarchy_level, evaluation_points[i]);
      }
    } else if (!is_first_evaluation) {
      // Full expansion: Generate prefixes for the previous hierarchy level if
      // we're not on the first level.
      num_expansions = static_cast<int>(evaluation_points.size());
      prefixes.resize(evaluation_points.size());
      previous_log_domain_size =
          parameters_[previous_hierarchy_level].log_domain_size();
      for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i) {
        prefixes[i] =
            GetPrefixForLevel(previous_hierarchy_level, evaluation_points[i]);
      }
    }

    absl::StatusOr<std::vector<T>> result_0, result_1;
    if (single_point_) {
      result_0 = dpf_->EvaluateAt<T>(hierarchy_level, prefixes, ctx0);
      result_1 = dpf_->EvaluateAt<T>(hierarchy_level, prefixes, ctx1);
    } else {
      result_0 = dpf_->EvaluateUntil<T>(hierarchy_level, prefixes, ctx0);
      result_1 = dpf_->EvaluateUntil<T>(hierarchy_level, prefixes, ctx1);
    }

    // Check results are ok.
    DPF_EXPECT_OK(result_0)
        << "hierarchy_level=" << hierarchy_level << "\nparameters={\n"
        << parameters_[hierarchy_level].DebugString() << "}\n";
    DPF_EXPECT_OK(result_1);
    if (result_0.ok() && result_1.ok()) {
      // Check output sizes match.
      ASSERT_EQ(result_0->size(), result_1->size());

      if (single_point_) {
        absl::uint128 current_alpha_prefix =
            GetPrefixForLevel(hierarchy_level, alpha_);
        for (int i = 0; i < result_0->size(); ++i) {
          if (prefixes[i] == current_alpha_prefix) {
            EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]),
                      beta_[hierarchy_level])
                << "i=" << i << ", hierarchy_level=" << hierarchy_level
                << "\nparameters={\n"
                << parameters_[hierarchy_level].DebugString() << "}\n"
                << "current_alpha_prefix=" << current_alpha_prefix << "\n"
                << "prefixes[" << i << "]=" << prefixes[i] << "\n"
                << "evaluation_points[" << i << "]=" << evaluation_points[i]
                << "\n";
          } else {
            EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), 0)
                << "i=" << i << ", hierarchy_level=" << hierarchy_level
                << "\nparameters={\n"
                << parameters_[hierarchy_level].DebugString() << "}\n"
                << "current_alpha_prefix=" << current_alpha_prefix << "\n"
                << "prefixes[" << i << "]=" << prefixes[i] << "\n"
                << "evaluation_points[" << i << "]=" << evaluation_points[i]
                << "\n";
          }
        }
      } else {
        int64_t outputs_per_prefix =
            int64_t{1} << (current_log_domain_size - previous_log_domain_size);
        int64_t expected_output_size = num_expansions * outputs_per_prefix;
        ASSERT_EQ(result_0->size(), expected_output_size);

        // Iterate over the outputs and check that they sum up to 0 or to
        // beta_[current_hierarchy_level].
        absl::uint128 previous_alpha_prefix = 0;
        if (!is_first_evaluation) {
          previous_alpha_prefix =
              GetPrefixForLevel(previous_hierarchy_level, alpha_);
        }
        absl::uint128 current_alpha_prefix =
            GetPrefixForLevel(hierarchy_level, alpha_);
        for (int64_t i = 0; i < expected_output_size; ++i) {
          int prefix_index = i / outputs_per_prefix;
          int prefix_expansion_index = i % outputs_per_prefix;
          // The output is on the path to alpha, if we're at the first level or
          // under a prefix of alpha, and the current block in the expansion of
          // the prefix is also on the path to alpha.
          if ((is_first_evaluation ||
               prefixes[prefix_index] == previous_alpha_prefix) &&
              prefix_expansion_index ==
                  (current_alpha_prefix % outputs_per_prefix)) {
            // We need to static_cast here since otherwise operator+ returns an
            // unsigned int without doing a modular reduction, which causes the
            // test to fail on types with sizeof(T) < sizeof(unsigned).
            EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]),
                      beta_[hierarchy_level])
                << "i=" << i << ", hierarchy_level=" << hierarchy_level
                << "\nparameters={\n"
                << parameters_[hierarchy_level].DebugString() << "}\n"
                << "previous_alpha_prefix=" << previous_alpha_prefix << "\n"
                << "current_alpha_prefix=" << current_alpha_prefix << "\n";
          } else {
            EXPECT_EQ(static_cast<T>((*result_0)[i] + (*result_1)[i]), 0)
                << "i=" << i << ", hierarchy_level=" << hierarchy_level
                << "\nparameters={\n"
                << parameters_[hierarchy_level].DebugString() << "}\n";
          }
        }
      }
    }
  }

  std::vector<DpfParameters> parameters_;
  std::unique_ptr<DistributedPointFunction> dpf_;
  absl::uint128 alpha_;
  std::vector<absl::uint128> beta_;
  std::pair<DpfKey, DpfKey> keys_;
  int level_step_;
  bool single_point_;
  std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
};

TEST_P(IncrementalDpfTest, CreateEvaluationContextCreatesValidContext) {
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));
  DPF_EXPECT_OK(proto_validator_->ValidateEvaluationContext(ctx));
}

TEST_P(IncrementalDpfTest, FailsIfPrefixNotPresentInCtx) {
  if (parameters_.size() < 3 || parameters_[0].log_domain_size() < 1 ||
      parameters_[0].value_type().integer().bitsize() != 128 ||
      parameters_[1].value_type().integer().bitsize() != 128 ||
      parameters_[2].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  // Evaluate on prefixes 0 and 1, then delete partial evaluations for prefix 0.
  DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({}, ctx));
  DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({0, 1}, ctx));
  ctx.mutable_partial_evaluations()->erase(ctx.partial_evaluations().begin());

  // The missing prefix corresponds to hierarchy level 1, even though we
  // evaluate at level 2.
  EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({0}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "Prefix not present in ctx.partial_evaluations at "
                       "hierarchy level 1"));
}

TEST_P(IncrementalDpfTest, FailsIfDuplicatePrefixInCtx) {
  if (parameters_.size() < 3 || parameters_[0].log_domain_size() < 1 ||
      parameters_[0].value_type().integer().bitsize() != 128 ||
      parameters_[1].value_type().integer().bitsize() != 128 ||
      parameters_[2].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  // Evaluate on prefixes 0 and 1, then add duplicate prefix for 0 with
  // different seed.
  DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({}, ctx));
  DPF_ASSERT_OK(dpf_->EvaluateNext<absl::uint128>({0, 1}, ctx));
  *(ctx.add_partial_evaluations()) = ctx.partial_evaluations(0);
  Block changed_seed = ctx.partial_evaluations(0).seed();
  changed_seed.set_low(changed_seed.low() + 1);
  *((ctx.mutable_partial_evaluations()->end() - 1)->mutable_seed()) =
      changed_seed;

  // The missing prefix corresponds to hierarchy level 1, even though we
  // evaluate at level 2.
  EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({0}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "Duplicate prefix in `ctx.partial_evaluations()` with "
                       "mismatching seed or control bit"));
}

TEST_P(IncrementalDpfTest, EvaluationFailsOnEmptyContext) {
  if (parameters_[0].value_type().integer().bitsize() != 128) {
    return;
  }
  EvaluationContext ctx;

  // We don't check the error message, since it depends on the ProtoValidator
  // implementation which is tested in the corresponding unit test.
  EXPECT_THAT(dpf_->EvaluateNext<absl::uint128>({}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfHierarchyLevelNegative) {
  if (parameters_[0].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(-1, {}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "`hierarchy_level` must be non-negative and less than "
                       "parameters_.size()"));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfHierarchyLevelTooLarge) {
  if (parameters_[0].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(parameters_.size(), {}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "`hierarchy_level` must be non-negative and less than "
                       "parameters_.size()"));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfValueTypeDoesntMatch) {
  using SomeStrangeType = Tuple<uint8_t, uint32_t, uint8_t, uint16_t, uint8_t>;
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  EXPECT_THAT(
      dpf_->EvaluateUntil<SomeStrangeType>(0, {}, ctx),
      StatusIs(absl::StatusCode::kInvalidArgument,
               "Value type T doesn't match parameters at `hierarchy_level`"));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfLevelAlreadyEvaluated) {
  if (parameters_.size() < 2 ||
      parameters_[0].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  DPF_ASSERT_OK(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx));

  EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       "`hierarchy_level` must be greater than "
                       "`ctx.previous_hierarchy_level`"));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfPrefixesNotEmptyOnFirstCall) {
  if (parameters_[0].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  EXPECT_THAT(
      dpf_->EvaluateUntil<absl::uint128>(0, {0}, ctx),
      StatusIs(
          absl::StatusCode::kInvalidArgument,
          "`prefixes` must be empty if and only if this is the first call with "
          "`ctx`."));
}

TEST_P(IncrementalDpfTest, EvaluationFailsIfPrefixOutOfRange) {
  if (parameters_.size() < 2 ||
      parameters_[0].value_type().integer().bitsize() != 128 ||
      parameters_[1].value_type().integer().bitsize() != 128) {
    return;
  }
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx,
                           dpf_->CreateEvaluationContext(keys_.first));

  DPF_ASSERT_OK(dpf_->EvaluateUntil<absl::uint128>(0, {}, ctx));
  auto invalid_prefix = absl::uint128{1} << parameters_[0].log_domain_size();

  EXPECT_THAT(dpf_->EvaluateUntil<absl::uint128>(1, {invalid_prefix}, ctx),
              StatusIs(absl::StatusCode::kInvalidArgument,
                       StrFormat("Index %d out of range for hierarchy level 0",
                                 invalid_prefix)));
}

TEST_P(IncrementalDpfTest, TestCorrectness) {
  // Generate a random set of evaluation points. The library should be able to
  // handle duplicates, so fixing the size to 1000 works even for smaller
  // domains.
  absl::BitGen rng;
  absl::uniform_int_distribution<uint64_t> dist;
  const int kNumEvaluationPoints = 1000;
  std::vector<absl::uint128> evaluation_points(kNumEvaluationPoints);
  for (int i = 0; i < kNumEvaluationPoints - 1; ++i) {
    evaluation_points[i] = absl::MakeUint128(dist(rng), dist(rng));
    if (parameters_.back().log_domain_size() < 128) {
      evaluation_points[i] %= absl::uint128{1}
                              << parameters_.back().log_domain_size();
    }
  }
  evaluation_points.back() = alpha_;  // Always evaluate on alpha_.

  int num_levels = static_cast<int>(parameters_.size());
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx0,
                           dpf_->CreateEvaluationContext(keys_.first));
  DPF_ASSERT_OK_AND_ASSIGN(EvaluationContext ctx1,
                           dpf_->CreateEvaluationContext(keys_.second));

  for (int i = level_step_ - 1; i < num_levels; i += level_step_) {
    switch (parameters_[i].value_type().integer().bitsize()) {
      case 8:
        EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, ctx0, ctx1);
        break;
      case 16:
        EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, ctx0, ctx1);
        break;
      case 32:
        EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, ctx0, ctx1);
        break;
      case 64:
        EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, ctx0, ctx1);
        break;
      case 128:
        EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, ctx0, ctx1);
        break;
      default:
        ASSERT_TRUE(0) << "Unsupported element_bitsize";
    }
  }
}

INSTANTIATE_TEST_SUITE_P(
    OneHierarchyLevelVaryElementSizes, IncrementalDpfTest,
    testing::Combine(
        // DPF parameters.
        testing::Values(
            // Vary element sizes, small domain size.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 16}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 32}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 128}},
            // Vary element sizes, medium domain size.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 10, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 10, .element_bitsize = 16}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 10, .element_bitsize = 32}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 10, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 10, .element_bitsize = 128}}),
        testing::Values(0, 1, 15),  // alpha
        testing::Values(std::vector<absl::uint128>(1, 1),
                        std::vector<absl::uint128>(1, 100),
                        std::vector<absl::uint128>(1, 255)),  // beta
        testing::Values(1),                                   // level_step
        testing::Values(false, true)                          // single_point
        ));

INSTANTIATE_TEST_SUITE_P(
    OneHierarchyLevelVaryDomainSizes, IncrementalDpfTest,
    testing::Combine(
        // DPF parameters.
        testing::Values(
            // Vary domain sizes, small element size.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 1, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 2, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 3, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 6, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 7, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 8, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 9, .element_bitsize = 8}},
            // Vary domain sizes, medium element size.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 1, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 2, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 3, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 6, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 7, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 8, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 9, .element_bitsize = 64}},
            // Vary domain sizes, large element size.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 1, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 2, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 3, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 6, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 7, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 8, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 9, .element_bitsize = 128}}),
        testing::Values(0),  // alpha
        testing::Values(std::vector<absl::uint128>(1, 1),
                        std::vector<absl::uint128>(1, 100),
                        std::vector<absl::uint128>(1, 255)),  // beta
        testing::Values(1),                                   // level_step
        testing::Values(false, true)                          // single_point
        ));

INSTANTIATE_TEST_SUITE_P(
    TwoHierarchyLevels, IncrementalDpfTest,
    testing::Combine(
        // DPF parameters.
        testing::Values(
            // Equal element sizes.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 8},
                {.log_domain_size = 10, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 16},
                {.log_domain_size = 10, .element_bitsize = 16}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 32},
                {.log_domain_size = 10, .element_bitsize = 32}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 64},
                {.log_domain_size = 10, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 128},
                {.log_domain_size = 10, .element_bitsize = 128}},
            // First correction is in seed word.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 8},
                {.log_domain_size = 10, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 16},
                {.log_domain_size = 10, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 32},
                {.log_domain_size = 10, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 64},
                {.log_domain_size = 10, .element_bitsize = 128}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 128},
                {.log_domain_size = 10, .element_bitsize = 128}}),
        testing::Values(0, 1, 2, 100, 1023),  // alpha
        testing::Values(std::vector<absl::uint128>({1, 2}),
                        std::vector<absl::uint128>({80, 90}),
                        std::vector<absl::uint128>({255, 255})),  // beta
        testing::Values(1, 2),                                    // level_step
        testing::Values(false, true)  // single_point
        ));

INSTANTIATE_TEST_SUITE_P(
    ThreeHierarchyLevels, IncrementalDpfTest,
    testing::Combine(
        // DPF parameters.
        testing::Values<std::vector<DpfTestParameters>>(
            // Equal element sizes.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 8},
                {.log_domain_size = 10, .element_bitsize = 8},
                {.log_domain_size = 15, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 16},
                {.log_domain_size = 10, .element_bitsize = 16},
                {.log_domain_size = 15, .element_bitsize = 16}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 32},
                {.log_domain_size = 10, .element_bitsize = 32},
                {.log_domain_size = 15, .element_bitsize = 32}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 64},
                {.log_domain_size = 10, .element_bitsize = 64},
                {.log_domain_size = 15, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 128},
                {.log_domain_size = 10, .element_bitsize = 128},
                {.log_domain_size = 15, .element_bitsize = 128}},
            // Varying element sizes
            std::vector<DpfTestParameters>{
                {.log_domain_size = 5, .element_bitsize = 8},
                {.log_domain_size = 10, .element_bitsize = 16},
                {.log_domain_size = 15, .element_bitsize = 32}},
            // Small level distances.
            std::vector<DpfTestParameters>{
                {.log_domain_size = 4, .element_bitsize = 8},
                {.log_domain_size = 5, .element_bitsize = 8},
                {.log_domain_size = 6, .element_bitsize = 8}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 3, .element_bitsize = 16},
                {.log_domain_size = 4, .element_bitsize = 16},
                {.log_domain_size = 5, .element_bitsize = 16}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 2, .element_bitsize = 32},
                {.log_domain_size = 3, .element_bitsize = 32},
                {.log_domain_size = 4, .element_bitsize = 32}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 1, .element_bitsize = 64},
                {.log_domain_size = 2, .element_bitsize = 64},
                {.log_domain_size = 3, .element_bitsize = 64}},
            std::vector<DpfTestParameters>{
                {.log_domain_size = 0, .element_bitsize = 128},
                {.log_domain_size = 1, .element_bitsize = 128},
                {.log_domain_size = 2, .element_bitsize = 128}}),
        testing::Values(0, 1),                                   // alpha
        testing::Values(std::vector<absl::uint128>({1, 2, 3})),  // beta
        testing::Values(1, 2),                                   // level_step
        testing::Values(false, true)                             // single_point
        ));

INSTANTIATE_TEST_SUITE_P(
    MaximumOutputDomainSize, IncrementalDpfTest,
    testing::Combine(
        // DPF parameters. We want to be able to evaluate at every bit, so this
        // lambda returns a vector with 129 parameters with log domain sizes
        // 0...128.
        testing::Values([]() -> std::vector<DpfTestParameters> {
          std::vector<DpfTestParameters> parameters(129);
          for (int i = 0; i < static_cast<int>(parameters.size()); ++i) {
            parameters[i].log_domain_size = i;
            parameters[i].element_bitsize = 64;
          }
          return parameters;
        }()),
        testing::Values(absl::MakeUint128(23, 42)),                 // alpha
        testing::Values(std::vector<absl::uint128>(129, 1234567)),  // beta
        testing::Values(1, 2, 3, 5, 7),  // level_step
        testing::Values(false, true)     // single_point
        ));

template <typename T>
class DpfEvaluationTest : public ::testing::Test {
 protected:
  void SetUp() { SetUp(10, 23); }
  void SetUp(int log_domain_size, absl::uint128 alpha) {
    return SetUp(absl::MakeConstSpan(&log_domain_size, 1), alpha);
  }
  void SetUp(absl::Span<const int> log_domain_size, absl::uint128 alpha) {
    log_domain_size_.resize(log_domain_size.size());
    absl::c_copy(log_domain_size, log_domain_size_.begin());
    alpha_ = alpha;
    beta_.resize(log_domain_size.size());
    for (T& beta : beta_) {
      SetTo42(beta);
    }
    parameters_.resize(log_domain_size.size());
    for (int i = 0; i < parameters_.size(); ++i) {
      parameters_[i].set_log_domain_size(log_domain_size_[i]);
      parameters_[i].set_security_parameter(48);
      *(parameters_[i].mutable_value_type()) = ToValueType<T>();
    }
    DPF_ASSERT_OK_AND_ASSIGN(
        dpf_, DistributedPointFunction::CreateIncremental(parameters_));
    DPF_ASSERT_OK(this->dpf_->template RegisterValueType<T>());
    DPF_ASSERT_OK_AND_ASSIGN(
        keys_, this->dpf_->GenerateKeysIncremental(
                   this->alpha_, absl::MakeConstSpan(this->beta_)));
  }

  // Helper function that recursively sets all elements of a tuple to 42.
  template <typename T0>
  static void SetTo42(T0& x) {
    x = T0(42);
  }
  template <typename T0, typename... Tn>
  static void SetTo42(T0& x0, Tn&... xn) {
    SetTo42(x0);
    SetTo42(xn...);
  }
  template <typename... Tn>
  static void SetTo42(Tuple<Tn...>& x) {
    absl::apply([](auto&... in) { SetTo42(in...); }, x.value());
  }

  std::vector<int> log_domain_size_;
  absl::uint128 alpha_;
  std::vector<T> beta_;
  std::vector<DpfParameters> parameters_;
  std::unique_ptr<DistributedPointFunction> dpf_;
  std::pair<DpfKey, DpfKey> keys_;
};

using MyIntModN = IntModN<uint32_t, 4294967291u>;                // 2**32 - 5.
using MyIntModN64 = IntModN<uint64_t, 18446744073709551557ull>;  // 2**64 - 59.
#ifdef ABSL_HAVE_INTRINSIC_INT128
using MyIntModN128 =
    IntModN<absl::uint128, (unsigned __int128)(absl::MakeUint128(
                               65535u, 18446744073709551551ull))>;  // 2**80-65
#endif
using DpfEvaluationTypes = ::testing::Types<
    // Integers
    uint8_t, uint32_t, uint64_t, absl::uint128,
    // Tuple
    Tuple<uint8_t>, Tuple<uint32_t>, Tuple<absl::uint128>,
    Tuple<uint32_t, uint32_t>, Tuple<uint32_t, uint64_t>,
    Tuple<uint64_t, uint64_t>, Tuple<uint8_t, uint16_t, uint32_t, uint64_t>,
    Tuple<uint32_t, uint32_t, uint32_t, uint32_t>,
    Tuple<uint32_t, Tuple<uint32_t, uint32_t>, uint32_t>,
    Tuple<uint32_t, absl::uint128>,
    // IntModN
    MyIntModN, Tuple<MyIntModN>, Tuple<uint32_t, MyIntModN>,
    Tuple<absl::uint128, MyIntModN>, Tuple<MyIntModN, Tuple<MyIntModN>>,
    Tuple<MyIntModN, MyIntModN, MyIntModN, MyIntModN, MyIntModN>,
    Tuple<MyIntModN64, MyIntModN64>
#ifdef ABSL_HAVE_INTRINSIC_INT128
    ,
    Tuple<MyIntModN128, MyIntModN128>,
#endif
    // XorWrapper
    XorWrapper<uint8_t>, XorWrapper<absl::uint128>,
    Tuple<XorWrapper<uint32_t>, absl::uint128>>;
TYPED_TEST_SUITE(DpfEvaluationTest, DpfEvaluationTypes);

TYPED_TEST(DpfEvaluationTest, TestRegularDpf) {
  int log_domain_size = 10;
  absl::uint128 alpha = 23;
  this->SetUp(log_domain_size, alpha);
  DPF_ASSERT_OK_AND_ASSIGN(
      EvaluationContext ctx_1,
      this->dpf_->CreateEvaluationContext(this->keys_.first));
  DPF_ASSERT_OK_AND_ASSIGN(
      EvaluationContext ctx_2,
      this->dpf_->CreateEvaluationContext(this->keys_.second));
  DPF_ASSERT_OK_AND_ASSIGN(
      std::vector<TypeParam> output_1,
      this->dpf_->template EvaluateNext<TypeParam>({}, ctx_1));
  DPF_ASSERT_OK_AND_ASSIGN(
      std::vector<TypeParam> output_2,
      this->dpf_->template EvaluateNext<TypeParam>({}, ctx_2));

  EXPECT_EQ(output_1.size(), 1 << log_domain_size);
  EXPECT_EQ(output_2.size(), 1 << log_domain_size);
  for (int i = 0; i < (1 << log_domain_size); ++i) {
    TypeParam sum = output_1[i] + output_2[i];
    if (i == this->alpha_) {
      EXPECT_EQ(sum, this->beta_[0]);
    } else {
      EXPECT_EQ(sum, TypeParam{});
    }
  }
}

TYPED_TEST(DpfEvaluationTest, TestBatchSinglePointEvaluation) {
  // Set Up with a large output domain, to make sure this works.
  for (int log_domain_size : {0, 1, 2, 32, 128}) {
    absl::uint128 max_evaluation_point = absl::Uint128Max();
    if (log_domain_size < 128) {
      max_evaluation_point = (absl::uint128{1} << log_domain_size) - 1;
    }
    const absl::uint128 alpha = 23 & max_evaluation_point;
    this->SetUp(log_domain_size, alpha);
    for (int num_evaluation_points : {0, 1, 2, 100, 1000}) {
      std::vector<absl::uint128> evaluation_points(num_evaluation_points);
      for (int i = 0; i < num_evaluation_points; ++i) {
        evaluation_points[i] = i & max_evaluation_point;
      }
      DPF_ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> output_1,
                               this->dpf_->template EvaluateAt<TypeParam>(
                                   this->keys_.first, 0, evaluation_points));
      DPF_ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> output_2,
                               this->dpf_->template EvaluateAt<TypeParam>(
                                   this->keys_.second, 0, evaluation_points));
      ASSERT_EQ(output_1.size(), output_2.size());
      ASSERT_EQ(output_1.size(), num_evaluation_points);

      for (int i = 0; i < num_evaluation_points; ++i) {
        TypeParam sum = output_1[i] + output_2[i];
        if (evaluation_points[i] == this->alpha_) {
          EXPECT_EQ(sum, this->beta_[0])
              << "i=" << i << ", log_domain_size=" << log_domain_size;
        } else {
          EXPECT_EQ(sum, TypeParam{})
              << "i=" << i << ", log_domain_size=" << log_domain_size;
        }
      }
    }
  }
}

TYPED_TEST(DpfEvaluationTest, TestEvaluateAndApplySimpleAddition) {
  std::vector<std::vector<int>> parameters = {
      {0, 1, 2}, {8, 16, 32, 64}, {0, 128}, {128}, {/* filled below */}};
  for (int i = 0; i <= 128; ++i) {
    parameters.back().push_back(i);
  }
  for (const auto& log_domain_sizes : parameters) {
    absl::uint128 max_domain_element = absl::Uint128Max();
    if (log_domain_sizes.back() < 128) {
      max_domain_element = (absl::uint128{1} << log_domain_sizes.back()) - 1;
    }
    absl::uint128 alpha = max_domain_element;
    this->SetUp(log_domain_sizes, alpha);

    std::vector<absl::uint128> evaluation_points = {23, 42, 123, 0,
                                                    absl::Uint128Max()};
    for (auto& point : evaluation_points) {
      point &= max_domain_element;
    }
    std::vector<const DpfKey*> keys = {
        &(this->keys_.first), &(this->keys_.second), &(this->keys_.first),
        &(this->keys_.second), &(this->keys_.first)};
    int num_levels = log_domain_sizes.size();
    int num_keys = keys.size();

    std::vector<TypeParam> sum(num_keys, TypeParam{});
    int count = 0;
    auto fn = [&sum, &count](absl::Span<const TypeParam> values) {
      for (int i = 0; i < values.size(); ++i) {
        sum[i] += values[i];
      }
      ++count;
      return true;
    };

    // Run evaluation level-by-level to compute the expected sum.
    std::vector<TypeParam> expected(num_keys, TypeParam{});
    for (int hierarchy_level = 0; hierarchy_level < num_levels;
         ++hierarchy_level) {
      const int shift_amount =
          (log_domain_sizes.back() - log_domain_sizes[hierarchy_level]);
      for (int i = 0; i < num_keys; ++i) {
        absl::uint128 prefix = 0;
        if (shift_amount < 128) {
          prefix = evaluation_points[i] >> shift_amount;
        }
        DPF_ASSERT_OK_AND_ASSIGN(
            auto result,
            this->dpf_->template EvaluateAt<TypeParam>(
                *keys[i], hierarchy_level, absl::MakeConstSpan(&prefix, 1)));
        expected[i] += result[0];
      }
    }

    EXPECT_THAT(this->dpf_->template EvaluateAndApply<TypeParam>(
                    keys, evaluation_points, fn),
                IsOk());
    EXPECT_EQ(sum, expected)
        << "log_domain_sizes=" << absl::StrJoin(log_domain_sizes, " ");
    EXPECT_EQ(count, num_levels);
  }
}

TYPED_TEST(DpfEvaluationTest,
           EvaluateAndApplyFailsWithTooManyEvaluationPoints) {
  std::vector<absl::uint128> evaluation_points = {0, 1};

  EXPECT_THAT(
      this->dpf_->template EvaluateAndApply<TypeParam>(
          absl::MakeConstSpan(&(this->keys_.first), 1), evaluation_points,
          [](absl::Span<const TypeParam>) { return true; }),
      StatusIs(absl::StatusCode::kInvalidArgument,
               HasSubstr("evaluation_points")));
}

TYPED_TEST(DpfEvaluationTest, EvaluateAndApplyFailsWithInvalidKey) {
  DpfKey key;

  EXPECT_THAT(this->dpf_->template EvaluateAndApply<TypeParam>(
                  absl::MakeConstSpan(&key, 1), {0},
                  [](absl::Span<const TypeParam>) { return true; }),
              StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("key")));
}

}  // namespace
}  // namespace distributed_point_functions