chromium/chromeos/ash/components/boca/babelorca/tachyon_authed_client_impl.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/boca/babelorca/tachyon_authed_client_impl.h"

#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/sequence_checker.h"
#include "base/task/thread_pool.h"
#include "base/types/expected.h"
#include "chromeos/ash/components/boca/babelorca/request_data_wrapper.h"
#include "chromeos/ash/components/boca/babelorca/response_callback_wrapper.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_client.h"
#include "chromeos/ash/components/boca/babelorca/token_manager.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "third_party/protobuf/src/google/protobuf/message_lite.h"

namespace ash::babelorca {
namespace {

std::optional<std::string> SerializeProtoToString(
    std::unique_ptr<google::protobuf::MessageLite> request_proto) {
  std::string proto_string;
  if (!request_proto->SerializeToString(&proto_string)) {
    return std::nullopt;
  }
  return proto_string;
}

}  // namespace

TachyonAuthedClientImpl::TachyonAuthedClientImpl(
    std::unique_ptr<TachyonClient> client,
    TokenManager* oauth_token_manager)
    : client_(std::move(client)), oauth_token_manager_(oauth_token_manager) {}

TachyonAuthedClientImpl::~TachyonAuthedClientImpl() = default;

void TachyonAuthedClientImpl::StartAuthedRequest(
    const net::NetworkTrafficAnnotationTag& annotation_tag,
    std::unique_ptr<google::protobuf::MessageLite> request_proto,
    std::string_view url,
    int max_retries,
    std::unique_ptr<ResponseCallbackWrapper> response_cb) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  auto serialize_cb =
      base::BindOnce(SerializeProtoToString, std::move(request_proto));
  auto reply_post_cb =
      base::BindOnce(&TachyonAuthedClientImpl::OnRequestProtoSerialized,
                     weak_ptr_factory_.GetWeakPtr(), annotation_tag, url,
                     max_retries, std::move(response_cb));
  base::ThreadPool::PostTaskAndReplyWithResult(
      FROM_HERE, std::move(serialize_cb), std::move(reply_post_cb));
}

void TachyonAuthedClientImpl::StartAuthedRequestString(
    const net::NetworkTrafficAnnotationTag& annotation_tag,
    std::string request_string,
    std::string_view url,
    int max_retries,
    std::unique_ptr<ResponseCallbackWrapper> response_cb) {
  OnRequestProtoSerialized(annotation_tag, url, max_retries,
                           std::move(response_cb), request_string);
}

void TachyonAuthedClientImpl::OnRequestProtoSerialized(
    const net::NetworkTrafficAnnotationTag& annotation_tag,
    std::string_view url,
    int max_retries,
    std::unique_ptr<ResponseCallbackWrapper> response_cb,
    std::optional<std::string> request_string) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!request_string) {
    response_cb->Run(base::unexpected(
        ResponseCallbackWrapper::TachyonRequestError::kInternalError));
    return;
  }
  std::unique_ptr<RequestDataWrapper> request_data =
      std::make_unique<RequestDataWrapper>(annotation_tag, url, max_retries,
                                           std::move(response_cb));
  request_data->content_data = std::move(*request_string);
  const std::string* oauth_token = oauth_token_manager_->GetTokenString();
  if (oauth_token) {
    StartAuthedRequestInternal(std::move(request_data),
                               /*has_oauth_token=*/true);
    return;
  }
  oauth_token_manager_->ForceFetchToken(
      base::BindOnce(&TachyonAuthedClientImpl::StartAuthedRequestInternal,
                     weak_ptr_factory_.GetWeakPtr(), std::move(request_data)));
}

void TachyonAuthedClientImpl::StartAuthedRequestInternal(
    std::unique_ptr<RequestDataWrapper> request_data,
    bool has_oauth_token) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!has_oauth_token) {
    request_data->response_cb->Run(base::unexpected(
        ResponseCallbackWrapper::TachyonRequestError::kAuthError));
    return;
  }
  std::string oauth_token = *(oauth_token_manager_->GetTokenString());
  request_data->oauth_version = oauth_token_manager_->GetFetchedVersion();
  client_->StartRequest(
      std::move(request_data), std::move(oauth_token),
      base::BindOnce(&TachyonAuthedClientImpl::OnRequestAuthFailure,
                     base::Unretained(this)));
}

void TachyonAuthedClientImpl::OnRequestAuthFailure(
    std::unique_ptr<RequestDataWrapper> request_data) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  static int constexpr kMaxAuthRetries = 1;
  if (request_data->oauth_retry_num >= kMaxAuthRetries) {
    request_data->response_cb->Run(base::unexpected(
        ResponseCallbackWrapper::TachyonRequestError::kAuthError));
    return;
  }
  ++(request_data->oauth_retry_num);
  const std::string* oauth_token = oauth_token_manager_->GetTokenString();
  // Check for the token version to make sure it is not the same as the
  // one used one in the auth failure request.
  if (oauth_token && request_data->oauth_version !=
                         oauth_token_manager_->GetFetchedVersion()) {
    StartAuthedRequestInternal(std::move(request_data),
                               /*has_oauth_token=*/true);
    return;
  }
  oauth_token_manager_->ForceFetchToken(
      base::BindOnce(&TachyonAuthedClientImpl::StartAuthedRequestInternal,
                     weak_ptr_factory_.GetWeakPtr(), std::move(request_data)));
}

}  // namespace ash::babelorca