chromium/third_party/distributed_point_functions/fuzz/dpf_fuzzer.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "third_party/distributed_point_functions/code/dpf/distributed_point_function.h"

#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>

#include <fuzzer/FuzzedDataProvider.h>

#include <algorithm>
#include <memory>

#define DPF_FUZZER_ASSERT(x)                                         \
  if (!(x)) {                                                        \
    printf("DPF assertion failed: function %s, file %s, line %d.\n", \
           __PRETTY_FUNCTION__, __FILE__, __LINE__);                 \
    abort();                                                         \
  }

namespace {

const size_t UINT128_SIZE = 2 * sizeof(uint64_t);

// Constructs a `uint128` numeric value from two 64-bit unsigned integers
// consumed from the data provider.
absl::uint128 ConsumeUint128(FuzzedDataProvider& data_provider) {
  uint64_t high = data_provider.ConsumeIntegral<uint64_t>();
  uint64_t low = data_provider.ConsumeIntegral<uint64_t>();
  return absl::MakeUint128(high, low);
}

// Returns the prefix of `index` for the domain of `hierarchy_level`.
// Adapted from `DpfEvaluationTest::GetPrefixForLevel()`.
absl::uint128 GetPrefixForLevel(
    int hierarchy_level,
    absl::uint128 index,
    const std::vector<distributed_point_functions::DpfParameters>& parameters) {
  absl::uint128 result = 0;
  int shift_amount = parameters.back().log_domain_size() -
                     parameters[hierarchy_level].log_domain_size();
  if (shift_amount < 128)
    result = index >> shift_amount;
  return result;
}

// Evaluates both contexts `ctx0` and `ctx1` at `hierarchy level`, using the
// appropriate prefixes of `evaluation_points`. Checks that the expansion of
// both keys from correct DPF shares, i.e., they add up to
// `beta[ctx.hierarchy_level()]` under prefixes of `alpha`, and to 0 otherwise.
// Adapted from `DpfEvaluationTest::EvaluateAndCheckLevel()`.
template <typename T>
void EvaluateAndCheckLevel(
    int hierarchy_level,
    absl::Span<const absl::uint128> evaluation_points,
    absl::uint128 alpha,
    const std::vector<absl::uint128>& beta,
    distributed_point_functions::EvaluationContext& ctx0,
    distributed_point_functions::EvaluationContext& ctx1,
    const std::vector<distributed_point_functions::DpfParameters>& parameters,
    const distributed_point_functions::DistributedPointFunction& dpf) {
  int previous_hierarchy_level = ctx0.previous_hierarchy_level();
  int current_log_domain_size = parameters[hierarchy_level].log_domain_size();
  int previous_log_domain_size = 0;
  int num_expansions = 1;
  bool is_first_evaluation = previous_hierarchy_level < 0;
  // Generate prefixes if we're not on the first level.
  std::vector<absl::uint128> prefixes;
  if (!is_first_evaluation) {
    num_expansions = static_cast<int>(evaluation_points.size());
    prefixes.resize(evaluation_points.size());
    previous_log_domain_size =
        parameters[previous_hierarchy_level].log_domain_size();
    for (int i = 0; i < static_cast<int>(evaluation_points.size()); ++i)
      prefixes[i] = GetPrefixForLevel(previous_hierarchy_level,
                                      evaluation_points[i], parameters);
  }

  // Evaluating a key with N correction words leads to an O(2^N) malloc, which
  // will unsurprisingly cause a fuzzer crash. See <https://crbug.com/1494260>.
  constexpr size_t kMaxCorrectionWords = 30;
  if (ctx0.key().correction_words().size() > kMaxCorrectionWords) {
    return;
  }
  absl::StatusOr<std::vector<T>> result_0 =
      dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx0);
  DPF_FUZZER_ASSERT(result_0.ok());
  if (ctx1.key().correction_words().size() > kMaxCorrectionWords) {
    return;
  }
  absl::StatusOr<std::vector<T>> result_1 =
      dpf.EvaluateUntil<T>(hierarchy_level, prefixes, ctx1);
  DPF_FUZZER_ASSERT(result_1.ok());

  DPF_FUZZER_ASSERT(result_0->size() == result_1->size());
  int64_t outputs_per_prefix =
      int64_t{1} << (current_log_domain_size - previous_log_domain_size);
  int64_t expected_output_size = num_expansions * outputs_per_prefix;
  DPF_FUZZER_ASSERT(static_cast<int64_t>(result_0->size()) ==
                    expected_output_size);

  // Iterator over the outputs and check that they sum up to 0 or to
  // `beta[current_hierarchy_level]`;
  absl::uint128 previous_alpha_prefix = 0;
  if (!is_first_evaluation)
    previous_alpha_prefix =
        GetPrefixForLevel(previous_hierarchy_level, alpha, parameters);

  absl::uint128 current_alpha_prefix =
      GetPrefixForLevel(hierarchy_level, alpha, parameters);
  for (int64_t i = 0; i < expected_output_size; ++i) {
    int prefix_index = i / outputs_per_prefix;
    int prefix_expansion_index = i % outputs_per_prefix;
    // The output is on the path to `alpha`, if we're at the first level or
    // under a prefix of `alpha`, and the current block in the expansion of the
    // prefix is also on the path to `alpha`.
    if ((is_first_evaluation ||
         prefixes[prefix_index] == previous_alpha_prefix) &&
        prefix_expansion_index == (current_alpha_prefix % outputs_per_prefix)) {
      // We need to static_cast here since otherwise operator+ returns an
      // unsigned int without doing a modular reduction, which causes the test
      // to fail on types with sizeof(T) < sizeof(unsigned).
      DPF_FUZZER_ASSERT(
          absl::uint128{static_cast<T>((*result_0)[i] + (*result_1)[i])} ==
          beta[hierarchy_level]);
    } else {
      DPF_FUZZER_ASSERT(static_cast<T>((*result_0)[i] + (*result_1)[i]) == 0U);
    }
  }
}

}  // namespace

extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
  // Use magic separator to split the input into two parts. The first part will
  // generate alpha, and an array of parameters and betas. The second part will
  // generate level step and an array of evaluation points.
  const uint8_t separator[] = {0xDE, 0xAD, 0xBE, 0xEF};

  const uint8_t* pos =
      std::search(data, data + size, separator, separator + sizeof(separator));

  const uint8_t* data1 = data;
  size_t size1 = pos - data;

  const uint8_t* data2 =
      (pos == data + size) ? nullptr : pos + sizeof(separator);
  size_t size2 = data2 ? (data + size) - (pos + sizeof(separator)) : 0;

  FuzzedDataProvider data_provider1(data1, size1);

  if (data_provider1.remaining_bytes() < UINT128_SIZE)
    return 0;

  absl::uint128 alpha = ConsumeUint128(data_provider1);

  std::vector<int32_t> log_domain_sizes;
  std::vector<int32_t> element_bitsizes;
  std::vector<distributed_point_functions::DpfParameters> parameters;
  std::vector<absl::uint128> beta;

  // log_domain_size(int32_t), element_bitsize(int32_t),
  // beta(uint128)
  while (data_provider1.remaining_bytes() >=
         (2 * sizeof(int32_t) + UINT128_SIZE)) {
    int32_t log_domain_size = data_provider1.ConsumeIntegral<int32_t>();
    int32_t element_bitsize = data_provider1.ConsumeIntegral<int32_t>();
    log_domain_sizes.push_back(log_domain_size);
    element_bitsizes.push_back(element_bitsize);

    distributed_point_functions::DpfParameters parameter;
    parameter.set_log_domain_size(log_domain_size);
    parameter.mutable_value_type()->mutable_integer()->set_bitsize(
        element_bitsize);
    parameters.push_back(parameter);

    beta.push_back(ConsumeUint128(data_provider1));
  }

  absl::StatusOr<
      std::unique_ptr<distributed_point_functions::DistributedPointFunction>>
      status_or_dpf = distributed_point_functions::DistributedPointFunction::
          CreateIncremental(parameters);

  size_t num_levels = parameters.size();

  if (!status_or_dpf.ok()) {
    // `log_domain_size` is expected to be in ascending order and
    // `element_bitsize` is expected to be non-decreasing. As it is hard for the
    // fuzzer to land upon a valid input, we sort the parameters and try again
    // if the construction fails.
    std::sort(log_domain_sizes.begin(), log_domain_sizes.end());
    std::sort(element_bitsizes.begin(), element_bitsizes.end());
    for (size_t i = 0; i < num_levels; ++i) {
      parameters[i].set_log_domain_size(log_domain_sizes[i]);
      parameters[i].mutable_value_type()->mutable_integer()->set_bitsize(
          element_bitsizes[i]);
    }

    status_or_dpf = distributed_point_functions::DistributedPointFunction::
        CreateIncremental(parameters);
  }

  if (!status_or_dpf.ok())
    return 0;

  std::unique_ptr<distributed_point_functions::DistributedPointFunction> dpf =
      std::move(status_or_dpf).value();

  absl::StatusOr<std::pair<distributed_point_functions::DpfKey,
                           distributed_point_functions::DpfKey>>
      status_or_keys = dpf->GenerateKeysIncremental(alpha, beta);
  if (!status_or_keys.ok())
    return 0;

  std::pair<distributed_point_functions::DpfKey,
            distributed_point_functions::DpfKey>
      keys = std::move(status_or_keys).value();

  // Adapted from `DpfEvaluationTest.TestCorrectness()`.
  absl::StatusOr<distributed_point_functions::EvaluationContext>
      status_or_ctx0 = dpf->CreateEvaluationContext(keys.first);
  DPF_FUZZER_ASSERT(status_or_ctx0.ok());

  absl::StatusOr<distributed_point_functions::EvaluationContext>
      status_or_ctx1 = dpf->CreateEvaluationContext(keys.second);
  DPF_FUZZER_ASSERT(status_or_ctx1.ok());

  distributed_point_functions::EvaluationContext ctx0 =
      std::move(status_or_ctx0).value();
  distributed_point_functions::EvaluationContext ctx1 =
      std::move(status_or_ctx1).value();

  // Generate evaluation points.
  FuzzedDataProvider data_provider2(data2, size2);
  if (data_provider2.remaining_bytes() < sizeof(int))
    return 0;

  int level_step = data_provider2.ConsumeIntegralInRange<int>(1, 10);

  std::vector<absl::uint128> evaluation_points;
  while (data_provider2.remaining_bytes() >= UINT128_SIZE) {
    evaluation_points.push_back(ConsumeUint128(data_provider2));
    if (parameters.back().log_domain_size() < 128)
      evaluation_points.back() %=
          (absl::uint128{1} << parameters.back().log_domain_size());
  }

  // Always evaluate on alpha.
  evaluation_points.push_back(alpha);

  int32_t previous_log_domain_size = 0;
  for (int i = level_step - 1; i < static_cast<int>(num_levels);
       i += level_step) {
    // If any gap in the log_domain_sizes used in successive evaluations is
    // larger than 62, validation will fail in `EvaluateAndCheckLevel`.
    int32_t current_log_domain_size = parameters[i].log_domain_size();
    if (current_log_domain_size - previous_log_domain_size > 62)
      return 0;
    previous_log_domain_size = current_log_domain_size;

    switch (parameters[i].value_type().integer().bitsize()) {
      case 8:
        EvaluateAndCheckLevel<uint8_t>(i, evaluation_points, alpha, beta, ctx0,
                                       ctx1, parameters, *dpf);
        break;
      case 16:
        EvaluateAndCheckLevel<uint16_t>(i, evaluation_points, alpha, beta, ctx0,
                                        ctx1, parameters, *dpf);
        break;
      case 32:
        EvaluateAndCheckLevel<uint32_t>(i, evaluation_points, alpha, beta, ctx0,
                                        ctx1, parameters, *dpf);
        break;
      case 64:
        EvaluateAndCheckLevel<uint64_t>(i, evaluation_points, alpha, beta, ctx0,
                                        ctx1, parameters, *dpf);
        break;
      case 128:
        EvaluateAndCheckLevel<absl::uint128>(i, evaluation_points, alpha, beta,
                                             ctx0, ctx1, parameters, *dpf);
        break;
      default:
        // DPF construction should've failed if the parameters were invalid.
        DPF_FUZZER_ASSERT(false);
        break;
    }
  }

  return 0;
}