chromium/ash/public/cpp/tab_cluster/correlation_clusterer.cc

// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "ash/public/cpp/tab_cluster/correlation_clusterer.h"

#include <map>
#include <optional>
#include <set>

#include "ash/public/cpp/tab_cluster/undirected_graph.h"
#include "base/containers/contains.h"
#include "base/logging.h"
#include "base/rand_util.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"

namespace ash {

namespace {

// Number of times we run the clustering algorithm on the given graph.
// This is an arbitrary number and might be subjected to further tuning.
constexpr int kNumIterations = 10;

// Converts current clustering into vector of vectors format.
std::vector<std::vector<int>> OutputClusters(
    const std::vector<int>& clustering) {
  std::map<int, std::vector<int>> clusters;
  for (size_t i = 0; i < clustering.size(); ++i) {
    clusters[clustering[i]].push_back(i);
  }
  std::vector<std::vector<int>> output;
  for (auto& key_value : clusters) {
    auto& cluster = key_value.second;
    output.push_back(std::move(cluster));
  }
  return output;
}
}  // namespace

// A helper class that keeps track of the sum of edge weights, accounting
// for missing edges, for best move computations.
class EdgeSum {
 public:
  EdgeSum() = default;
  EdgeSum(const EdgeSum&) = delete;
  EdgeSum& operator=(const EdgeSum&) = delete;
  ~EdgeSum() = default;

  // The edge weight `w` should have the edge weight offset subtracted before
  // calling this function.
  void Add(double w) { weight_ += w; }
  // Should be called at most once, after all edges have been Add()ed.
  void RemoveDoubleCounting() { weight_ /= 2.0; }
  // Retrieve the total weight of all edges seen, correcting for the implicit
  // negative weight of resolution multiplied by the product of the weights of
  // the two nodes incident to each edge.
  double NetWeight(
      double sum_prod_node_weights,
      const CorrelationClusterer::CorrelationClustererConfig& config) const {
    return weight_ - config.resolution * sum_prod_node_weights;
  }

 private:
  double weight_ = 0.0;
};

CorrelationClusterer::CorrelationClusterer() = default;
CorrelationClusterer::~CorrelationClusterer() = default;

std::vector<std::vector<int>> CorrelationClusterer::Cluster(
    const UndirectedGraph& undirected_graph) {
  Reset();

  graph_ = undirected_graph;
  num_nodes_ = graph_.NumNodes();

  // Create all-singletons initial clusters
  std::vector<std::vector<int>> clusters;

  for (int i = 0; i < num_nodes_; ++i) {
    clusters.push_back({i});
  }

  // Initialize to all-singletons clustering.
  clustering_.reserve(num_nodes_);
  for (int i = 0; i < num_nodes_; ++i) {
    int cluster = NewClusterId();
    clustering_.push_back(cluster);
    cluster_sizes_[cluster] = 1;
    cluster_weights_[cluster] = graph_.NodeWeight(i);
  }

  // Modularity objective.
  config_.resolution = 1.0 / graph_.total_node_weight();

  RefineClusters(&clusters);

  return clusters;
}

void CorrelationClusterer::RefineClusters(
    std::vector<std::vector<int>>* clusters_ptr) {
  std::string error;
  SetClustering(*clusters_ptr, &error);
  if (!error.empty()) {
    LOG(ERROR) << "Failed to set clustering " << error;
    return;
  }

  double objective = 0;

  auto try_moves = [&](std::vector<std::set<int>>* clusters_to_try) {
    base::RandomShuffle(clusters_to_try->begin(), clusters_to_try->end());
    for (const auto& cluster : *clusters_to_try) {
      std::pair<std::optional<int>, double> best_move = BestMove(cluster);
      if (best_move.second > 0) {
        std::optional<int> new_cluster = best_move.first;
        MoveNodesToCluster(cluster, new_cluster);
        objective += best_move.second;
      }
    }
  };

  for (int iter = 0; iter < kNumIterations; ++iter) {
    // Use current clusters as move sets, which means we'll consider
    // merging clusters.
    std::map<int, std::set<int>> node_cluster_map;
    for (int i = 0; i < num_nodes_; ++i) {
      node_cluster_map[ClusterForNode(i)].insert(i);
    }
    std::vector<std::set<int>> temp_clusters;
    for (auto& key_value : node_cluster_map) {
      auto& cluster = key_value.second;
      temp_clusters.push_back(std::move(cluster));
    }
    try_moves(&temp_clusters);
  }

  *clusters_ptr = OutputClusters(clustering_);
}

bool CorrelationClusterer::SetClustering(
    const std::vector<std::vector<int>>& clusters,
    std::string* error) {
  std::vector<bool> seen_nodes(num_nodes_);
  for (const auto& cluster : clusters) {
    int id = NewClusterId();
    for (const auto node : cluster) {
      if (node >= num_nodes_ || node < 0) {
        *error =
            base::StrCat({"Node id ", base::NumberToString(node),
                          " in initial clusters not in expected range [0, ",
                          base::NumberToString(num_nodes_), ")"});
        return false;
      }
      if (seen_nodes[node]) {
        *error = base::StrCat({"Node id ", base::NumberToString(node),
                               " appears in initial clusters more than once."});
        return false;
      }
      seen_nodes[node] = true;
      MoveNodeToCluster(node, id);
    }
  }
  for (int node = 0; node < num_nodes_; ++node) {
    if (!seen_nodes[node]) {
      *error = base::StrCat({"Node id ", base::NumberToString(node),
                             " does not appear in initial clusters."});
      return false;
    }
  }
  return true;
}

void CorrelationClusterer::MoveNodeToCluster(const int node,
                                             const int new_cluster) {
  const int old_cluster = clustering_[node];
  const double weight = graph_.NodeWeight(node);
  cluster_sizes_[old_cluster] -= 1;
  cluster_weights_[old_cluster] -= weight;
  if (cluster_sizes_[old_cluster] == 0) {
    DCHECK_EQ(static_cast<int>(cluster_sizes_.erase(old_cluster)), 1);
    DCHECK_EQ(static_cast<int>(cluster_weights_.erase(old_cluster)), 1);
  }
  clustering_[node] = new_cluster;
  cluster_sizes_[new_cluster] += 1;
  cluster_weights_[new_cluster] += weight;
}

// Null optional means make a new cluster.
void CorrelationClusterer::MoveNodesToCluster(const std::set<int>& nodes,
                                              std::optional<int> new_cluster) {
  int actual_new_cluster = new_cluster ? *new_cluster : NewClusterId();
  for (const auto& node : nodes) {
    MoveNodeToCluster(node, actual_new_cluster);
  }
}

std::pair<std::optional<int>, double> CorrelationClusterer::BestMove(
    const std::set<int>& moving_nodes) {
  // Weight of nodes in each cluster that are moving.
  std::map<int, double> cluster_moving_weights;
  // Class 2 edges where the endpoints are currently in different clusters.
  EdgeSum class_2_currently_separate;
  // Class 1 edges where the endpoints are currently in the same cluster.
  EdgeSum class_1_currently_together;
  // Class 1 edges, grouped by the cluster that the non-moving node is in.
  std::map<int, EdgeSum> class_1_together_after;

  double moving_nodes_weight = 0;
  for (const auto& node : moving_nodes) {
    const int node_cluster = clustering_[node];
    cluster_moving_weights[node_cluster] += graph_.NodeWeight(node);
    moving_nodes_weight += graph_.NodeWeight(node);
    for (const auto& edge : graph_.Neighbors(node)) {
      const auto neighbor = edge.first;
      const auto weight = edge.second;
      const int neighbor_cluster = clustering_[neighbor];
      if (base::Contains(moving_nodes, neighbor)) {
        // Class 2 edge.
        if (node_cluster != neighbor_cluster) {
          class_2_currently_separate.Add(weight);
        }
      } else {
        // Class 1 edge.
        if (node_cluster == neighbor_cluster) {
          class_1_currently_together.Add(weight);
        }
        class_1_together_after[neighbor_cluster].Add(weight);
      }
    }
  }
  class_2_currently_separate.RemoveDoubleCounting();
  // Now cluster_moving_weights is correct and class_2_currently_separate,
  // class_1_currently_together, and class_1_by_cluster are ready to call
  // NetWeight().

  return BestMoveFromStats(moving_nodes_weight, cluster_moving_weights,
                           class_2_currently_separate,
                           class_1_currently_together, class_1_together_after);
}

std::pair<std::optional<int>, double> CorrelationClusterer::BestMoveFromStats(
    double moving_nodes_weight,
    std::map<int, double>& cluster_moving_weights,
    const EdgeSum& class_2_currently_separate,
    const EdgeSum& class_1_currently_together,
    const std::map<int, EdgeSum>& class_1_together_after) {
  double change_in_objective = 0.0;

  auto half_square = [](double x) { return x * x / 2.0; };
  double max_edges = half_square(moving_nodes_weight);
  for (const auto& cluster_moving_weight : cluster_moving_weights) {
    max_edges -= half_square(cluster_moving_weight.second);
  }
  change_in_objective +=
      class_2_currently_separate.NetWeight(max_edges, config_);

  max_edges = 0;
  for (const auto& cluster_moving_weight : cluster_moving_weights) {
    max_edges +=
        moving_nodes_weight * (GetClusterWeight(cluster_moving_weight.first) -
                               cluster_moving_weight.second);
  }
  change_in_objective -=
      class_1_currently_together.NetWeight(max_edges, config_);

  std::pair<std::optional<int>, double> best_move;
  best_move.first = std::nullopt;
  best_move.second = change_in_objective;
  for (const auto& cluster_data : class_1_together_after) {
    int cluster = cluster_data.first;
    const EdgeSum& data = cluster_data.second;
    max_edges = moving_nodes_weight *
                (GetClusterWeight(cluster) - cluster_moving_weights[cluster]);
    // Change in objective if we move the moving nodes to cluster i.
    double overall_change_in_objective =
        change_in_objective + data.NetWeight(max_edges, config_);
    if (overall_change_in_objective > best_move.second ||
        (overall_change_in_objective == best_move.second &&
         cluster < best_move.first)) {
      best_move.first = cluster;
      best_move.second = overall_change_in_objective;
    }
  }
  return best_move;
}

int CorrelationClusterer::NewClusterId() {
  return next_cluster_id_++;
}

int CorrelationClusterer::ClusterForNode(int node) const {
  return clustering_[node];
}

double CorrelationClusterer::GetClusterWeight(int cluster_id) const {
  return cluster_weights_.at(cluster_id);
}

void CorrelationClusterer::Reset() {
  clustering_.clear();
  cluster_sizes_.clear();
  cluster_weights_.clear();
  next_cluster_id_ = 0;
}

}  // namespace ash