chromium/chrome/services/speech/soda/cros_soda_client.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifdef UNSAFE_BUFFERS_BUILD
// TODO(crbug.com/40285824): Remove this and convert code to safer constructs.
#pragma allow_unsafe_buffers
#endif

#include "chrome/services/speech/soda/cros_soda_client.h"
#include "base/run_loop.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "mojo/public/cpp/bindings/remote.h"

namespace {

media::SpeechRecognitionResult GetSpeechRecognitionResultFromFinalEvent(
    const chromeos::machine_learning::mojom::FinalResultPtr& final_event) {
  media::SpeechRecognitionResult result;
  result.transcription = final_event->final_hypotheses.front();
  result.is_final = true;

  if (!final_event->timing_event || !final_event->hypothesis_part)
    return result;

  const auto& timing_event = final_event->timing_event;
  media::TimingInformation timing;
  timing.audio_start_time = timing_event->audio_start_time;
  timing.audio_end_time = timing_event->event_end_time;
  timing.hypothesis_parts = std::vector<media::HypothesisParts>();

  for (const auto& part : final_event->hypothesis_part.value())
    timing.hypothesis_parts->emplace_back(part->text, part->alignment);

  result.timing_information = timing;

  return result;
}

}  // namespace
namespace soda {
CrosSodaClient::CrosSodaClient() : soda_client_(this) {}
CrosSodaClient::~CrosSodaClient() = default;

bool CrosSodaClient::DidAudioPropertyChange(int sample_rate,
                                            int channel_count) {
  return !is_initialized_ || sample_rate_ != sample_rate ||
         channel_count_ != channel_count;
}

void CrosSodaClient::AddAudio(const char* audio_buffer,
                              int audio_buffer_size) const {
  DCHECK(IsInitialized()) << "Unable to add audio before starting.";
  const uint8_t* audio_buffer_casted =
      reinterpret_cast<const uint8_t*>(audio_buffer);
  std::vector<uint8_t> audio(audio_buffer_casted,
                             audio_buffer_casted + audio_buffer_size);
  soda_recognizer_->AddAudio(audio);
}

void CrosSodaClient::MarkDone() {
  DCHECK(IsInitialized()) << "Can't mark as done before starting";
  soda_recognizer_->MarkDone();
}

void CrosSodaClient::Reset(
    chromeos::machine_learning::mojom::SodaConfigPtr soda_config,
    CrosSodaClient::TranscriptionResultCallback transcription_callback,
    CrosSodaClient::OnStopCallback stop_callback,
    CrosSodaClient::OnLanguageIdentificationEventCallback langid_callback) {
  sample_rate_ = soda_config->sample_rate;
  channel_count_ = soda_config->channel_count;
  if (is_initialized_) {
    soda_recognizer_->Stop();
  }
  soda_recognizer_.reset();
  soda_client_.reset();
  ml_service_.reset();
  is_initialized_ = true;
  chromeos::machine_learning::ServiceConnection::GetInstance()
      ->BindMachineLearningService(ml_service_.BindNewPipeAndPassReceiver());
  ml_service_->LoadSpeechRecognizer(
      std::move(soda_config), soda_client_.BindNewPipeAndPassRemote(),
      soda_recognizer_.BindNewPipeAndPassReceiver(),
      base::BindOnce(
          [](chromeos::machine_learning::mojom::LoadModelResult result) {
            if (result !=
                chromeos::machine_learning::mojom::LoadModelResult::OK) {
              LOG(DFATAL) << "Could not load recognizer, error: " << result;
            }
          }));

  transcription_callback_ = transcription_callback;
  stop_callback_ = stop_callback;
  langid_callback_ = langid_callback;

  // Ensure this one is started.
  soda_recognizer_->Start();
}

void CrosSodaClient::OnStop() {
  stop_callback_.Run();
}

void CrosSodaClient::OnStart() {
  // Do nothing OnStart.
}
void CrosSodaClient::OnSpeechRecognizerEvent(
    chromeos::machine_learning::mojom::SpeechRecognizerEventPtr event) {
  if (event->is_final_result()) {
    auto& final_result = event->get_final_result();
    if (!final_result->final_hypotheses.empty())
      transcription_callback_.Run(
          GetSpeechRecognitionResultFromFinalEvent(final_result));
  } else if (event->is_partial_result()) {
    auto& partial_result = event->get_partial_result();
    if (!partial_result->partial_text.empty()) {
      const std::string partial_hyp = partial_result->partial_text.front();
      transcription_callback_.Run(
          media::SpeechRecognitionResult(partial_hyp, false));
    }
  } else if (event->is_langid_event()) {
    const auto& langid_event = event->get_langid_event();
    langid_callback_.Run(langid_event->language,
                         static_cast<media::mojom::ConfidenceLevel>(
                             langid_event->confidence_level),
                         static_cast<media::mojom::AsrSwitchResult>(
                             langid_event->asr_switch_result));
  } else if (!event->is_endpointer_event() && !event->is_audio_event()) {
    LOG(ERROR) << "Some kind of other soda event, ignoring completely. Tag is '"
               << static_cast<uint32_t>(event->which()) << "'";
  }
}

}  // namespace soda