chromium/third_party/distributed_point_functions/code/dpf/internal/evaluate_prg_hwy.cc

// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dpf/internal/evaluate_prg_hwy.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <vector>

#include "absl/base/config.h"
#include "absl/base/optimization.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/absl_check.h"
#include "absl/numeric/int128.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/status_macros.h"
#include "hwy/aligned_allocator.h"
#include "openssl/aes.h"

// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE
#include "hwy/foreach_target.h"
// clang-format on

#include "dpf/internal/aes_128_fixed_key_hash_hwy.h"
#include "hwy/highway.h"

HWY_BEFORE_NAMESPACE();
namespace distributed_point_functions {
namespace dpf_internal {
namespace HWY_NAMESPACE {

hn;

#if HWY_TARGET == HWY_SCALAR

absl::Status EvaluateSeedsHwy(
    int64_t num_seeds, int num_levels, const absl::uint128* seeds_in,
    const bool* control_bits_in, const absl::uint128* paths,
    const absl::uint128* correction_seeds, const bool* correction_controls_left,
    const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
    const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
    bool* control_bits_out) {
  return EvaluateSeedsNoHwy(num_seeds, num_levels, seeds_in, control_bits_in,
                            paths, correction_seeds, correction_controls_left,
                            correction_controls_right, prg_left, prg_right,
                            seeds_out, control_bits_out);
}

#else

// Converts a bool array to a block-level mask suitable for vectors described by
// `d`. The mask value for each integer in the i-th block is set to input[i].
// If `max_blocks > 0`, returns after reading `max_blocks` bools from `input`.
template <typename D>
auto MaskFromBools(D d, const bool* input, int max_blocks = 0) {}

// Converts a mask for types `d` to a bool array. Assumes that the mask value
// for all integers in the i-th block is equal, and writes that value to
// output[i]. If `max_blocks > 0`, returns after writing `max_blocks` bools to
// `output`.
template <typename D, typename M>
void BoolsFromMask(D d, M mask, bool* output, int max_blocks = 0) {}

template <typename M>
M IfThenElseMask(M condition, M true_value, M false_value) {}

// Returns a mask that is `true` on all blocks where `input[i] & (1 << index)`
// is nonzero. The mask is a 64-bit-level mask, suitable for AES hashing.
template <typename V, typename D>
auto IsBitSet(D d, const V input, int index) {}

// Dummy struct to get HWY_ALIGN as a number, for testing if an array of
// absl::uint128 is aligned.
struct HWY_ALIGN Aligned128 {};

absl::Status EvaluateSeedsHwy(
    int64_t num_seeds, int num_levels, int num_correction_words,
    const absl::uint128* seeds_in, const bool* control_bits_in,
    const absl::uint128* paths, int paths_rightshift,
    const absl::uint128* correction_seeds, const bool* correction_controls_left,
    const bool* correction_controls_right, const Aes128FixedKeyHash& prg_left,
    const Aes128FixedKeyHash& prg_right, absl::uint128* seeds_out,
    bool* control_bits_out) {}

#endif  // HWY_TARGET == HWY_SCALAR

}  // namespace HWY_NAMESPACE
}  // namespace dpf_internal
}  // namespace distributed_point_functions
HWY_AFTER_NAMESPACE();

#if HWY_ONCE || HWY_IDE
namespace distributed_point_functions {
dpf_internal  // namespace dpf_internal
}  // namespace distributed_point_functions
#endif