chromium/ash/system/focus_mode/focus_mode_tasks_provider.cc

// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "ash/system/focus_mode/focus_mode_tasks_provider.h"

#include <algorithm>
#include <optional>
#include <vector>

#include "ash/api/tasks/tasks_controller.h"
#include "ash/api/tasks/tasks_delegate.h"
#include "ash/api/tasks/tasks_types.h"
#include "ash/system/focus_mode/focus_mode_retry_util.h"
#include "base/barrier_closure.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/ranges/algorithm.h"
#include "base/ranges/ranges.h"
#include "base/strings/string_number_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/time/time.h"
#include "google_apis/common/api_error_codes.h"
#include "url/gurl.h"

namespace ash {

namespace {

// The tasks UI has limited space, so we restrict to showing N tasks.
constexpr size_t kTasksToFetch = 5;

// In order to get these tasks, we first query the API for task lists. We then
// query the task lists until we have received at least N tasks. To reduce
// latency, we query up to `kListFetchBatchSize` task lists in parallel.
constexpr size_t kListFetchBatchSize = 8;

// Controls the amount of time we'll serve a cached version of the task list.
constexpr base::TimeDelta kCacheLifetime = base::Seconds(30);

// Used to sort tasks for the carousel.
struct TaskComparator {
  // Tasks are classified into these groups and within each group sorted by
  // their update time. Tasks that have been created by the user in the focus
  // mode UI appear first, followed by past due tasks and so on.
  enum class TaskGroupOrdering {
    kCreatedInSession,
    kPastDue,
    kDueSoon,
    kDueLater,
  };

  bool operator()(const FocusModeTask& lhs, const FocusModeTask& rhs) const {
    auto lhs_group = GetOrdering(lhs);
    auto rhs_group = GetOrdering(rhs);
    if (lhs_group != rhs_group) {
      return lhs_group < rhs_group;
    }

    return lhs.updated > rhs.updated;
  }

  TaskGroupOrdering GetOrdering(const FocusModeTask& entry) const {
    if (created_task_ids->contains(entry.task_id)) {
      return TaskGroupOrdering::kCreatedInSession;
    }

    auto remaining = entry.due.value_or(base::Time::Max()) - now;
    if (remaining < base::Hours(0)) {
      return TaskGroupOrdering::kPastDue;
    } else if (remaining < base::Hours(24)) {
      return TaskGroupOrdering::kDueSoon;
    }
    return TaskGroupOrdering::kDueLater;
  }

  base::Time now;
  raw_ref<base::flat_set<TaskId>> created_task_ids;
};

}  // namespace

std::strong_ordering TaskId::operator<=>(const TaskId& other) const {
  if (pending && other.pending) {
    // Two pending ids are always equivalent.
    return std::strong_ordering::equivalent;
  }
  if (pending != other.pending) {
    // If pending does not match, use the ordering for bools.
    return pending <=> other.pending;
  }

  if (list_id < other.list_id || (list_id == other.list_id && id < other.id)) {
    return std::strong_ordering::less;
  }
  if (list_id > other.list_id || (list_id == other.list_id && id > other.id)) {
    return std::strong_ordering::greater;
  }
  return std::strong_ordering::equivalent;
}

FocusModeTask::FocusModeTask() = default;
FocusModeTask::~FocusModeTask() = default;
FocusModeTask::FocusModeTask(const FocusModeTask&) = default;
FocusModeTask::FocusModeTask(FocusModeTask&&) = default;
FocusModeTask& FocusModeTask::operator=(const FocusModeTask&) = default;
FocusModeTask& FocusModeTask::operator=(FocusModeTask&&) = default;

// Helper used to fetch tasks from the API. It starts by querying for task
// lists, and then queries tasks from each list.
class TaskFetcher {
 public:
  void Start(base::OnceClosure done) {
    done_ = std::move(done);
    GetTaskListsInternal();
  }

  std::string GetMostRecentlyUpdatedTaskList() const {
    return task_lists_.empty() ? "" : task_lists_[0].first;
  }

  std::vector<FocusModeTask> GetTasks() && { return std::move(tasks_); }

  bool error() const { return error_; }

 private:
  // Invokes API request to get the task lists. It may be retried for certain
  // HTTP errors.
  void GetTaskListsInternal() {
    if (api::TasksDelegate* delegate =
            api::TasksController::Get()->tasks_delegate()) {
      delegate->GetTaskLists(
          /*force_fetch=*/true, base::BindOnce(&TaskFetcher::OnGetTaskLists,
                                               weak_factory_.GetWeakPtr()));
    }
  }

  void GetTasksInternal(const std::string& list_id,
                        base::RepeatingClosure barrier) {
    if (api::TasksDelegate* delegate =
            api::TasksController::Get()->tasks_delegate()) {
      delegate->GetTasks(
          list_id,
          /*force_fetch=*/true,
          base::BindOnce(&TaskFetcher::OnGetTasks, weak_factory_.GetWeakPtr(),
                         list_id, barrier));
    }
  }

  void OnGetTaskLists(bool success,
                      std::optional<google_apis::ApiErrorCode> http_error,
                      const ui::ListModel<api::TaskList>* api_task_lists) {
    // Handle HTTP errors and apply retires.
    if (http_error.has_value() &&
        http_error.value() != google_apis::HTTP_SUCCESS) {
      // Handle too many request error.
      if (http_error == 429) {
        // Retry if needed.
        if (get_task_lists_retry_state_.retry_index <
            kMaxRetryTooManyRequests) {
          get_task_lists_retry_state_.retry_index++;
          get_task_lists_retry_state_.timer.Start(
              FROM_HERE, kWaitTimeTooManyRequests,
              base::BindOnce(&TaskFetcher::GetTaskListsInternal,
                             weak_factory_.GetWeakPtr()));
          return;
        }

        // Max number of retries reached. Bail gracefully.
        error_ = true;
        get_task_lists_retry_state_.Reset();
        std::move(done_).Run();
        return;
      }

      // Handle general HTTP errors.
      if (ShouldRetryHttpError(http_error.value())) {
        // Retry if needed.
        if (get_task_lists_retry_state_.retry_index < kMaxRetryOverall) {
          get_task_lists_retry_state_.retry_index++;
          get_task_lists_retry_state_.timer.Start(
              FROM_HERE,
              GetExponentialBackoffRetryWaitTime(
                  get_task_lists_retry_state_.retry_index),
              base::BindOnce(&TaskFetcher::GetTaskListsInternal,
                             weak_factory_.GetWeakPtr()));
          return;
        }

        // Max number of retries reached. Bail gracefully.
        error_ = true;
        get_task_lists_retry_state_.Reset();
        std::move(done_).Run();
        return;
      }

      // Other unhandled HTTP errors. Bail gracefully.
      error_ = true;
      get_task_lists_retry_state_.Reset();
      std::move(done_).Run();
      return;
    }

    if (!api_task_lists || api_task_lists->item_count() == 0) {
      get_task_lists_retry_state_.Reset();
      std::move(done_).Run();
      return;
    }

    // Collect the task lists and sort them so that the greatest one is first.
    task_lists_.reserve(api_task_lists->item_count());
    for (const auto& list : *api_task_lists) {
      task_lists_.emplace_back(list->id, list->updated);
    }
    base::ranges::sort(task_lists_, std::greater{},
                       &std::pair<std::string, base::Time>::second);

    MaybeFetchMoreTasks();
  }

  // If we haven't yet fetched enough tasks to show *and* there are lists that
  // haven't yet been queried, then try to fetch more tasks. In any other case,
  // we invoke the done callback.
  void MaybeFetchMoreTasks() {
    const auto lists_left = task_lists_.size() - task_list_fetch_index_;
    if (lists_left == 0 || tasks_.size() >= kTasksToFetch) {
      // We are done.
      std::move(done_).Run();
      return;
    }

    const auto batch_size = std::min(lists_left, kListFetchBatchSize);
    auto barrier = base::BarrierClosure(
        batch_size, base::BindOnce(&TaskFetcher::MaybeFetchMoreTasks,
                                   weak_factory_.GetWeakPtr()));

    // The code here is structured so that we don't modify any members after
    // calling `GetTasks`. This is done so that the code still works if
    // `GetTasks` invokes the callback synchronously (which happens in tests).
    auto next_task_list_index = task_list_fetch_index_;
    task_list_fetch_index_ += batch_size;

    for (size_t i = 0; i != batch_size; ++i) {
      const std::string& list_id = task_lists_[next_task_list_index++].first;
      GetTasksInternal(list_id, barrier);
    }
  }

  void OnGetTasks(const std::string& list_id,
                  base::RepeatingClosure barrier,
                  bool success,
                  std::optional<google_apis::ApiErrorCode> http_error,
                  const ui::ListModel<api::Task>* api_tasks) {
    // Handle HTTP errors and apply retires.
    if (http_error.has_value() &&
        http_error.value() != google_apis::HTTP_SUCCESS) {
      // Handle too many request error.
      if (http_error == 429) {
        // Retry if needed.
        if (get_tasks_retry_state_.retry_index < kMaxRetryTooManyRequests) {
          get_tasks_retry_state_.retry_index++;
          get_tasks_retry_state_.timer.Start(
              FROM_HERE, kWaitTimeTooManyRequests,
              base::BindOnce(&TaskFetcher::GetTasksInternal,
                             weak_factory_.GetWeakPtr(), list_id, barrier));
          return;
        }

        // Max number of retries reached. Bail gracefully.
        get_tasks_retry_state_.Reset();
        std::move(barrier).Run();
        return;
      }

      // Handle general HTTP errors.
      if (ShouldRetryHttpError(http_error.value())) {
        // Retry if needed.
        if (get_tasks_retry_state_.retry_index < kMaxRetryOverall) {
          get_tasks_retry_state_.retry_index++;
          get_tasks_retry_state_.timer.Start(
              FROM_HERE,
              GetExponentialBackoffRetryWaitTime(
                  get_tasks_retry_state_.retry_index),
              base::BindOnce(&TaskFetcher::GetTasksInternal,
                             weak_factory_.GetWeakPtr(), list_id, barrier));
          return;
        }

        // Max number of retries reached. Bail gracefully.
        get_tasks_retry_state_.Reset();
        std::move(barrier).Run();
        return;
      }

      // Other unhandled HTTP errors. Bail gracefully.
      get_tasks_retry_state_.Reset();
      std::move(barrier).Run();
      return;
    }

    // NOTE: Completed tasks will not show up in `api_tasks`.
    if (success && api_tasks) {
      for (const auto& api_task : *api_tasks) {
        // Skip tasks with empty titles.
        if (api_task->title.empty()) {
          continue;
        }
        FocusModeTask& task = tasks_.emplace_back();
        task.task_id = {.list_id = list_id, .id = api_task->id};
        task.title = api_task->title;
        task.updated = api_task->updated;
        task.due = api_task->due;
      }
    }

    // Do not do anything with `this` after this line since the fetcher will be
    // deleted after the last list has been queried.
    std::move(barrier).Run();
  }

  // This will only be set after retries if retries are conducted.
  bool error_ = false;

  // Task list IDs, sorted by creation time.
  std::vector<std::pair<std::string, base::Time>> task_lists_;

  // The index of the next task list to fetch tasks for.
  std::size_t task_list_fetch_index_ = 0;

  // Tasks fetched.
  std::vector<FocusModeTask> tasks_;

  // Invoked when the fetcher is complete.
  base::OnceClosure done_;

  FocusModeRetryState get_task_lists_retry_state_;
  FocusModeRetryState get_tasks_retry_state_;

  base::WeakPtrFactory<TaskFetcher> weak_factory_{this};
};

FocusModeTasksProvider::FocusModeTasksProvider() = default;
FocusModeTasksProvider::~FocusModeTasksProvider() = default;

void FocusModeTasksProvider::ScheduleTaskListUpdate() {
  if (!task_fetcher_) {
    // We don't start a new fetch if a fetch is already running.
    task_fetcher_ = std::make_unique<TaskFetcher>();
    task_fetcher_->Start(base::BindOnce(&FocusModeTasksProvider::OnTasksFetched,
                                        weak_factory_.GetWeakPtr()));
  }
}

void FocusModeTasksProvider::Reset() {
  task_fetcher_ = nullptr;
  task_fetch_time_ = {};
  task_list_for_new_task_ = {};
  tasks_.clear();
  deleted_task_ids_.clear();
}

const std::vector<FocusModeTask> FocusModeTasksProvider::TasksForTesting()
    const {
  return tasks_;
}

void FocusModeTasksProvider::GetSortedTaskList(OnGetTasksCallback callback) {
  if ((base::Time::Now() - task_fetch_time_) < kCacheLifetime) {
    base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
        FROM_HERE, base::BindOnce(std::move(callback), GetSortedTasksImpl()));
    return;
  }

  get_tasks_requests_.push_back(std::move(callback));
  ScheduleTaskListUpdate();
}

void FocusModeTasksProvider::GetTask(const std::string& task_list_id,
                                     const std::string& task_id,
                                     OnGetTaskCallback callback) {
  CHECK(!task_list_id.empty());
  CHECK(!task_id.empty());

  if (api::TasksDelegate* delegate =
          api::TasksController::Get()->tasks_delegate()) {
    delegate->GetTasks(
        task_list_id, /*force_fetch=*/true,
        base::BindOnce(&FocusModeTasksProvider::OnTasksFetchedForTask,
                       weak_factory_.GetWeakPtr(), task_list_id, task_id,
                       std::move(callback)));
  }
}

void FocusModeTasksProvider::AddTask(const std::string& title,
                                     OnTaskSavedCallback callback) {
  if (task_list_for_new_task_.empty()) {
    // TODO(b/339667327): Instead of failing the request, consider queueing it.
    std::move(callback).Run(FocusModeTask{});
    return;
  }

  // Clear the cache. This is done so that the backend is queried the next time
  // a task list is requested. This in turn is done so that we can get the
  // actual ID of the newly created task.
  task_fetch_time_ = {};
  AddTaskInternal(title, std::move(callback));
}

void FocusModeTasksProvider::UpdateTask(const std::string& task_list_id,
                                        const std::string& task_id,
                                        const std::string& title,
                                        bool completed,
                                        OnTaskSavedCallback callback) {
  CHECK(!task_id.empty());
  CHECK(!task_list_id.empty());

  if (completed) {
    deleted_task_ids_.insert({.list_id = task_list_id, .id = task_id});
  }

  UpdateTaskInternal(task_list_id, task_id, title, completed,
                     std::move(callback));
}

void FocusModeTasksProvider::OnTasksFetched() {
  CHECK(task_fetcher_);

  if (!task_fetcher_->error()) {
    task_fetch_time_ = base::Time::Now();
    task_list_for_new_task_ = task_fetcher_->GetMostRecentlyUpdatedTaskList();
    tasks_ = std::move(*task_fetcher_).GetTasks();
  } else {
    tasks_ = {};
    task_list_for_new_task_ = {};
  }
  task_fetcher_ = nullptr;

  // Make sure to clear this in case there are tasks completed through Focus
  // mode that the user then un-completed outside of Focus mode.
  deleted_task_ids_ = {};

  auto pending = std::move(get_tasks_requests_);
  auto tasks = GetSortedTasksImpl();
  for (auto& callback : pending) {
    std::move(callback).Run(tasks);
  }
}

void FocusModeTasksProvider::OnTasksFetchedForTask(
    const std::string& task_list_id,
    const std::string& task_id,
    OnGetTaskCallback callback,
    bool success,
    std::optional<google_apis::ApiErrorCode> http_error,
    const ui::ListModel<api::Task>* api_tasks) {
  // Handle HTTP errors and apply retires.
  if (http_error.has_value() &&
      http_error.value() != google_apis::HTTP_SUCCESS) {
    // Handle too many request error.
    if (http_error == 429) {
      // Retry if needed.
      if (get_task_retry_state_.retry_index < kMaxRetryTooManyRequests) {
        get_task_retry_state_.retry_index++;
        get_task_retry_state_.timer.Start(
            FROM_HERE, kWaitTimeTooManyRequests,
            base::BindOnce(&FocusModeTasksProvider::GetTask,
                           weak_factory_.GetWeakPtr(), task_list_id, task_id,
                           std::move(callback)));
        return;
      }

      // Max number of retries reached. Bail gracefully.
      std::move(callback).Run(FocusModeTask{});
      get_task_retry_state_.Reset();
      return;
    }

    // Handle general HTTP errors.
    if (ShouldRetryHttpError(http_error.value())) {
      // Retry if needed.
      if (get_task_retry_state_.retry_index < kMaxRetryOverall) {
        get_task_retry_state_.retry_index++;
        get_task_retry_state_.timer.Start(
            FROM_HERE,
            GetExponentialBackoffRetryWaitTime(
                get_task_retry_state_.retry_index),
            base::BindOnce(&FocusModeTasksProvider::GetTask,
                           weak_factory_.GetWeakPtr(), task_list_id, task_id,
                           std::move(callback)));
        return;
      }

      // Max number of retries reached. Bail gracefully.
      std::move(callback).Run(FocusModeTask{});
      get_task_retry_state_.Reset();
      return;
    }

    // Other unhandled HTTP errors. Bail gracefully.
    std::move(callback).Run(FocusModeTask{});
    get_task_retry_state_.Reset();
    return;
  }

  if (!success) {
    std::move(callback).Run(FocusModeTask{});
    get_task_retry_state_.Reset();
    return;
  }

  TaskId fetched_task_id = {.list_id = task_list_id, .id = task_id};
  auto iter =
      base::ranges::find(tasks_, fetched_task_id, &FocusModeTask::task_id);
  bool task_exists = iter != tasks_.end();

  FocusModeTask temp_local_task;
  if (!task_exists) {
    temp_local_task.task_id = fetched_task_id;
  }
  FocusModeTask& task = task_exists ? *iter : temp_local_task;

  // Make sure that the fetched task is updated in the cache if it exists.
  // NOTE: Completed tasks will not show up in `api_tasks`, so we first assume
  // it's completed and update the state if the task is found in `api_tasks`.
  // TODO: Can we actually verify that the task is complete instead of making
  // this assumption?
  task.completed = true;

  for (const auto& api_task : *api_tasks) {
    if (api_task->id == task_id) {
      task.title = api_task->title;
      task.updated = api_task->updated;
      task.completed = api_task->completed;
      break;
    }
  }
  if (task.completed && task_exists) {
    // Only mark the task as deleted if it already exists in `tasks_`.
    deleted_task_ids_.insert(fetched_task_id);
  }

  std::move(callback).Run(task);
  get_task_retry_state_.Reset();
}

void FocusModeTasksProvider::OnTaskAdded(const std::string& title,
                                         OnTaskSavedCallback callback,
                                         google_apis::ApiErrorCode http_error,
                                         const api::Task* api_task) {
  if (!api_task || api_task->title.empty()) {
    // When `api_task` is null, `http_error` can be
    // `google_apis::ApiErrorCode::HTTP_SUCCESS` or other error code. Retry some
    // of the error codes as well.
    if (http_error != google_apis::ApiErrorCode::HTTP_SUCCESS) {
      // Handle too many requests error.
      if (http_error == 429 &&
          add_task_retry_state_.retry_index < kMaxRetryTooManyRequests) {
        // Retry if needed.
        add_task_retry_state_.retry_index++;
        add_task_retry_state_.timer.Start(
            FROM_HERE, kWaitTimeTooManyRequests,
            base::BindOnce(&FocusModeTasksProvider::AddTaskInternal,
                           weak_factory_.GetWeakPtr(), title,
                           std::move(callback)));
        return;
      }

      // Handle general HTTP errors.
      if (ShouldRetryHttpError(http_error) &&
          add_task_retry_state_.retry_index < kMaxRetryOverall) {
        // Retry if needed.
        add_task_retry_state_.retry_index++;
        add_task_retry_state_.timer.Start(
            FROM_HERE,
            GetExponentialBackoffRetryWaitTime(
                add_task_retry_state_.retry_index),
            base::BindOnce(&FocusModeTasksProvider::AddTaskInternal,
                           weak_factory_.GetWeakPtr(), title,
                           std::move(callback)));
        return;
      }
    }

    // After all of the retries, if there's still an error, we clear the cache.
    task_fetch_time_ = {};
    std::move(callback).Run(FocusModeTask{});
    add_task_retry_state_.Reset();
    return;
  }

  UpdateOrInsertTask(task_list_for_new_task_, api_task, std::move(callback));
  add_task_retry_state_.Reset();
}

void FocusModeTasksProvider::OnTaskUpdated(const std::string& task_list_id,
                                           const std::string& task_id,
                                           const std::string& title,
                                           bool completed,
                                           OnTaskSavedCallback callback,
                                           google_apis::ApiErrorCode http_error,
                                           const api::Task* api_task) {
  if (!api_task || api_task->title.empty()) {
    // When `api_task` is null, `http_error` can be
    // `google_apis::ApiErrorCode::HTTP_SUCCESS` or other error code. Retry some
    // of the error codes as well.
    if (http_error != google_apis::ApiErrorCode::HTTP_SUCCESS) {
      // Handle too many requests error.
      if (http_error == 429 &&
          update_task_retry_state_.retry_index < kMaxRetryTooManyRequests) {
        // Retry if needed.
        update_task_retry_state_.retry_index++;
        update_task_retry_state_.timer.Start(
            FROM_HERE, kWaitTimeTooManyRequests,
            base::BindOnce(&FocusModeTasksProvider::UpdateTaskInternal,
                           weak_factory_.GetWeakPtr(), task_list_id, task_id,
                           title, completed, std::move(callback)));
        return;
      }

      // Handle general HTTP errors.
      if (ShouldRetryHttpError(http_error) &&
          update_task_retry_state_.retry_index < kMaxRetryOverall) {
        // Retry if needed.
        update_task_retry_state_.retry_index++;
        update_task_retry_state_.timer.Start(
            FROM_HERE,
            GetExponentialBackoffRetryWaitTime(
                update_task_retry_state_.retry_index),
            base::BindOnce(&FocusModeTasksProvider::UpdateTaskInternal,
                           weak_factory_.GetWeakPtr(), task_list_id, task_id,
                           title, completed, std::move(callback)));
        return;
      }
    }

    // After all of the retries, if there's still an error, we clear the cache.
    task_fetch_time_ = {};
    if (completed) {
      deleted_task_ids_.erase({.list_id = task_list_id, .id = task_id});
    }
    std::move(callback).Run(FocusModeTask{});
    update_task_retry_state_.Reset();
    return;
  }

  UpdateOrInsertTask(task_list_id, api_task, std::move(callback));
  update_task_retry_state_.Reset();
}

void FocusModeTasksProvider::AddTaskInternal(const std::string& title,
                                             OnTaskSavedCallback callback) {
  api::TasksController::Get()->tasks_delegate()->AddTask(
      task_list_for_new_task_, title,
      base::BindOnce(&FocusModeTasksProvider::OnTaskAdded,
                     weak_factory_.GetWeakPtr(), title, std::move(callback)));
}

void FocusModeTasksProvider::UpdateTaskInternal(const std::string& task_list_id,
                                                const std::string& task_id,
                                                const std::string& title,
                                                bool completed,
                                                OnTaskSavedCallback callback) {
  api::TasksController::Get()->tasks_delegate()->UpdateTask(
      task_list_id, task_id, title, completed,
      base::BindOnce(&FocusModeTasksProvider::OnTaskUpdated,
                     weak_factory_.GetWeakPtr(), task_list_id, task_id, title,
                     completed, std::move(callback)));
}

void FocusModeTasksProvider::UpdateOrInsertTask(const std::string& task_list_id,
                                                const api::Task* api_task,
                                                OnTaskSavedCallback callback) {
  TaskId created_id = {.list_id = task_list_id, .id = api_task->id};
  created_task_ids_.insert(created_id);

  // Try to find the task in the cache or insert it.
  auto iter = base::ranges::find(tasks_, created_id, &FocusModeTask::task_id);

  FocusModeTask& task = (iter != tasks_.end()) ? *iter : tasks_.emplace_back();
  task.task_id = created_id;
  task.title = api_task->title;
  task.updated = api_task->updated;

  std::move(callback).Run(task);
}

std::vector<FocusModeTask> FocusModeTasksProvider::GetSortedTasksImpl() {
  std::vector<FocusModeTask> result;
  for (const FocusModeTask& task : tasks_) {
    if (!deleted_task_ids_.contains(task.task_id)) {
      result.push_back(task);
    }
  }

  base::ranges::sort(
      result, TaskComparator{base::Time::Now(), raw_ref<base::flat_set<TaskId>>(
                                                    created_task_ids_)});

  return result;
}

}  // namespace ash