chromium/third_party/shell-encryption/src/sample_error_test.cc

/*
 * Copyright 2018 Google LLC.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "sample_error.h"

#include <cstdint>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "context.h"
#include "montgomery.h"
#include "symmetric_encryption.h"
#include "testing/parameters.h"
#include "testing/status_matchers.h"
#include "testing/status_testing.h"
#include "testing/testing_prng.h"

namespace {

using ::rlwe::testing::StatusIs;
using ::testing::HasSubstr;

const int kTestingRounds = 10;
const std::vector<rlwe::Uint64> variances = {8, 15, 29, 50};

template <typename ModularInt>
class SampleErrorTest : public ::testing::Test {};
TYPED_TEST_SUITE(SampleErrorTest, rlwe::testing::ModularIntTypes);

TYPED_TEST(SampleErrorTest, CheckUpperBoundOnNoise) {
  using Int = typename TypeParam::Int;

  auto prng = absl::make_unique<rlwe::testing::TestingPrng>(0);

  for (const auto& params :
       rlwe::testing::ContextParameters<TypeParam>::Value()) {
    ASSERT_OK_AND_ASSIGN(auto context,
                         rlwe::RlweContext<TypeParam>::Create(params));

    for (auto variance : variances) {
      for (int i = 0; i < kTestingRounds; i++) {
        ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> error,
                             rlwe::SampleFromErrorDistribution<TypeParam>(
                                 context->GetN(), variance, prng.get(),
                                 context->GetModulusParams()));
        // Check that each coefficient is in [-2*variance, 2*variance]
        for (int j = 0; j < context->GetN(); j++) {
          Int reduced = error[j].ExportInt(context->GetModulusParams());
          if (reduced > (context->GetModulus() >> 1)) {
            EXPECT_LT(context->GetModulus() - reduced, 2 * variance + 1);
          } else {
            EXPECT_LT(reduced, 2 * variance + 1);
          }
        }
      }
    }
  }
}

TYPED_TEST(SampleErrorTest, FailOnTooLargeVariance) {
  auto prng = absl::make_unique<rlwe::testing::TestingPrng>(0);
  for (const auto& params :
       rlwe::testing::ContextParameters<TypeParam>::Value()) {
    ASSERT_OK_AND_ASSIGN(auto context,
                         rlwe::RlweContext<TypeParam>::Create(params));

    rlwe::Uint64 variance = rlwe::kMaxVariance + 1;
    EXPECT_THAT(
        rlwe::SampleFromErrorDistribution<TypeParam>(
            context->GetN(), variance, prng.get(), context->GetModulusParams()),
        StatusIs(
            absl::StatusCode::kInvalidArgument,
            HasSubstr(absl::StrCat("The variance, ", variance,
                                   ", must be at most ", rlwe::kMaxVariance))));
  }
}

}  // namespace