// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chromeos/components/mahi/ax_tree_extractor.h"
#include <memory>
#include <queue>
#include <string>
#include "base/containers/contains.h"
#include "base/i18n/break_iterator.h"
#include "base/strings/string_util.h"
#include "base/time/time.h"
#include "chromeos/components/mahi/public/mojom/content_extraction.mojom.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "ui/accessibility/ax_node.h"
#include "ui/accessibility/ax_serializable_tree.h"
#include "ui/accessibility/ax_tree.h"
namespace mahi {
namespace {
using ::base::i18n::BreakIterator;
// Time after which an idle connection to Screen AI service is disconnected.
constexpr base::TimeDelta kScreenAIIdleDisconnectDelay = base::Minutes(5);
static const ax::mojom::Role kContentRoles[]{
ax::mojom::Role::kHeading,
ax::mojom::Role::kParagraph,
ax::mojom::Role::kNote,
};
static const ax::mojom::Role kRolesToSkip[]{
ax::mojom::Role::kAudio, ax::mojom::Role::kBanner,
ax::mojom::Role::kButton, ax::mojom::Role::kComplementary,
ax::mojom::Role::kContentInfo, ax::mojom::Role::kFooter,
ax::mojom::Role::kImage, ax::mojom::Role::kLabelText,
ax::mojom::Role::kNavigation, ax::mojom::Role::kSectionFooter,
};
// Recurse through the root node, searching for content nodes (any node whose
// role is in kContentRoles). Skip branches which begin with a node with role
// in kRolesToSkip. Once a content node is identified, add it to the vector
// `content_node_ids`, whose pointer is passed through the recursion. For nodes
// that does not fall into either role list, we further dive into its child
// nodes until either eligible node if found or we have reached the leave of the
// tree.
void AddContentNodesToVector(const ui::AXNode* node,
std::vector<ui::AXNodeID>* content_node_ids) {
if (base::Contains(kContentRoles, node->GetRole())) {
content_node_ids->emplace_back(node->id());
return;
}
if (base::Contains(kRolesToSkip, node->GetRole())) {
return;
}
// The node's role not in either kContentRoles or kRolesToSkip. Check its
// child nodes.
for (auto iter = node->UnignoredChildrenBegin();
iter != node->UnignoredChildrenEnd(); ++iter) {
AddContentNodesToVector(iter.get(), content_node_ids);
}
}
// Get contents from the a11y tree based on the `content_node_ids`.
void GetContents(const ui::AXNode* root,
const std::vector<ui::AXNodeID>& content_node_ids,
std::u16string* contents) {
if (!root || content_node_ids.empty()) {
return;
}
// If a content node is found, add its content to the result and early return.
if (base::Contains(content_node_ids, root->id())) {
if (!contents->empty()) {
contents->append(u"\n\n");
}
contents->append(root->GetTextContentUTF16());
return;
}
// Use dfs search to ensure the contents is the same order as users see them
// in the page.
// TODO(chenjih): Revisit this if ax tree can be super deep. But this should
// be quite rare.
for (auto iter = root->UnignoredChildrenBegin();
iter != root->UnignoredChildrenEnd(); ++iter) {
GetContents(iter.get(), content_node_ids, contents);
}
}
// Get word count from contents.
int GetContentsWordCount(std::u16string& contents) {
int word_count = 0;
BreakIterator break_iter(contents, BreakIterator::BREAK_WORD);
if (!break_iter.Init()) {
return word_count;
}
while (break_iter.Advance()) {
if (break_iter.IsWord()) {
++word_count;
}
}
return word_count;
}
} // namespace
AXTreeExtractor::AXTreeExtractor() = default;
AXTreeExtractor::~AXTreeExtractor() = default;
void AXTreeExtractor::OnScreen2xReady(
mojo::PendingRemote<screen_ai::mojom::Screen2xMainContentExtractor>
screen2x_content_extractor) {
// Drop the callback if the extractor is already bound.
if (screen2x_main_content_extractor_.is_bound()) {
return;
}
// Etablish a connection with screen AI service is not already made and set
// it to reset if it stays idle for `kScreenAIIdleDisconnectDelay` to release
// resources.
screen2x_main_content_extractor_.Bind(std::move(screen2x_content_extractor));
screen2x_main_content_extractor_.reset_on_disconnect();
screen2x_main_content_extractor_.reset_on_idle_timeout(
kScreenAIIdleDisconnectDelay);
}
void AXTreeExtractor::ExtractContent(
mojom::ExtractionRequestPtr extraction_request,
ExtractContentCallback callback) {
if (extraction_request->snapshot.has_value()) {
ExtractContentFromSnapshot(std::move(extraction_request),
std::move(callback));
} else {
if (!extraction_request->updates.has_value()) {
mojo::ReportBadMessage("No AXTree snapshot or updates were detected.");
std::move(callback).Run(nullptr);
return;
}
ExtractContentFromAXTreeUpdates(std::move(extraction_request),
std::move(callback));
}
}
void AXTreeExtractor::GetContentSize(
mojom::ExtractionRequestPtr content_size_request,
GetContentSizeCallback callback) {
// Deserializes the snapshot.
std::unique_ptr<ui::AXTree> tree =
std::make_unique<ui::AXTree>(content_size_request->snapshot.value());
std::vector<ui::AXNodeID> content_node_ids;
if (content_size_request->extraction_methods->use_algorithm) {
DistillViaAlgorithm(tree.get(), &content_node_ids);
}
mojom::ResponseStatus error_status = mojom::ResponseStatus::kSuccess;
if (content_size_request->extraction_methods->use_screen2x &&
screen2x_main_content_extractor_.is_bound() &&
screen2x_main_content_extractor_.is_connected()) {
OnAxTreeDistilledCallback on_ax_tree_distilled_callback =
base::BindOnce(&AXTreeExtractor::OnDistilledForContentSize,
weak_ptr_factory_.GetWeakPtr(), std::move(tree),
std::move(callback), error_status);
screen2x_main_content_extractor_->ExtractMainContent(
content_size_request->snapshot.value(),
content_size_request->ukm_source_id.value(),
base::BindOnce(&AXTreeExtractor::OnGetScreen2xResult,
weak_ptr_factory_.GetWeakPtr(),
std::move(content_node_ids),
std::move(on_ax_tree_distilled_callback)));
return;
}
// If screen2x is not available when receiving a request, report the error
// status. Don't early return here as rule based algorithm may still work.
if (content_size_request->extraction_methods->use_screen2x) {
error_status = mojom::ResponseStatus::kScreen2xNotAvailable;
}
OnDistilledForContentSize(std::move(tree), std::move(callback), error_status,
content_node_ids);
}
void AXTreeExtractor::ExtractContentFromSnapshot(
mojom::ExtractionRequestPtr extraction_request,
ExtractContentCallback callback) {
if (!extraction_request->snapshot.has_value()) {
mojo::ReportBadMessage("No AXTree snapshot were detected.");
std::move(callback).Run(nullptr);
return;
}
// Deserializes the snapshot.
std::unique_ptr<ui::AXTree> tree =
std::make_unique<ui::AXTree>(extraction_request->snapshot.value());
std::vector<ui::AXNodeID> content_node_ids;
if (extraction_request->extraction_methods->use_algorithm) {
DistillViaAlgorithm(tree.get(), &content_node_ids);
}
mojom::ResponseStatus error_status = mojom::ResponseStatus::kSuccess;
if (extraction_request->extraction_methods->use_screen2x &&
screen2x_main_content_extractor_.is_bound() &&
screen2x_main_content_extractor_.is_connected()) {
auto on_ax_tree_distilled_callback =
base::BindOnce(&AXTreeExtractor::OnDistilledForContentExtraction,
weak_ptr_factory_.GetWeakPtr(), std::move(tree),
std::move(callback), error_status);
screen2x_main_content_extractor_->ExtractMainContent(
extraction_request->snapshot.value(),
extraction_request->ukm_source_id.value(),
base::BindOnce(&AXTreeExtractor::OnGetScreen2xResult,
weak_ptr_factory_.GetWeakPtr(),
std::move(content_node_ids),
std::move(on_ax_tree_distilled_callback)));
return;
}
// If screen2x is not available when receiving a request, report the error
// status. Don't early return here as rule based algorithm may still work.
if (extraction_request->extraction_methods->use_screen2x) {
error_status = mojom::ResponseStatus::kScreen2xNotAvailable;
}
OnDistilledForContentExtraction(std::move(tree), std::move(callback),
error_status, content_node_ids);
}
// TODO(b:333803190): consider merging this with ExtractContentFromSnapshot if
// possible.
void AXTreeExtractor::ExtractContentFromAXTreeUpdates(
mojom::ExtractionRequestPtr extraction_request,
ExtractContentCallback callback) {
if (!extraction_request->updates.has_value()) {
mojo::ReportBadMessage("No AXTree updates were detected.");
std::move(callback).Run(nullptr);
return;
}
std::unique_ptr<ui::AXTree> tree = std::make_unique<ui::AXTree>();
// Unserialize the updates.
for (const ui::AXTreeUpdate& update : extraction_request->updates.value()) {
tree->Unserialize(update);
}
std::vector<ui::AXNodeID> content_node_ids;
if (extraction_request->extraction_methods->use_algorithm) {
DistillViaAlgorithm(tree.get(), &content_node_ids);
}
// TODO(b:333803190): Figure out how to call screen2x using the tree.
OnDistilledForContentExtraction(std::move(tree), std::move(callback),
mojom::ResponseStatus::kScreen2xNotAvailable,
content_node_ids);
}
void AXTreeExtractor::DistillViaAlgorithm(
const ui::AXTree* tree,
std::vector<ui::AXNodeID>* content_node_ids) {
AddContentNodesToVector(tree->root(), content_node_ids);
}
void AXTreeExtractor::OnGetScreen2xResult(
std::vector<ui::AXNodeID> content_node_ids_algorithm,
OnAxTreeDistilledCallback on_ax_tree_distilled_callback,
const std::vector<ui::AXNodeID>& content_node_ids_screen2x) {
// Merges the results of algorithm and screen2x.
for (ui::AXNodeID content_node_id_screen2x : content_node_ids_screen2x) {
if (!base::Contains(content_node_ids_algorithm, content_node_id_screen2x)) {
content_node_ids_algorithm.push_back(content_node_id_screen2x);
}
}
std::move(on_ax_tree_distilled_callback).Run(content_node_ids_algorithm);
}
void AXTreeExtractor::OnDistilledForContentExtraction(
std::unique_ptr<ui::AXTree> tree,
ExtractContentCallback callback,
mojom::ResponseStatus error_status,
const std::vector<ui::AXNodeID>& content_node_ids) {
mojom::ExtractionResponsePtr extraction_response =
mojom::ExtractionResponse::New();
std::u16string contents;
GetContents(tree->root(), content_node_ids, &contents);
extraction_response->contents = std::move(contents);
extraction_response->status = error_status;
std::move(callback).Run(std::move(extraction_response));
}
void AXTreeExtractor::OnDistilledForContentSize(
std::unique_ptr<ui::AXTree> tree,
GetContentSizeCallback callback,
mojom::ResponseStatus error_status,
const std::vector<ui::AXNodeID>& content_node_ids) {
mojom::ContentSizeResponsePtr content_size_response =
mojom::ContentSizeResponse::New();
std::u16string contents;
GetContents(tree->root(), content_node_ids, &contents);
content_size_response->word_count = GetContentsWordCount(contents);
content_size_response->status = error_status;
std::move(callback).Run(std::move(content_size_response));
}
} // namespace mahi