chromium/chromeos/ash/services/libassistant/grpc/rpc_method_driver.h

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_RPC_METHOD_DRIVER_H_
#define CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_RPC_METHOD_DRIVER_H_

#include <memory>

#include "base/check.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequence_checker.h"
#include "base/threading/thread.h"
#include "third_party/grpc/src/include/grpcpp/grpcpp.h"
#include "third_party/grpc/src/include/grpcpp/server_context.h"

namespace ash::libassistant {

// Implements async RPC driver for an RPC method.
// Request and Response are the RPC method's request and response protos.
//
// Sample usage for handling SetCapabilities() RPC method from service
// ConversationInterface:
//
//  set_capabilities_rpc_driver_.reset(
//      std::make_unique<RpcMethodDriver<SetCapabilitiesRequest,
//      SetCapabilitiesResponse>>(
//          cq,
//          base::BindRepeating(
//              &ConversationInterface::AsyncService::RequestSetCapabilities,
//              service_.WeakPtr()),
//          base::BindRepeating(
//              &ConversationInterfaceImpl::SetCapabilities,
//              conversation_interface_impl_.WeakPtr())));
//
template <class Request, class Response>
class RpcMethodDriver {
 public:
  // Callback object encapsulating service.Request##RpcMethod() which looks
  // for next incoming RPC.
  using ServiceRpcCallFn =
      base::RepeatingCallback<void(grpc::ServerContext*,
                                   Request*,
                                   grpc::ServerAsyncResponseWriter<Response>*,
                                   grpc::CompletionQueue*,
                                   grpc::ServerCompletionQueue*,
                                   void*)>;

  // Callback object for calling RPC async business logic implementation.
  using RpcImplAsyncFn = base::RepeatingCallback<void(
      grpc::ServerContext*,
      const Request*,
      base::OnceCallback<void(const grpc::Status&, const Response&)>)>;

  // Constructs the class and initializes the completion queue.
  // cq: CompletionQueue
  // service_rpc_call_fn: Callback object encapsulating
  //         service.Request##RpcMethod() which looks for next incoming RPC.
  // rpc_impl_async_fn: Callback object for calling implementation of
  //              business logic of the RPC.
  RpcMethodDriver(grpc::ServerCompletionQueue* cq,
                  ServiceRpcCallFn service_rpc_call_fn,
                  RpcImplAsyncFn rpc_impl_async_fn)
      : cq_(cq),
        service_rpc_call_fn_(service_rpc_call_fn),
        rpc_impl_async_fn_(rpc_impl_async_fn) {
    DCHECK(cq);
    RequestNextRpc();
  }

  ~RpcMethodDriver() = default;
  RpcMethodDriver(const RpcMethodDriver&) = delete;
  RpcMethodDriver& operator=(const RpcMethodDriver&) = delete;

 private:
  // Look for the next incoming RPC.
  void RequestNextRpc() {
    // Owned by CleanupAfterRpc() run at the end of the lifecycle of current
    // RPC.
    auto ctx = std::make_unique<grpc::ServerContext>();
    auto request = std::make_unique<Request>();
    auto responder =
        std::make_unique<grpc::ServerAsyncResponseWriter<Response>>(ctx.get());

    // Prestore valid pointers before std::move() nulls the smart pointers.
    auto* ctx_ptr = ctx.get();
    auto* request_ptr = request.get();
    auto* responder_ptr = responder.get();

    // A raw pointer has to be used here since |service_rpc_call_fn_| is
    // expecting void* as the parameter. It will be deleted by server cq
    // after being executed.
    auto* process_rpc_cb = new base::OnceCallback<void(bool)>(
        base::BindOnce(&RpcMethodDriver<Request, Response>::ProcessRpc,
                       weak_factory_.GetWeakPtr(), std::move(ctx),
                       std::move(request), std::move(responder)));

    DCHECK(service_rpc_call_fn_);
    service_rpc_call_fn_.Run(ctx_ptr, request_ptr, responder_ptr, cq_.get(),
                             cq_.get(), process_rpc_cb);
  }

  // Process the RPC received.
  void ProcessRpc(
      std::unique_ptr<grpc::ServerContext> ctx,
      std::unique_ptr<Request> request,
      std::unique_ptr<grpc::ServerAsyncResponseWriter<Response>> responder,
      bool ok) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

    if (!ok) {
      // If not okay, logs error and returns. Used data, i.e. ctx, will be
      // cleaned up automatically when unique_ptrs go out of scope.
      LOG(ERROR) << "OnEventFromLibas request not ok.";
      return;
    }

    // Start waiting for the next RPC.
    RequestNextRpc();

    ExecuteRpc(std::move(ctx), std::move(request), std::move(responder));
  }

  // Execute the RPC received.
  void ExecuteRpc(
      std::unique_ptr<grpc::ServerContext> ctx,
      std::unique_ptr<Request> request,
      std::unique_ptr<grpc::ServerAsyncResponseWriter<Response>> responder) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

    // Prestore valid pointers before std::move() nulls the smart pointers.
    auto* ctx_ptr = ctx.get();
    auto* request_ptr = request.get();
    auto* responder_ptr = responder.get();

    auto* finish_rpc_cb = new base::OnceCallback<void(bool)>(
        base::BindOnce(&RpcMethodDriver<Request, Response>::CleanupAfterRpc,
                       weak_factory_.GetWeakPtr(), std::move(ctx),
                       std::move(request), std::move(responder)));

    auto async_cb = base::BindOnce(
        [](grpc::ServerAsyncResponseWriter<Response>* responder,
           base::OnceCallback<void(bool)>* finish_rpc_cb,
           const grpc::Status& status, const Response& response) {
          responder->Finish(response, status, finish_rpc_cb);
        },
        responder_ptr, finish_rpc_cb);

    DCHECK(rpc_impl_async_fn_);
    // Call the async implementation of the RPC business logic.
    rpc_impl_async_fn_.Run(ctx_ptr, request_ptr, std::move(async_cb));
  }

  void CleanupAfterRpc(
      std::unique_ptr<grpc::ServerContext> ctx,
      std::unique_ptr<Request> request,
      std::unique_ptr<grpc::ServerAsyncResponseWriter<Response>> responder,
      bool ignored_ok) {
    DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);

    DVLOG(3) << "OnEventFromLibas is finished.";

    // Unique_ptrs will delete the objects after they go out of scope
    // so no manual data clean-up needed here.
  }

  // Owned by |ServicesInitializerBase|.
  raw_ptr<grpc::ServerCompletionQueue> cq_ = nullptr;

  ServiceRpcCallFn service_rpc_call_fn_;
  RpcImplAsyncFn rpc_impl_async_fn_;

  // This sequence checker ensures that all callbacks are called on the
  // main sequence.
  SEQUENCE_CHECKER(sequence_checker_);

  base::WeakPtrFactory<RpcMethodDriver> weak_factory_{this};
};

}  // namespace ash::libassistant

#endif  // CHROMEOS_ASH_SERVICES_LIBASSISTANT_GRPC_RPC_METHOD_DRIVER_H_