// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ash/public/cpp/tab_cluster/clusterer.h"
#include "ash/public/cpp/tab_cluster/tab_cluster_ui_item.h"
#include "ash/public/cpp/tab_cluster/undirected_graph.h"
#include "base/containers/contains.h"
#include "base/logging.h"
namespace ash {
namespace {
// Returns true if the source is irrelevant to the clusterer.
bool ShouldSkip(const std::string& source) {
// "about: and chrome:// are not recorded as host"
return source.empty() || source == "blank" || source == "newtab/";
}
// Gets the list of cluster_id for nodes.
std::vector<int> GetNodeClusterId(std::vector<std::vector<int>> clusters,
int num_nodes) {
std::vector<int> node_cluster_id(num_nodes);
for (size_t cluster_id = 0; cluster_id < clusters.size(); ++cluster_id) {
for (int node_id : clusters[cluster_id]) {
node_cluster_id[node_id] = cluster_id;
}
}
return node_cluster_id;
}
double GetBoundaryStrength(UndirectedGraph& graph,
std::vector<int> cluster,
std::vector<int> node_cluster_id) {
int internal_edge_weight = 0;
int external_edge_weight = 0;
for (int node : cluster) {
for (const auto& edge : graph.Neighbors(node)) {
const auto neighbor = edge.first;
const auto weight = edge.second;
if (node_cluster_id[neighbor] == node_cluster_id[node]) {
internal_edge_weight += weight;
} else {
external_edge_weight += weight;
}
}
}
external_edge_weight = std::max(1, external_edge_weight);
// Internal edge weight is accounted twice.
// As internal edge weight and external edge weight are both int, divide
// internal edge weight by 2.0 to get a double result if present.
return (internal_edge_weight / 2.0) / external_edge_weight;
}
} // namespace
Clusterer::Clusterer() = default;
Clusterer::~Clusterer() = default;
std::vector<TabClusterUIItem*> Clusterer::GetUpdatedClusterInfo(
const TabItems& tab_items,
TabClusterUIItem* old_active_item,
TabClusterUIItem* new_active_item) {
std::vector<TabClusterUIItem*> items;
if (!old_active_item || !new_active_item)
return items;
const std::string& old_source = old_active_item->current_info().source;
const std::string& new_source = new_active_item->current_info().source;
// Ignores irrelevant sources and self-loop.
if (ShouldSkip(old_source) || ShouldSkip(new_source) ||
old_source == new_source)
return items;
AddEdge(old_source, new_source);
// Get cluster results from the current graph.
std::map<std::string, ClusterResult> result_map = Cluster();
for (const auto& tab_item : tab_items) {
const std::string& source = tab_item.get()->current_info().source;
// Tab item source might not be present in result map yet when users have
// yet to navigate between two tabs that are not ignored.
if (!base::Contains(result_map, source)) {
continue;
}
ClusterResult result = result_map.at(source);
bool item_updated = false;
if (tab_item->current_info().cluster_id != result.cluster_id) {
tab_item->SetCurrentClusterId(result.cluster_id);
item_updated = true;
}
if (tab_item->current_info().boundary_strength !=
result.boundary_strength) {
tab_item->SetCurrentBoundaryStrength(result.boundary_strength);
item_updated = true;
}
if (item_updated) {
items.push_back(tab_item.get());
}
}
return items;
}
std::map<std::string, ClusterResult> Clusterer::Cluster() {
std::vector<std::vector<int>> clusters =
correlation_clusterer_.Cluster(graph_);
std::vector<int> node_cluster_id =
GetNodeClusterId(clusters, graph_.NumNodes());
std::vector<double> boundary_strength_by_cluster_id;
for (const auto& cluster : clusters) {
boundary_strength_by_cluster_id.push_back(
GetBoundaryStrength(graph_, cluster, node_cluster_id));
}
std::map<std::string, ClusterResult> result_map;
for (size_t node = 0; node < node_cluster_id.size(); ++node) {
ClusterResult result;
result.cluster_id = node_cluster_id[node];
result.boundary_strength =
boundary_strength_by_cluster_id[result.cluster_id];
result_map[node_to_source_.at(node)] = result;
}
return result_map;
}
void Clusterer::AddEdge(const std::string& from_source,
const std::string& to_source) {
size_t from_node = GetNodeForSource(from_source);
size_t to_node = GetNodeForSource(to_source);
graph_.AddUndirectedEdgeAndNodeWeight(from_node, to_node);
}
size_t Clusterer::GetNodeForSource(const std::string& source) {
if (base::Contains(source_to_node_, source)) {
return source_to_node_.at(source);
}
size_t curr_size = source_to_node_.size();
node_to_source_[curr_size] = source;
source_to_node_[source] = curr_size;
return curr_size;
}
std::vector<std::string> Clusterer::GetSourcesFromCluster(
std::vector<int> cluster) {
std::vector<std::string> sources;
for (int node : cluster) {
sources.push_back(node_to_source_.at(node));
}
return sources;
}
} // namespace ash