chromium/chromeos/ash/components/boca/babelorca/transcript_sender.cc

// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/boca/babelorca/transcript_sender.h"

#include <cstddef>
#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/sequence_checker.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/thread_pool.h"
#include "base/uuid.h"
#include "chromeos/ash/components/boca/babelorca/proto/babel_orca_message.pb.h"
#include "chromeos/ash/components/boca/babelorca/proto/tachyon.pb.h"
#include "chromeos/ash/components/boca/babelorca/proto/tachyon_common.pb.h"
#include "chromeos/ash/components/boca/babelorca/proto/tachyon_enums.pb.h"
#include "chromeos/ash/components/boca/babelorca/response_callback_wrapper.h"
#include "chromeos/ash/components/boca/babelorca/response_callback_wrapper_impl.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_authed_client.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_constants.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_request_data_provider.h"
#include "chromeos/ash/components/boca/babelorca/tachyon_utils.h"
#include "media/mojo/mojom/speech_recognition_result.h"
#include "net/traffic_annotation/network_traffic_annotation.h"

namespace ash::babelorca {
namespace {

int GetTranscriptPartIndex(const std::string& current_text,
                           const std::string& new_text,
                           size_t max_allowed_char) {
  const int len = new_text.length() < current_text.length()
                      ? new_text.length()
                      : current_text.length();
  int diff_index = 0;
  while (diff_index < len && new_text[diff_index] == current_text[diff_index]) {
    ++diff_index;
  }
  const size_t diff_len = new_text.length() - diff_index;
  if (diff_len < max_allowed_char) {
    const int index = diff_index - (max_allowed_char - diff_len);
    diff_index = index < 0 ? 0 : index;
  }
  return diff_index;
}

std::string CreateRequestString(BabelOrcaMessage message,
                                std::string tachyon_token,
                                std::string group_id,
                                std::string sender_email) {
  Id receiver_id;
  receiver_id.set_id(std::move(group_id));
  receiver_id.set_app(kTachyonAppName);
  receiver_id.set_type(IdType::GROUP_ID);
  InboxSendRequest send_request;
  *send_request.mutable_header() = GetRequestHeaderTemplate();
  send_request.mutable_header()->set_auth_token_payload(
      std::move(tachyon_token));
  *send_request.mutable_dest_id() = receiver_id;

  send_request.mutable_message()->set_message_id(
      base::Uuid::GenerateRandomV4().AsLowercaseString());
  send_request.mutable_message()->set_message(message.SerializeAsString());
  *send_request.mutable_message()->mutable_receiver_id() = receiver_id;
  send_request.mutable_message()->mutable_sender_id()->set_id(
      std::move(sender_email));
  send_request.mutable_message()->mutable_sender_id()->set_type(IdType::EMAIL);
  send_request.mutable_message()->mutable_sender_id()->set_app(kTachyonAppName);
  send_request.mutable_message()->set_message_type(InboxMessage::GROUP);
  send_request.mutable_message()->set_message_class(InboxMessage::USER);

  send_request.set_fanout_sender(MessageFanout::OTHER_SENDER_DEVICES);

  return send_request.SerializeAsString();
}

}  // namespace

TranscriptSender::TranscriptSender(
    TachyonAuthedClient* authed_client,
    TachyonRequestDataProvider* request_data_provider,
    std::string_view sender_email,
    const net::NetworkTrafficAnnotationTag& network_traffic_annotation,
    Options options,
    base::OnceClosure failure_cb)
    : authed_client_(authed_client),
      request_data_provider_(request_data_provider),
      sender_email_(sender_email),
      network_traffic_annotation_(network_traffic_annotation),
      options_(std::move(options)),
      failure_cb_(std::move(failure_cb)),
      sender_uuid_(base::Uuid::GenerateRandomV4().AsLowercaseString()) {}

TranscriptSender::~TranscriptSender() {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}

bool TranscriptSender::SendTranscriptionUpdate(
    const media::SpeechRecognitionResult& transcript,
    const std::string& language) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (errors_num_ >= options_.max_errors_num) {
    return false;
  }
  const int part_index =
      GetTranscriptPartIndex(current_transcript_text_, transcript.transcription,
                             options_.max_allowed_char);
  BabelOrcaMessage message = GenerateMessage(transcript, part_index, language);
  base::ThreadPool::PostTaskAndReplyWithResult(
      FROM_HERE,
      base::BindOnce(CreateRequestString, std::move(message),
                     request_data_provider_->tachyon_token(),
                     request_data_provider_->group_id(), sender_email_),
      base::BindOnce(&TranscriptSender::Send, weak_ptr_factory.GetWeakPtr(),
                     /*max_retries=*/transcript.is_final ? 1 : 0));
  // Should be called after `GenerateMessage`.
  UpdateTranscripts(transcript, language);
  return true;
}

BabelOrcaMessage TranscriptSender::GenerateMessage(
    const media::SpeechRecognitionResult& transcript,
    int part_index,
    const std::string& language) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  BabelOrcaMessage message;
  // Set main message metadata.
  message.set_sender_uuid(sender_uuid_);
  message.set_order(message_order_);
  ++message_order_;

  std::string current_text_part = transcript.transcription.substr(part_index);
  const size_t current_text_part_len = current_text_part.length();
  TranscriptPart* current_transcript_part =
      message.mutable_current_transcript();
  current_transcript_part->set_transcript_id(current_transcript_index_);
  current_transcript_part->set_text_index(part_index);
  current_transcript_part->set_text(std::move(current_text_part));
  current_transcript_part->set_is_final(transcript.is_final);
  current_transcript_part->set_language(language);

  // Set previous transcript if message did not reach
  // `options_.max_allowed_char`.
  if (current_text_part_len < options_.max_allowed_char &&
      !previous_transcript_text_.empty()) {
    const size_t max_prev_len =
        options_.max_allowed_char - current_text_part_len;
    const int prev_index =
        previous_transcript_text_.length() < max_prev_len
            ? 0
            : previous_transcript_text_.length() - max_prev_len;
    std::string prev_text = previous_transcript_text_.substr(prev_index);
    TranscriptPart* previous_transcript_part =
        message.mutable_previous_transcript();
    previous_transcript_part->set_transcript_id(current_transcript_index_ - 1);
    previous_transcript_part->set_text_index(prev_index);
    previous_transcript_part->set_text(std::move(prev_text));
    previous_transcript_part->set_is_final(true);
    previous_transcript_part->set_language(previous_language_);
  }
  return message;
}

void TranscriptSender::UpdateTranscripts(
    const media::SpeechRecognitionResult& transcript,
    const std::string& language) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (!transcript.is_final) {
    current_transcript_text_ = transcript.transcription;
    return;
  }
  ++current_transcript_index_;
  previous_language_ = language;
  previous_transcript_text_ = transcript.transcription;
  current_transcript_text_ = "";
}

void TranscriptSender::Send(int max_retries, std::string request_string) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (request_string.empty()) {
    LOG(ERROR) << "Send request is empty.";
    return;
  }

  auto response_callback_wrapper =
      std::make_unique<ResponseCallbackWrapperImpl<InboxSendResponse>>(
          base::BindOnce(&TranscriptSender::OnSendResponse,
                         weak_ptr_factory.GetWeakPtr()));
  authed_client_->StartAuthedRequestString(
      network_traffic_annotation_, std::move(request_string), kSendMessageUrl,
      max_retries, std::move(response_callback_wrapper));
}

void TranscriptSender::OnSendResponse(
    base::expected<InboxSendResponse,
                   ResponseCallbackWrapper::TachyonRequestError> response) {
  DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
  if (response.has_value()) {
    errors_num_ = 0;
    return;
  }
  ++errors_num_;
  if (errors_num_ >= options_.max_errors_num && failure_cb_) {
    std::move(failure_cb_).Run();
  }
}

}  // namespace ash::babelorca