chromium/chromeos/ash/services/libassistant/grpc/grpc_state.h

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_GRPC_STATE_H_
#define CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_GRPC_STATE_H_

#include <memory>
#include <string>
#include <utility>

#include "base/check.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "chromeos/ash/services/libassistant/grpc/grpc_util.h"
#include "chromeos/assistant/internal/grpc_transport/grpc_client_cq_tag.h"
#include "third_party/grpc/src/include/grpcpp/client_context.h"
#include "third_party/grpc/src/include/grpcpp/generic/generic_stub.h"
#include "third_party/grpc/src/include/grpcpp/grpcpp.h"
#include "third_party/grpc/src/include/grpcpp/impl/codegen/client_context.h"
#include "third_party/grpc/src/include/grpcpp/support/status.h"

namespace google {
namespace protobuf {
class MessageLite;
}  // namespace protobuf
}  // namespace google

namespace ash::libassistant {

// Configs which dictate options for an RPCState instance.
struct StateConfig {
  StateConfig() = default;
  ~StateConfig() = default;
  StateConfig(int32_t retries, int64_t timeout_in_ms)
      : max_retries(retries), timeout_in_ms(timeout_in_ms) {}

  // The maximum retry attempts for the client call if it failed.
  int32_t max_retries = 0;

  // Deadline for the client call.
  int64_t timeout_in_ms = 2000;

  // If set to true, the RPC will be queued and not "fail fast" if the channel
  // is in TRANSIENT_FAILURE or CONNECTING state, and wait until the channel
  // turns READY. Otherwise, such gRPCs will be failed immediately.
  bool wait_for_ready = true;
};

// Object allocated per active RPC.
// Manage the state of a single asynchronous RPC request. If `max_retries`
// is greater than 0, the request will be retried for any transient failures
// as long as the overall deadline has not elapsed.
template <class Response>
class RPCState : public chromeos::libassistant::GrpcClientCQTag {
 public:
  // Async RPCState constructor.
  // Default behavior is to set wait_for_ready = true and handle timeouts
  // manually.
  RPCState(std::shared_ptr<grpc::Channel> channel,
           grpc::CompletionQueue* cq,
           const grpc::string& method,
           const google::protobuf::MessageLite& request,
           ResponseCallback<grpc::Status, Response> done,
           scoped_refptr<base::SequencedTaskRunner> callback_task_runner,
           StateConfig config)
      : async_cb_(std::move(done)),
        callback_task_runner_(callback_task_runner),
        cq_(cq),
        stub_(channel),
        method_(method),
        timeout_in_ms_(config.timeout_in_ms),
        max_retries_(config.max_retries),
        wait_for_ready_(config.wait_for_ready) {
    DCHECK(cq);
    DCHECK(callback_task_runner);

    grpc::Status s = GrpcSerializeProto(request, &request_buf_);
    if (!s.ok()) {
      LOG(ERROR) << "GrpcSerializeProto returned with non-ok status: "
                 << s.error_message();
      // Skip retry logic if we fail to parse our request.
      StateDone();
      return;
    }

    StartCall();
  }

  RPCState(const RPCState&) = delete;
  RPCState& operator=(const RPCState&) = delete;
  ~RPCState() override = default;

  void StartCall() {
    context_ = std::make_unique<grpc::ClientContext>();
    context_->set_wait_for_ready(wait_for_ready_);

    if (timeout_in_ms_ > 0) {
      context_->set_deadline(
          gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
    }

    VLOG(3) << "Starting call: " << method_;
    call_ = stub_.PrepareUnaryCall(context_.get(), method_, request_buf_, cq_);
    call_->StartCall();
    // Request that upon the completion of an RPC call, |response_buf_| will be
    // updated with server's response. Tag the call with |this| to identify this
    // request.
    call_->Finish(&response_buf_, &status_, /*tag=*/this);
  }

  // GrpcClientCQTag overrides:
  // Invoked from the completion queue thread.
  void OnCompleted(State state) override {
    VLOG(3) << "Completed call: " << method_;

    if (state == State::kShutdown) {
      LOG(WARNING) << "Unary RPC done with CQ has been shutting down.";
      ParseAndCallDone();
      return;
    }

    if (status_.ok() || status_.error_code() == grpc::StatusCode::CANCELLED) {
      ParseAndCallDone();
      return;
    }

    LOG_IF(WARNING, ShouldLogGrpcError())
        << method_ << " returned with non-ok status: " << status_.error_code()
        << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n";
    // TODO(nanping): Retry only for logical errors by having them in the
    // config.
    // Retry if we have any attempts left
    if (num_retries_ < max_retries_) {
      ++num_retries_;
      response_buf_.Clear();
      LOG_IF(WARNING, ShouldLogGrpcError())
          << "Retrying call for " << method_ << "Retry: " << num_retries_
          << " of " << max_retries_;
      StartCall();
    } else {
      // Attach additional GRPC error information if any to the final status
      LOG_IF(ERROR, ShouldLogGrpcError()) << "RPC call failed :\n";
      StateDone();
    }
  }

  // Runs on the completion queue thread.
  void ParseAndCallDone() {
    if (!GrpcParseProto(&response_buf_, &async_response_)) {
      LOG(ERROR) << "RPC parse response failed.";
    }
    StateDone();
  }

  // Run on the completion queue thread.
  void StateDone() {
    DCHECK(async_cb_);
    // |async_cb_| must be invoked from its original sequence.
    callback_task_runner_->PostTask(
        FROM_HERE,
        base::BindOnce(std::move(async_cb_), status_, async_response_));

    delete this;
  }

 private:
  bool ShouldLogGrpcError() {
    // Some grpc errors are legitimate/expected. Ex: ReadSecureFile() may return
    // NOT_FOUND if the file doesn't exist. Do not log warning/errors since it's
    // just spam. The caller can log the error if desired.
    return status_.error_code() != grpc::StatusCode::NOT_FOUND;
  }

  // An instance managing the context settings, e.g. deadline, relevant to the
  // call they are invoked with. Same object should not be reused across RPCs.
  std::unique_ptr<::grpc::ClientContext> context_;

  // Message response of type |Response| received from the server.
  Response async_response_;

  // Buffer filled in with request/response.
  grpc::ByteBuffer request_buf_;
  grpc::ByteBuffer response_buf_;

  // Status of a RPC call. The status is OK if the call finished with no errors.
  grpc::Status status_;

  // |async_cb_| must always be called from the main thread.
  ResponseCallback<grpc::Status, Response> async_cb_;
  scoped_refptr<base::SequencedTaskRunner> callback_task_runner_;

  // An instance used by an async gRPC client to manage asynchronous rpc
  // operations. An RPC call is bound to a CompletionQueue when performed
  // using the stub.
  raw_ptr<grpc::CompletionQueue> cq_ = nullptr;

  // An instance used by a gRPC client to invoke rpc methods implemented in
  // the server.
  grpc::GenericStub stub_;

  // An instance held a unary RPC call and exposes methods to start and finish
  // the call with server's response.
  std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;

  // The name of a RPC method.
  grpc::string method_;

  // Config options for a RPC call.
  int64_t timeout_in_ms_;
  size_t max_retries_;
  bool wait_for_ready_;
  size_t num_retries_ = 0;
};

}  // namespace ash::libassistant

#endif  // CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_GRPC_STATE_H_