chromium/components/policy/test_support/request_handler_for_psm_auto_enrollment.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "components/policy/test_support/request_handler_for_psm_auto_enrollment.h"

#include <vector>

#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/path_service.h"
#include "components/policy/core/common/cloud/cloud_policy_constants.h"
#include "components/policy/proto/device_management_backend.pb.h"
#include "components/policy/test_support/test_server_helpers.h"
#include "net/http/http_status_code.h"
#include "net/test/embedded_test_server/http_request.h"
#include "net/test/embedded_test_server/http_response.h"
#include "third_party/private_membership/src/internal/testing/regression_test_data/regression_test_data.pb.h"

using net::test_server::HttpRequest;
using net::test_server::HttpResponse;

namespace em = enterprise_management;
using RlweTestData =
    private_membership::rlwe::PrivateMembershipRlweClientRegressionTestData;

namespace policy {
namespace {

const RlweTestData::TestCase* FindOprfTestCase(
    const RlweTestData& test_data,
    const private_membership::rlwe::PrivateMembershipRlweOprfRequest& request) {
  for (const auto& test_case : test_data.test_cases()) {
    if (request.SerializeAsString() ==
        test_case.expected_oprf_request().SerializeAsString()) {
      return &test_case;
    }
  }
  return nullptr;
}

const RlweTestData::TestCase* FindQueryTestCase(
    const RlweTestData& test_data,
    const private_membership::rlwe::PrivateMembershipRlweQueryRequest&
        request) {
  for (const auto& test_case : test_data.test_cases()) {
    if (request.SerializeAsString() ==
        test_case.expected_query_request().SerializeAsString()) {
      return &test_case;
    }
  }
  return nullptr;
}

}  // namespace

// static
std::unique_ptr<RlweTestData>
RequestHandlerForPsmAutoEnrollment::LoadTestData() {
  base::FilePath src_root_dir;
  CHECK(base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &src_root_dir));
  const base::FilePath path_to_test_data =
      src_root_dir.AppendASCII("third_party")
          .AppendASCII("private_membership")
          .AppendASCII("src")
          .AppendASCII("internal")
          .AppendASCII("testing")
          .AppendASCII("regression_test_data")
          .AppendASCII("test_data.binarypb");

  base::ScopedAllowBlockingForTesting allow_blocking;

  auto test_data = std::make_unique<RlweTestData>();
  if (!base::PathExists(path_to_test_data)) {
    LOG(WARNING) << "Path to psm test data does not exist (this is expected in "
                    "tast tests, but not in unit tests: "
                 << path_to_test_data;
    return test_data;
  }

  std::string serialized_test_data;
  CHECK(base::ReadFileToString(path_to_test_data, &serialized_test_data));

  CHECK(test_data->ParseFromString(serialized_test_data));

  return test_data;
}

RequestHandlerForPsmAutoEnrollment::RequestHandlerForPsmAutoEnrollment(
    EmbeddedPolicyTestServer* parent)
    : EmbeddedPolicyTestServer::RequestHandler(parent) {}

RequestHandlerForPsmAutoEnrollment::~RequestHandlerForPsmAutoEnrollment() =
    default;

std::string RequestHandlerForPsmAutoEnrollment::RequestType() {
  return dm_protocol::kValueRequestPsmHasDeviceState;
}

std::unique_ptr<HttpResponse> RequestHandlerForPsmAutoEnrollment::HandleRequest(
    const HttpRequest& request) {
  if (!test_data_) {
    test_data_ = LoadTestData();
  }
  em::DeviceManagementRequest device_management_request;
  device_management_request.ParseFromString(request.content);

  em::DeviceManagementResponse device_management_response;
  em::PrivateSetMembershipResponse* psm_response =
      device_management_response.mutable_private_set_membership_response();
  const auto& rlwe_request =
      device_management_request.private_set_membership_request().rlwe_request();
  if (rlwe_request.has_oprf_request()) {
    const auto* test_case =
        FindOprfTestCase(*test_data_, rlwe_request.oprf_request());
    if (!test_case) {
      return CreateHttpResponse(net::HTTP_BAD_REQUEST,
                                "PSM RLWE OPRF request not as expected");
    }
    *psm_response->mutable_rlwe_response()->mutable_oprf_response() =
        test_case->oprf_response();
  } else if (rlwe_request.has_query_request()) {
    const auto* test_case =
        FindQueryTestCase(*test_data_, rlwe_request.query_request());
    if (!test_case) {
      return CreateHttpResponse(net::HTTP_BAD_REQUEST,
                                "PSM RLWE query request not as expected");
    }
    *psm_response->mutable_rlwe_response()->mutable_query_response() =
        test_case->query_response();
  } else {
    return CreateHttpResponse(
        net::HTTP_BAD_REQUEST,
        "PSM RLWE oprf_request, or query_request fields must be filled");
  }

  return CreateHttpResponse(net::HTTP_OK,
                            device_management_response.SerializeAsString());
}

}  // namespace policy