chromium/chrome/browser/ash/guest_os/infra/cached_callback.h

// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef CHROME_BROWSER_ASH_GUEST_OS_INFRA_CACHED_CALLBACK_H_
#define CHROME_BROWSER_ASH_GUEST_OS_INFRA_CACHED_CALLBACK_H_

#include <memory>
#include <type_traits>

#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/memory/weak_ptr.h"
#include "base/types/expected.h"

namespace guest_os {

// Manages several racing callbacks which all attempt to access the same
// (logical) object.
//
// This class is used when you have multiple potential callers for an
// asynchronous operation, and you want them all to agree on the outcome of that
// operation. The first callback that comes in triggers the "real" async
// operation, and subsequent callbacks are queued. If the "real" operation
// succeeds, all currently queued and future callbacks are invoked with a handle
// to that success. If it fails, all currently queued callbacks are notified of
// the failure but future ones will retry the "real" async operation.
//
// Internally this class tries to create the "real" result T as a unique_ptr,
// but exposes it to its own clients as a T*. For the errors E it is advisable
// to use something movable and copyable, preferably an enum.
template <typename T, typename E>
class CachedCallback {
 public:
  // As a convenience this class provides a default implementation of the
  // Reject() method, which is used when the cache is deleted while callbacks
  // are in-flight. In this case we default-construct an E for them.
  //
  // We can avoid this requirement using std::enable_if shenanigans but
  // enforcing the below is cleaner.
  static_assert(
      std::is_default_constructible<E>::value,
      "Cached callbacks must have a default constructible error type");

  using Result = base::expected<T*, E>;

  using Callback = base::OnceCallback<void(Result result)>;

  virtual ~CachedCallback() {
    if (!queued_callbacks_.empty()) {
      Finish(base::unexpected(Reject()));
    }
  }

  // Request access to the cached result.
  void Get(Callback callback) {
    if (real_object_) {
      std::move(callback).Run(Result(real_object_.get()));
      return;
    }
    queued_callbacks_.push_back(std::move(callback));
    // If this is the first callback, spawn the real one. This must happen after
    // enqueueing the callback in case the factory returns synchronously.
    if (queued_callbacks_.size() == 1) {
      Build(base::BindOnce(&CachedCallback<T, E>::OnRealResultFound,
                           weak_factory_.GetWeakPtr()));
    }
  }

  // Returns the cached result if it exists, nullptr otherwise.
  T* MaybeGet() const {
    if (real_object_) {
      return real_object_.get();
    }
    return nullptr;
  }

  // Clears the stored real result (if one exists) and returns it (or null).
  std::unique_ptr<T> Invalidate() {
    std::unique_ptr<T> real = std::move(real_object_);
    real_object_.reset();
    return std::move(real);
  }

  void CacheForTesting(std::unique_ptr<T> real) {
    OnRealResultFound(RealResult(std::move(real)));
  }

 protected:
  using RealResult = base::expected<std::unique_ptr<T>, E>;

  using RealCallback = base::OnceCallback<void(RealResult)>;

  // Used to construct the "real" result, which will be owned by this class.
  virtual void Build(RealCallback callback) = 0;

  // In cases where the cache determines that a request can not be fulfilled,
  // this error is returned to callers automatically. For example, it is used if
  // the cache is destroyed while requests are in flight.
  //
  // Unless overridden, a default-constructed E is used.
  virtual E Reject() { return E{}; }

  // Helper template to construct successful RealResults more conveniently. It
  // is static so that lambdas defined by the subclass can use them.
  template <typename... TT>
  static RealResult Success(TT&&... tt) {
    return base::ok(std::make_unique<T>(std::forward<TT>(tt)...));
  }

  // Similar to Success, but for error results.
  template <typename... EE>
  static RealResult Failure(EE&&... ee) {
    return base::unexpected(std::forward<EE>(ee)...);
  }

 private:
  void OnRealResultFound(RealResult real_result) {
    if (!real_result.has_value()) {
      Finish(base::unexpected(real_result.error()));
      return;
    }
    real_object_ = std::move(real_result.value());
    Finish(base::ok(real_object_.get()));
  }

  void Finish(Result res) {
    while (!queued_callbacks_.empty()) {
      std::move(queued_callbacks_.back()).Run(res);
      queued_callbacks_.pop_back();
    }
  }

  std::unique_ptr<T> real_object_;
  std::vector<Callback> queued_callbacks_;
  base::WeakPtrFactory<CachedCallback<T, E>> weak_factory_{this};
};

}  // namespace guest_os

#endif  // CHROME_BROWSER_ASH_GUEST_OS_INFRA_CACHED_CALLBACK_H_