chromium/chromeos/ash/services/libassistant/grpc/grpc_http_connection_client_unittests.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/libassistant/grpc/grpc_http_connection_client.h"

#include "base/memory/raw_ptr.h"
#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "chromeos/assistant/internal/grpc_transport/streaming/streaming_writer.h"
#include "chromeos/assistant/internal/libassistant/shared_headers.h"
#include "chromeos/assistant/internal/proto/shared/proto/v2/http_connection_interface.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace ash::libassistant {

namespace {

using ::assistant::api::StreamHttpConnectionRequest;
using ::assistant::api::StreamHttpConnectionResponse;
using assistant_client::HttpConnection;

MATCHER_P(SerializedProtoEquals, message, "") {
  std::string expected_serialized, actual_serialized;
  message.SerializeToString(&expected_serialized);
  arg.SerializeToString(&actual_serialized);
  return expected_serialized == actual_serialized;
}

class MockHttpConnection : public HttpConnection {
 public:
  explicit MockHttpConnection(HttpConnection::Delegate* delegate)
      : delegate_(delegate) {}
  MockHttpConnection(const MockHttpConnection&) = delete;
  MockHttpConnection& operator=(const MockHttpConnection&) = delete;
  ~MockHttpConnection() override = default;

  // assistant_client::HttpConnection implementation:
  MOCK_METHOD(void,
              SetRequest,
              (const std::string& url, HttpConnection::Method),
              (override));
  MOCK_METHOD(void,
              AddHeader,
              (const std::string& name, const std::string& value),
              (override));
  MOCK_METHOD(void,
              SetUploadContent,
              (const std::string& content, const std::string& content_type),
              (override));
  MOCK_METHOD(void,
              SetChunkedUploadContentType,
              (const std::string& content_type),
              (override));
  MOCK_METHOD(void, EnableHeaderResponse, (), (override));
  MOCK_METHOD(void, EnablePartialResults, (), (override));
  MOCK_METHOD(void, Start, (), (override));
  MOCK_METHOD(void, Pause, (), (override));
  MOCK_METHOD(void, Resume, (), (override));
  MOCK_METHOD(void, Close, (), (override));
  MOCK_METHOD(void,
              UploadData,
              (const std::string& data, bool is_last_chunk),
              (override));

  void SendOnHeaderResponse(const std::string& raw_headers) {
    delegate_->OnHeaderResponse(raw_headers);
  }

  void SendOnPartialResponse(const std::string& partial_response) {
    delegate_->OnPartialResponse(partial_response);
  }

  void SendOnCompleteResponse(int http_status,
                              const std::string& raw_headers,
                              const std::string& response) {
    delegate_->OnCompleteResponse(http_status, raw_headers, response);
  }

  void SendOnNetworkError(int error_code, const std::string& message) {
    delegate_->OnNetworkError(error_code, message);
  }

  void SendOnConnectionDestroyed() { delegate_->OnConnectionDestroyed(); }

 private:
  raw_ptr<HttpConnection::Delegate> delegate_;
};

class TestHttpConnectionFactory
    : public assistant_client::HttpConnectionFactory {
 public:
  TestHttpConnectionFactory() = default;
  TestHttpConnectionFactory(const TestHttpConnectionFactory&) = delete;
  TestHttpConnectionFactory& operator=(const TestHttpConnectionFactory&) =
      delete;
  ~TestHttpConnectionFactory() override = default;

  // assistant_client::HttpConnectionFactory implementation:
  HttpConnection* Create(HttpConnection::Delegate* delegate) override {
    http_connection_ = std::make_unique<MockHttpConnection>(delegate);
    return http_connection_.get();
  }

  MockHttpConnection* http_connection() { return http_connection_.get(); }

 private:
  std::unique_ptr<MockHttpConnection> http_connection_;
};

class MockStreamingWriter : public chromeos::libassistant::StreamingWriter<
                                StreamHttpConnectionRequest> {
 public:
  MockStreamingWriter() = default;
  ~MockStreamingWriter() override = default;

  // StreamingWriter implementation:
  MOCK_METHOD(void, Write, (StreamHttpConnectionRequest msg), (override));
  MOCK_METHOD(void, WritesDone, (), (override));
};

}  // namespace

class TestGrpcHttpConnectionService {
 public:
  explicit TestGrpcHttpConnectionService(GrpcHttpConnectionClient* client)
      : client_(client) {
    CreateWriter();
  }

  TestGrpcHttpConnectionService(const TestGrpcHttpConnectionService&) = delete;
  TestGrpcHttpConnectionService& operator=(
      const TestGrpcHttpConnectionService&) = delete;
  ~TestGrpcHttpConnectionService() = default;

  void SendCreateCommand() {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::CREATE);
    WriteResponse(std::move(response));
  }

  void SendStartCommand(
      const std::string& url,
      StreamHttpConnectionResponse::Method method,
      const std::vector<std::pair<std::string, std::string>>& headers,
      const std::string& upload_content,
      const std::string& upload_content_type,
      bool enable_header_response,
      bool handle_partial_response) {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::START);
    auto* parameters = response.mutable_parameters();
    parameters->set_url(url);
    parameters->set_method(method);
    for (const auto& header : headers) {
      auto* new_header = parameters->add_headers();
      new_header->set_name(header.first);
      new_header->set_value(header.second);
    }
    parameters->set_upload_content(upload_content);
    parameters->set_upload_content_type(upload_content_type);
    parameters->set_enable_header_response(enable_header_response);
    parameters->set_enable_partial_response(handle_partial_response);
    WriteResponse(std::move(response));
  }

  void SendPauseCommand() {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::PAUSE);
    WriteResponse(std::move(response));
  }

  void SendResumeCommand() {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::RESUME);
    WriteResponse(std::move(response));
  }

  void SendCloseCommand() {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::CLOSE);
    WriteResponse(std::move(response));
  }

  void SendUploadDataCommand(const std::string& data, bool is_last_chunk) {
    StreamHttpConnectionResponse response;
    response.set_id(1);
    response.set_command(StreamHttpConnectionResponse::UPLOAD_DATA);
    auto* chunked_data = response.mutable_chunked_data();
    chunked_data->set_data(data);
    chunked_data->set_is_last_chunk(is_last_chunk);
    WriteResponse(std::move(response));
  }

  void SetWriteAvailable() {
    client_->OnRpcWriteAvailable(nullptr, writer_.get());
  }
  MockStreamingWriter& writer() { return *writer_; }
  void CreateWriter() { writer_ = std::make_unique<MockStreamingWriter>(); }
  void ResetWriter() { writer_.reset(); }

 private:
  void WriteResponse(StreamHttpConnectionResponse response) {
    client_->OnRpcReadAvailable(nullptr, response);
  }

  raw_ptr<GrpcHttpConnectionClient> client_;
  std::unique_ptr<MockStreamingWriter> writer_;
};

class GrpcHttpConnectionClientTest : public testing::Test {
 public:
  GrpcHttpConnectionClientTest() = default;
  GrpcHttpConnectionClientTest(const GrpcHttpConnectionClientTest&) = delete;
  GrpcHttpConnectionClientTest& operator=(const GrpcHttpConnectionClientTest&) =
      delete;
  ~GrpcHttpConnectionClientTest() override = default;

  void SetUp() override {
    client_ = std::make_unique<GrpcHttpConnectionClient>(
        &http_connection_factory_,
        /*server_address=*/"unix:///tmp/test.socket");
    service_ = std::make_unique<TestGrpcHttpConnectionService>(client_.get());
    client_->Start();
  }

 protected:
  MockHttpConnection* http_connection() {
    return http_connection_factory_.http_connection();
  }

  base::test::SingleThreadTaskEnvironment environment_;
  std::unique_ptr<TestGrpcHttpConnectionService> service_;
  TestHttpConnectionFactory http_connection_factory_;
  std::unique_ptr<GrpcHttpConnectionClient> client_;
};

TEST_F(GrpcHttpConnectionClientTest, CreateHttpConnection) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());
}

TEST_F(GrpcHttpConnectionClientTest, StartHttpConnection) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);

  const std::string url = "url";
  const auto method = StreamHttpConnectionResponse::POST;
  std::vector<std::pair<std::string, std::string>> headers;
  headers.push_back({"name", "value"});
  const std::string upload_content = "upload_content";
  const std::string upload_content_type = "upload_content_type";
  const bool enable_header_response = true;
  const bool handle_partial_response = true;

  EXPECT_CALL(*connection, SetRequest(url, HttpConnection::POST));
  EXPECT_CALL(*connection, AddHeader("name", "value"));
  EXPECT_CALL(*connection,
              SetUploadContent(upload_content, upload_content_type));
  EXPECT_CALL(*connection, EnableHeaderResponse());
  EXPECT_CALL(*connection, EnablePartialResults());
  EXPECT_CALL(*connection, Start());
  EXPECT_CALL(*connection, Close());
  service_->SendStartCommand(url, method, headers, upload_content,
                             upload_content_type, enable_header_response,
                             handle_partial_response);
}

TEST_F(GrpcHttpConnectionClientTest, PauseHttpConnection) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);

  EXPECT_CALL(*connection, Pause());
  EXPECT_CALL(*connection, Close());
  service_->SendPauseCommand();
}

TEST_F(GrpcHttpConnectionClientTest, ResumeHttpConnection) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);

  EXPECT_CALL(*connection, Resume());
  EXPECT_CALL(*connection, Close());
  service_->SendResumeCommand();
}

TEST_F(GrpcHttpConnectionClientTest, CloseHttpConnection) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);

  EXPECT_CALL(*connection, Close()).Times(1);
  service_->SendCloseCommand();
}

TEST_F(GrpcHttpConnectionClientTest, UploadData) {
  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);

  const std::string data = "data";
  const bool is_last_chunk = true;
  EXPECT_CALL(*connection, UploadData(data, is_last_chunk));
  EXPECT_CALL(*connection, Close());
  service_->SendUploadDataCommand(data, is_last_chunk);
}

TEST_F(GrpcHttpConnectionClientTest, RegisterOnTheFirstWriteAvailable) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();
}

TEST_F(GrpcHttpConnectionClientTest, ReceiveOnHeaderResponse) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();

  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());

  const std::string raw_headers = "raw_headers";
  request.Clear();
  request.set_id(1);
  request.set_command(StreamHttpConnectionRequest::HANDLE_HEADER_RESPONSE);
  request.set_raw_headers(raw_headers);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  service_->SetWriteAvailable();
  connection->SendOnHeaderResponse(raw_headers);
  base::RunLoop().RunUntilIdle();
}

TEST_F(GrpcHttpConnectionClientTest, ReceiveOnPartialResponse) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();

  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());

  const std::string partial_response = "partial_response";
  request.Clear();
  request.set_id(1);
  request.set_command(StreamHttpConnectionRequest::HANDLE_PARTIAL_RESPONSE);
  request.set_partial_response(partial_response);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  service_->SetWriteAvailable();
  connection->SendOnPartialResponse(partial_response);
  base::RunLoop().RunUntilIdle();
}

TEST_F(GrpcHttpConnectionClientTest, ReceiveOnCompleteResponse) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();

  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());

  int http_status = 200;
  const std::string raw_headers = "raw_headers";
  const std::string response = "response";
  request.Clear();
  request.set_id(1);
  request.set_command(StreamHttpConnectionRequest::HANDLE_COMPLETE_RESPONSE);
  auto* res = request.mutable_complete_response();
  res->set_response_code(http_status);
  res->set_raw_headers(raw_headers);
  res->set_response(response);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  service_->SetWriteAvailable();
  connection->SendOnCompleteResponse(http_status, raw_headers, response);
  base::RunLoop().RunUntilIdle();
}

TEST_F(GrpcHttpConnectionClientTest, ReceiveOnNetworkError) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();

  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());

  int error_code = 501;
  const std::string message = "message";
  request.Clear();
  request.set_id(1);
  request.set_command(StreamHttpConnectionRequest::HANDLE_NETWORK_ERROR);
  auto* error = request.mutable_error();
  error->set_error_code(error_code);
  error->set_error_message(message);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  service_->SetWriteAvailable();
  connection->SendOnNetworkError(error_code, message);
  base::RunLoop().RunUntilIdle();
}

TEST_F(GrpcHttpConnectionClientTest, NotCrashWhenWriterGone) {
  StreamHttpConnectionRequest request;
  request.set_command(StreamHttpConnectionRequest::REGISTER);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));
  // Will trigger registering client.
  service_->SetWriteAvailable();
  base::RunLoop().RunUntilIdle();

  service_->SendCreateCommand();
  auto* connection = http_connection();
  ASSERT_TRUE(connection);
  EXPECT_CALL(*connection, Close());

  const std::string raw_headers = "raw_headers";
  request.Clear();
  request.set_id(1);
  request.set_command(StreamHttpConnectionRequest::HANDLE_HEADER_RESPONSE);
  request.set_raw_headers(raw_headers);
  connection->SendOnHeaderResponse(raw_headers);
  EXPECT_CALL(service_->writer(), Write(SerializedProtoEquals(request)));

  // Simulate the case that the writer becomes nullptr. Should not crash.
  service_->SetWriteAvailable();
  service_->ResetWriter();
  base::RunLoop().RunUntilIdle();
}

}  // namespace ash::libassistant