chromium/chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_utils.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_utils.h"

#include <algorithm>
#include <utility>

#include "base/base64.h"
#include "base/json/json_writer.h"
#include "base/logging.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/string_util.h"
#include "chromeos/ash/components/enhanced_network_tts/enhanced_network_tts_constants.h"
#include "ui/accessibility/ax_text_utils.h"

namespace ash::enhanced_network_tts {
namespace {

// The offsets computed by |ui::GetSentenceEndOffsets| and
// |ui::GetWordEndOffsets| are pointing to the index after the actual end. This
// method converts the offsets to indexes.
void ConvertOffsetsToIndexes(std::vector<int>& vect) {
  for (int& end : vect)
    end -= 1;
}

// The server requires the rate to be between 0.3 and 4.0, in steps of 0.1.
float ClampRateToLimits(float rate) {
  float clampped_rate = std::clamp(rate, kMinRate, kMaxRate);
  // Set the precision to one significant digit.
  return static_cast<float>(static_cast<int>(clampped_rate * 10) / 10.0f);
}

}  // namespace

std::string FormatJsonRequest(const mojom::TtsRequestPtr tts_request) {
  base::Value::Dict request;

  // utterance is sent as {'text': {'text_parts': [<utterance>]} }
  base::Value::List text_parts;
  text_parts.Append(std::move(tts_request->utterance));
  request.SetByDottedPath(kTextPartsPath, std::move(text_parts));

  // Speech rate, Voice and language are sent as
  // {
  //   {'advanced_options':
  //     {
  //       'audio_generation_options': {'speed_factor': <rate>},
  //       'force_language':<lang>
  //     }
  //   },
  //   {'voice_settings':
  //     {'voice_criteria_and_selections':
  //       [{
  //          'selection': {'default_voice':<voice>}},
  //          'criteria': {'language':<lang>}}
  //       }]
  //     }
  //   }
  // }
  // See https://goto.google.com/readaloud-proto for more information.

  // Add speech rate.
  const float rate = ClampRateToLimits(tts_request->rate);
  request.SetByDottedPath(kSpeechFactorPath, base::Value(rate));

  // The voice and language have to be set together to be valid.
  if (tts_request->voice.has_value() && tts_request->lang.has_value()) {
    // Force the server to produce audio based on the current lang.
    request.SetByDottedPath(kForceLanguagePath,
                            base::Value(tts_request->lang.value()));

    // Produce 'voice_criteria_and_selections'.
    base::Value::Dict selection;
    selection.Set(kDefaultVoiceKey,
                  base::Value(std::move(tts_request->voice.value())));
    base::Value::Dict criteria;
    criteria.Set(kLanguageKey, base::Value(tts_request->lang.value()));
    base::Value::Dict voice_selection;
    voice_selection.Set(kSelectionKey, std::move(selection));
    voice_selection.Set(kCriteriaKey, std::move(criteria));
    base::Value::List voice_criteria_and_selections;
    voice_criteria_and_selections.Append(std::move(voice_selection));
    request.SetByDottedPath(kVoiceCriteriaAndSelectionsPath,
                            std::move(voice_criteria_and_selections));
  }

  std::string json_request;
  base::JSONWriter::Write(request, &json_request);
  return json_request;
}

std::vector<uint16_t> FindTextBreaks(const std::u16string& utterance,
                                     const int length_limit) {
  std::vector<uint16_t> breaks;
  DCHECK_GT(length_limit, 0);

  if (utterance.empty())
    return breaks;

  // The input utterance must be pre-trimmed so that it does not start with
  // whitespaces. The ICU break iterator does not work well with text that
  // has whitespaces at start.
  DCHECK(!base::IsUnicodeWhitespace(utterance[0]));

  const int utterance_length = utterance.length();
  if (utterance_length <= length_limit) {
    breaks.push_back(utterance_length - 1);
    return breaks;
  }

  if (length_limit == 1) {
    for (int i = 1; i < utterance_length; i++)
      breaks.push_back(base::checked_cast<uint16_t>(i));
    return breaks;
  }

  std::vector<int> sentence_ends = ui::GetSentenceEndOffsets(utterance);
  ConvertOffsetsToIndexes(sentence_ends);
  std::vector<int> word_ends = ui::GetWordEndOffsets(utterance);
  ConvertOffsetsToIndexes(word_ends);

  const int sentence_ends_length = sentence_ends.size();
  const int word_ends_length = word_ends.size();
  int cur_word_end_index = 0;
  int cur_sentence_end_index = 0;

  int text_start = 0;
  int text_end = -1;

  // Searching for the end of the text piece as long as the |text_end|
  // (i.e., the end of last text piece) is smaller than the last index of the
  // utterance.
  while (text_end < utterance_length - 1) {
    // The start of the current text piece is the end of last piece plus one.
    text_start = text_end + 1;

    // Find the sentence end that is within the |length_limit| distance from the
    // |text_start|.
    while (cur_sentence_end_index < sentence_ends_length &&
           sentence_ends[cur_sentence_end_index] - text_start < length_limit) {
      // Update the |text_end| if we find a sentence end bigger than the prior
      // |text_end|.
      text_end = std::max(text_end, sentence_ends[cur_sentence_end_index]);
      cur_sentence_end_index++;
    }
    // If we have found a sentence end as the end of current text piece,
    // continue to the next search.
    if (text_end >= text_start) {
      breaks.push_back(base::checked_cast<uint16_t>(text_end));
      continue;
    }

    // If there is no qualified sentence end, this means the current sentence
    // is longer than |length_limit|. We keep searching for a word end that is
    // within the |length_limit| distance from the |text_start|.
    while (cur_word_end_index < word_ends_length &&
           word_ends[cur_word_end_index] - text_start < length_limit) {
      // Update the |text_end| if we find a word end bigger than the prior
      // |text_end|.
      text_end = std::max(text_end, word_ends[cur_word_end_index]);
      cur_word_end_index++;
    }
    // If we have found a word end as the end of current text piece, continue to
    // the next search.
    if (text_end >= text_start) {
      breaks.push_back(base::checked_cast<uint16_t>(text_end));
      continue;
    }

    // If there is no sentence end or word end, we just return the index
    // corresponding to the |length_limit| or the end of the utterance. In
    // practice, this means the current word is longer than |length_limit|.
    text_end = std::min(text_start + length_limit - 1, utterance_length - 1);
    breaks.push_back(base::checked_cast<uint16_t>(text_end));
  }

  return breaks;
}

mojom::TtsResponsePtr GetResultOnError(
    const mojom::TtsRequestError error_code) {
  // TODO(crbug.com/40771006): Log errors.
  return mojom::TtsResponse::NewErrorCode(error_code);
}

mojom::TtsResponsePtr UnpackJsonResponse(const base::Value::List& list_data,
                                         const int start_index,
                                         const bool is_last_request) {
  // Depending on the size of input text (n), the list size should be 1 + 2n.
  // The first item in the list is "metadata", then each input text has one
  // dictionary for "text" and another dictionary for "audio". Since we only
  // have one input text (assuming one paragraph only), we should only have a
  // list with a size of three.
  if (list_data.size() != 3) {
    DVLOG(1)
        << "HTTP response for Enhance Network TTS has unexpected JSON data.";
    return GetResultOnError(mojom::TtsRequestError::kReceivedUnexpectedData);
  }

  // Decode timing information. Inside the "text" dictionary, the "timingInfo"
  // is encoded as:
  // "timingInfo":[
  //   {
  //      "text":<string>
  //      "location":{
  //        "textLocation": {"length": <int32>, "offset": <int32>},
  //        "timeLocation": { "timeOffset": <string>, "duration": <string> },
  //        "paragraphTextLocation": {"offset": <int32>, "length": <int32>},
  //   },
  //   ...
  // ]
  std::vector<mojom::TimingInfoPtr> timing_infos;
  const base::Value::Dict& text_dict = list_data[1].GetDict();
  const base::Value::List* timing_info_list =
      text_dict.FindListByDottedPath("text.timingInfo");
  if (timing_info_list == nullptr) {
    DVLOG(1) << "HTTP response for Enhance Network TTS has unexpected timing "
                "info data.";
    return GetResultOnError(mojom::TtsRequestError::kReceivedUnexpectedData);
  }

  for (size_t i = 0; i < timing_info_list->size(); ++i) {
    const base::Value::Dict& timing_info = (*timing_info_list)[i].GetDict();
    const std::string* timing_info_text_ptr = timing_info.FindString("text");
    const std::string* timing_info_timeoffset_ptr =
        timing_info.FindStringByDottedPath("location.timeLocation.timeOffset");
    const std::string* timing_info_duration_ptr =
        timing_info.FindStringByDottedPath("location.timeLocation.duration");
    // If the first item in the timing_info_list does not have a text offset,
    // we default that to 0. If the first item starts with whitespaces, the
    // server will send back the text offset for the item.
    std::optional<int> timing_info_text_offset =
        timing_info.FindIntByDottedPath("location.textLocation.offset");
    if (timing_info_text_offset == std::nullopt && i == 0) {
      timing_info_text_offset = 0;
    }

    if (timing_info_text_offset == std::nullopt || !timing_info_text_ptr ||
        !timing_info_timeoffset_ptr || !timing_info_duration_ptr) {
      continue;
    }
    // The text offset needs to be compensated with the start index of this
    // TtsData.
    timing_infos.push_back(mojom::TimingInfo::New(
        *timing_info_text_ptr, timing_info_text_offset.value() + start_index,
        *timing_info_timeoffset_ptr, *timing_info_duration_ptr));
  }

  // Decode audio data.
  const base::Value::Dict& audio_dict = list_data[2].GetDict();
  const std::string* audio_bytes_ptr =
      audio_dict.FindStringByDottedPath("audio.bytes");
  if (audio_bytes_ptr == nullptr) {
    DVLOG(1) << "HTTP response for Enhance Network TTS has unexpected audio "
                "bytes data.";
    return GetResultOnError(mojom::TtsRequestError::kReceivedUnexpectedData);
  }
  std::string audio_bytes = *audio_bytes_ptr;
  if (!base::Base64Decode(audio_bytes, &audio_bytes)) {
    DVLOG(1) << "Failed to decode the audio data for Enhance Network TTS.";
    return GetResultOnError(mojom::TtsRequestError::kReceivedUnexpectedData);
  }

  std::vector<uint8_t> audio =
      std::vector<uint8_t>(audio_bytes.begin(), audio_bytes.end());
  mojom::TtsDataPtr tts_data = mojom::TtsData::New(
      std::move(audio), std::move(timing_infos), is_last_request);
  // Send the decoded data to the caller.
  return mojom::TtsResponse::NewData(std::move(tts_data));
}

}  // namespace ash::enhanced_network_tts