chromium/chrome/browser/ash/accessibility/service/speech_recognition_impl.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ash/accessibility/service/speech_recognition_impl.h"

#include "base/strings/utf_string_conversions.h"
#include "chrome/browser/ash/extensions/speech/speech_recognition_private_manager.h"
#include "chrome/browser/ash/extensions/speech/speech_recognition_private_recognizer.h"
#include "chrome/browser/profiles/profile.h"
#include "content/public/browser/browser_context.h"

namespace ash {

namespace {
ax::mojom::SpeechRecognitionType ToMojo(speech::SpeechRecognitionType type) {
  switch (type) {
    case speech::SpeechRecognitionType::kNetwork:
      return ax::mojom::SpeechRecognitionType::kNetwork;
    case speech::SpeechRecognitionType::kOnDevice:
      return ax::mojom::SpeechRecognitionType::kOnDevice;
  }
}
}  // namespace

//
// SpeechRecognitionEventObserverWrapper implementation.
//

SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::
    SpeechRecognitionEventObserverWrapper() = default;
SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::
    ~SpeechRecognitionEventObserverWrapper() = default;

void SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::OnStop() {
  observer_->OnStop();
}

void SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::OnResult(
    ax::mojom::SpeechRecognitionResultEventPtr event) {
  observer_->OnResult(std::move(event));
}

void SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::OnError(
    ax::mojom::SpeechRecognitionErrorEventPtr event) {
  observer_->OnError(std::move(event));
}

mojo::PendingReceiver<ax::mojom::SpeechRecognitionEventObserver>
SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper::PassReceiver() {
  return observer_.BindNewPipeAndPassReceiver();
}

//
// SpeechRecognitionImpl implementation.
//

SpeechRecognitionImpl::SpeechRecognitionImpl(content::BrowserContext* profile)
    : profile_(profile) {
  CHECK(profile_);
}

SpeechRecognitionImpl::~SpeechRecognitionImpl() = default;

void SpeechRecognitionImpl::Bind(
    mojo::PendingReceiver<ax::mojom::SpeechRecognition> receiver) {
  receivers_.Add(this, std::move(receiver));
}

void SpeechRecognitionImpl::Start(ax::mojom::StartOptionsPtr options,
                                  StartCallback callback) {
  // Extract arguments.
  std::optional<std::string> locale;
  std::optional<bool> interim_results;
  if (options->locale) {
    locale = *options->locale;
  }
  if (options->interim_results) {
    interim_results = *options->interim_results;
  }

  std::string key = CreateKey(options->type);
  GetSpeechRecognizer(key)->HandleStart(
      locale, interim_results,
      base::BindOnce(&SpeechRecognitionImpl::StartHelper, GetWeakPtr(),
                     std::move(callback), key));
}

void SpeechRecognitionImpl::Stop(ax::mojom::StopOptionsPtr options,
                                 StopCallback callback) {
  std::string key = CreateKey(options->type);
  GetSpeechRecognizer(key)->HandleStop(
      base::BindOnce(&SpeechRecognitionImpl::StopHelper, GetWeakPtr(),
                     std::move(callback), key));
}

void SpeechRecognitionImpl::HandleSpeechRecognitionStopped(
    const std::string& key) {
  auto* observer = GetEventObserverWrapper(key);
  if (!observer) {
    return;
  }

  observer->OnStop();
  RemoveEventObserverWrapper(key);
}

void SpeechRecognitionImpl::HandleSpeechRecognitionResult(
    const std::string& key,
    const std::u16string& transcript,
    bool is_final) {
  auto* observer = GetEventObserverWrapper(key);
  if (!observer) {
    return;
  }

  auto result = ax::mojom::SpeechRecognitionResultEvent::New();
  result->transcript = base::UTF16ToUTF8(transcript);
  result->is_final = is_final;
  observer->OnResult(std::move(result));
}

void SpeechRecognitionImpl::HandleSpeechRecognitionError(
    const std::string& key,
    const std::string& error) {
  auto* observer = GetEventObserverWrapper(key);
  if (!observer) {
    return;
  }

  auto event = ax::mojom::SpeechRecognitionErrorEvent::New();
  event->message = error;
  observer->OnError(std::move(event));
  RemoveEventObserverWrapper(key);
}

void SpeechRecognitionImpl::StartHelper(StartCallback callback,
                                        const std::string& key,
                                        speech::SpeechRecognitionType type,
                                        std::optional<std::string> error) {
  // Send relevant information back to the caller.
  auto info = ax::mojom::SpeechRecognitionStartInfo::New();
  info->type = ToMojo(type);
  if (error.has_value()) {
    info->observer_or_error =
        ax::mojom::ObserverOrError::NewError(error.value());
    std::move(callback).Run(std::move(info));
    return;
  }

  CreateEventObserverWrapper(key);
  auto* observer = GetEventObserverWrapper(key);
  info->observer_or_error =
      ax::mojom::ObserverOrError::NewObserver(observer->PassReceiver());
  std::move(callback).Run(std::move(info));
}

void SpeechRecognitionImpl::StopHelper(StopCallback callback,
                                       const std::string& key,
                                       std::optional<std::string> error) {
  RemoveEventObserverWrapper(key);
  if (error.has_value()) {
    std::move(callback).Run(std::move(error));
    return;
  }

  std::move(callback).Run(std::optional<std::string>());
}

std::string SpeechRecognitionImpl::CreateKey(
    ax::mojom::AssistiveTechnologyType type) {
  // Dictation is the only accessibility feature that uses speech recognition.
  // In the future, we can add support other accessibility features e.g.
  // Voice Access.
  DCHECK(type == ax::mojom::AssistiveTechnologyType::kDictation);
  return "AtpSpeechRecognition.Dictation";
}

extensions::SpeechRecognitionPrivateRecognizer*
SpeechRecognitionImpl::GetSpeechRecognizer(const std::string& key) {
  auto& recognizer = recognizers_[key];
  if (!recognizer) {
    recognizer =
        std::make_unique<extensions::SpeechRecognitionPrivateRecognizer>(
            this, profile_, key);
  }

  return recognizer.get();
}

void SpeechRecognitionImpl::CreateEventObserverWrapper(const std::string& key) {
  auto& observer = event_observer_wrappers_[key];
  if (!observer) {
    observer = std::make_unique<SpeechRecognitionEventObserverWrapper>();
  }
}

SpeechRecognitionImpl::SpeechRecognitionEventObserverWrapper*
SpeechRecognitionImpl::GetEventObserverWrapper(const std::string& key) {
  auto& observer = event_observer_wrappers_[key];
  if (!observer) {
    return nullptr;
  }

  return observer.get();
}

void SpeechRecognitionImpl::RemoveEventObserverWrapper(const std::string& key) {
  auto& observer = event_observer_wrappers_[key];
  if (observer) {
    observer.reset();
  }

  event_observer_wrappers_.erase(key);
}

}  // namespace ash