chromium/chromeos/ash/services/libassistant/grpc/grpc_http_connection_client.cc

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/services/libassistant/grpc/grpc_http_connection_client.h"

#include "base/functional/bind.h"
#include "base/notreached.h"
#include "base/task/sequenced_task_runner.h"
#include "chromeos/ash/services/libassistant/grpc/grpc_client_thread.h"
#include "chromeos/ash/services/libassistant/grpc/grpc_http_connection_delegate.h"
#include "chromeos/assistant/internal/grpc_transport/streaming/bidi_streaming_rpc_call.h"
#include "chromeos/assistant/internal/grpc_transport/streaming/streaming_write_queue.h"
#include "third_party/grpc/src/include/grpc/grpc_security_constants.h"
#include "third_party/grpc/src/include/grpc/impl/codegen/grpc_types.h"
#include "third_party/grpc/src/include/grpcpp/create_channel.h"
#include "third_party/grpc/src/include/grpcpp/security/credentials.h"
#include "third_party/grpc/src/include/grpcpp/security/server_credentials.h"
#include "third_party/grpc/src/include/grpcpp/support/channel_arguments.h"

namespace ash::libassistant {

namespace {
using ::assistant::api::StreamHttpConnectionRequest;
using ::assistant::api::StreamHttpConnectionResponse;
using assistant_client::HttpConnection;
using ::chromeos::libassistant::BidiStreamingRpcCall;
using ::chromeos::libassistant::StreamingWriteQueue;
using ::chromeos::libassistant::StreamingWriter;

HttpConnection::Method ConvertToHttpConnectionMethod(
    StreamHttpConnectionResponse::Method method) {
  switch (method) {
    case StreamHttpConnectionResponse::GET:
      return HttpConnection::GET;
    case StreamHttpConnectionResponse::POST:
      return HttpConnection::POST;
    case StreamHttpConnectionResponse::HEAD:
      return HttpConnection::HEAD;
    case StreamHttpConnectionResponse::PATCH:
      return HttpConnection::PATCH;
    case StreamHttpConnectionResponse::PUT:
      return HttpConnection::PUT;
    case StreamHttpConnectionResponse::DELETE:
      return HttpConnection::DELETE;
    case StreamHttpConnectionResponse::METHOD_UNSPECIFIED:
      NOTREACHED_IN_MIGRATION();
      return HttpConnection::GET;
  }
}

// A macro which ensures we are running on the calling sequence.
#define ENSURE_CALLING_SEQUENCE(method, ...)                                \
  DVLOG(3) << __func__;                                                     \
  if (!task_runner_->RunsTasksInCurrentSequence()) {                        \
    task_runner_->PostTask(                                                 \
        FROM_HERE,                                                          \
        base::BindOnce(method, weak_factory_.GetWeakPtr(), ##__VA_ARGS__)); \
    return;                                                                 \
  }

}  // namespace

GrpcHttpConnectionClient::GrpcHttpConnectionClient(
    assistant_client::HttpConnectionFactory* http_connection_factory,
    const std::string& server_address)
    : http_connection_factory_(http_connection_factory),
      cq_thread_(std::make_unique<GrpcClientThread>("http_connection_cq")),
      task_runner_(base::SequencedTaskRunner::GetCurrentDefault()) {
  // Make sure to turn off compression.
  grpc::ChannelArguments channel_args;
  channel_args.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS, 200);
  channel_args.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, 200);
  channel_args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 2000);
  channel_args.SetCompressionAlgorithm(
      grpc_compression_algorithm::GRPC_COMPRESS_NONE);
  grpc_local_connect_type connect_type =
      GetGrpcLocalConnectType(server_address);
  channel_ = grpc::CreateCustomChannel(
      server_address, grpc::experimental::LocalCredentials(connect_type),
      channel_args);
  stub_ = ::assistant::api::HttpConnectionService::NewStub(channel_);
}

GrpcHttpConnectionClient::~GrpcHttpConnectionClient() {
  DCHECK(task_runner_->RunsTasksInCurrentSequence());

  CleanUp();

  {
    base::AutoLock lock(write_queue_lock_);
    is_shutting_down_ = true;
  }

  if (call_) {
    {
      base::AutoLock lock(write_queue_lock_);
      write_queue_.reset();
    }

    call_->TryCancel();
    cq_thread_.reset();
  }
}

void GrpcHttpConnectionClient::Start() {
  DCHECK(task_runner_->RunsTasksInCurrentSequence());

  if (call_) {
    {
      base::AutoLock lock(write_queue_lock_);
      write_queue_.reset();
    }

    call_->TryCancel();
    call_.reset();
  }

  {
    base::AutoLock lock(write_queue_lock_);
    write_queue_ =
        std::make_unique<StreamingWriteQueue<StreamHttpConnectionRequest>>();
  }

  // Create a bidi streaming call to relay http connection from Libassistant.
  BidiStreamingRpcCall<StreamHttpConnectionRequest,
                       StreamHttpConnectionResponse>::CallbackParams cb_params;
  cb_params.write_available_cb = base::BindRepeating(
      &GrpcHttpConnectionClient::OnRpcWriteAvailable, base::Unretained(this));
  cb_params.read_available_cb = base::BindRepeating(
      &GrpcHttpConnectionClient::OnRpcReadAvailable, base::Unretained(this));
  cb_params.exited_cb = base::BindRepeating(
      &GrpcHttpConnectionClient::OnRpcExited, base::Unretained(this));
  call_ = std::make_unique<BidiStreamingRpcCall<StreamHttpConnectionRequest,
                                                StreamHttpConnectionResponse>>(
      std::move(cb_params));
  auto stream = stub_->PrepareAsyncStreamHttpConnection(
      call_->ctx(), cq_thread_->completion_queue());
  call_->Start(std::move(stream));
}

void GrpcHttpConnectionClient::CleanUp() {
  DCHECK(task_runner_->RunsTasksInCurrentSequence());

  // In case |http_connections_| is non-empty, make sure `Close()` is called.
  for (auto iter = http_connections_.begin(); iter != http_connections_.end();
       ++iter) {
    iter->second->Close();
  }
  http_connections_.clear();
  delegates_.clear();
}

void GrpcHttpConnectionClient::ScheduleRequest(
    StreamHttpConnectionRequest request) {
  base::AutoLock lock(write_queue_lock_);
  if (is_shutting_down_) {
    return;
  }

  if (write_queue_) {
    write_queue_->ScheduleWrite(std::move(request));
  }
}

// Called when the RPC channel is idle and ready to accept new write.
void GrpcHttpConnectionClient::OnRpcWriteAvailable(
    grpc::ClientContext* context,
    StreamingWriter<StreamHttpConnectionRequest>* writer) {
  {
    base::AutoLock lock(write_queue_lock_);
    if (is_shutting_down_) {
      return;
    }
  }

  if (!init_request_sent_) {
    DVLOG(1) << "Sending GrpcHttpConnectionClient registration request.";
    init_request_sent_ = true;
    // Send initial request to signal readiness for streaming.
    StreamHttpConnectionRequest request;
    request.set_command(StreamHttpConnectionRequest::REGISTER);
    writer->Write(std::move(request));
    return;
  }

  {
    base::AutoLock lock(write_queue_lock_);

    if (write_queue_) {
      write_queue_->OnRpcWriteAvailable(writer);
    }
  }
}

void GrpcHttpConnectionClient::OnRpcReadAvailable(
    grpc::ClientContext* context,
    const StreamHttpConnectionResponse& response) {
  ENSURE_CALLING_SEQUENCE(&GrpcHttpConnectionClient::OnRpcReadAvailable,
                          context, response);

  DCHECK(response.has_id());
  const int http_connection_id = response.id();
  const auto iter = http_connections_.find(http_connection_id);
  if (iter == http_connections_.end() &&
      response.command() != StreamHttpConnectionResponse::CREATE) {
    DVLOG(2) << "Ignoring the HttpConnection request because the http "
                "connection does not exist.";
    return;
  }

  switch (response.command()) {
    case StreamHttpConnectionResponse::CREATE: {
      DVLOG(1) << "StreamHttpConnectionResponse::CREATE";
      if (iter != http_connections_.end()) {
        LOG(ERROR) << "Failed to create the http connection because of "
                      "duplicated id: "
                   << http_connection_id;
        return;
      }
      {
        DVLOG(1) << "Ceate the http connection " << http_connection_id;
        auto* delegate =
            new GrpcHttpConnectionDelegate(http_connection_id, this);
        auto* http_connection = http_connection_factory_->Create(delegate);
        http_connections_.insert({http_connection_id, http_connection});
        delegates_.insert({http_connection_id, delegate});
      }
      break;
    }
    case StreamHttpConnectionResponse::START: {
      DVLOG(1) << "StreamHttpConnectionResponse::START";
      DCHECK(response.has_parameters());
      const auto& param = response.parameters();
      auto* http_connection = iter->second;
      http_connection->SetRequest(
          param.url(), ConvertToHttpConnectionMethod(param.method()));
      for (const auto& header : param.headers()) {
        http_connection->AddHeader(header.name(), header.value());
      }
      if (!param.upload_content_type().empty()) {
        DCHECK(param.chunked_upload_content_type().empty());
        http_connection->SetUploadContent(param.upload_content(),
                                          param.upload_content_type());
      } else if (!param.chunked_upload_content_type().empty()) {
        DCHECK(param.upload_content_type().empty());
        http_connection->SetChunkedUploadContentType(
            param.chunked_upload_content_type());
      }
      if (param.enable_header_response()) {
        http_connection->EnableHeaderResponse();
      }
      if (param.enable_partial_response()) {
        http_connection->EnablePartialResults();
      }
      http_connection->Start();
      break;
    }
    case StreamHttpConnectionResponse::PAUSE:
      DVLOG(1) << "StreamHttpConnectionResponse::PAUSE";
      iter->second->Pause();
      break;
    case StreamHttpConnectionResponse::RESUME:
      DVLOG(1) << "StreamHttpConnectionResponse::RESUME";
      iter->second->Resume();
      break;
    case StreamHttpConnectionResponse::CLOSE: {
      DVLOG(1) << "StreamHttpConnectionResponse::CLOSE";
      iter->second->Close();
      http_connections_.erase(iter);

      const auto delegate_iter = delegates_.find(http_connection_id);
      DCHECK(delegate_iter != delegates_.end());
      delegates_.erase(delegate_iter);
      break;
    }
    case StreamHttpConnectionResponse::UPLOAD_DATA:
      DVLOG(1) << "StreamHttpConnectionResponse::UPLOAD_DATA";
      iter->second->UploadData(response.chunked_data().data(),
                               response.chunked_data().is_last_chunk());
      break;
    case StreamHttpConnectionResponse::COMMAND_UNSPECIFIED:
      NOTREACHED_IN_MIGRATION();
  }
}

void GrpcHttpConnectionClient::OnRpcExited(grpc::ClientContext* context,
                                           const grpc::Status& status) {
  ENSURE_CALLING_SEQUENCE(&GrpcHttpConnectionClient::OnRpcExited, context,
                          status);
  DVLOG(1) << "GrpcHttpConnectionClient streaming exited with status "
           << (status.ok() ? "ok" : status.error_message());
  init_request_sent_ = false;
  // If the streaming session failed unexpectedly. Since client (this class) is
  // the one who initiates the streaming connection, it's the only one who can
  // repair a broken session. The server (Libassistant) is helpless in this
  // case, so it's important that the client diligently maintains a healthy
  // connection.
  if (!status.ok()) {
    DVLOG(2) << "Retry to establish GrpcHttpConnection streaming session.";
    task_runner_->PostTask(FROM_HERE,
                           base::BindOnce(&GrpcHttpConnectionClient::Start,
                                          weak_factory_.GetWeakPtr()));
  } else {
    DVLOG(1) << "GrpcHttpConnection exited.";
  }

  CleanUp();
}

}  // namespace ash::libassistant