chromium/third_party/mediapipe/src/mediapipe/util/resource_cache.h

// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef MEDIAPIPE_UTIL_RESOURCE_CACHE_H_
#define MEDIAPIPE_UTIL_RESOURCE_CACHE_H_

#include <cstddef>
#include <unordered_map>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/function_ref.h"
#include "absl/log/absl_check.h"
#include "mediapipe/framework/port/logging.h"

namespace mediapipe {

// Maintains a cache for resources of type `Value`, where the type of the
// resource (e.g., image dimension for an image pool) is described bye the `Key`
// type. The `Value` type must include an unset value, with implicit conversion
// to bool reflecting set/unset state.
template <typename Key, typename Value,
          typename KeyHash = typename absl::flat_hash_map<Key, int>::hasher>
class ResourceCache {
 public:
  Value Lookup(
      const Key& key,
      absl::FunctionRef<Value(const Key& key, int request_count)> create) {
    auto map_it = map_.find(key);
    Entry* entry;
    if (map_it == map_.end()) {
      std::tie(map_it, std::ignore) =
          map_.try_emplace(key, std::make_unique<Entry>(key));
      entry = map_it->second.get();
      ABSL_CHECK_EQ(entry->request_count, 0);
      entry->request_count = 1;
      entry_list_.Append(entry);
      if (entry->prev != nullptr) ABSL_CHECK_GE(entry->prev->request_count, 1);
    } else {
      entry = map_it->second.get();
      ++entry->request_count;
      Entry* larger = entry->prev;
      while (larger != nullptr &&
             larger->request_count < entry->request_count) {
        larger = larger->prev;
      }
      if (larger != entry->prev) {
        entry_list_.Remove(entry);
        entry_list_.InsertAfter(entry, larger);
      }
    }
    if (!entry->value) {
      entry->value = create(entry->key, entry->request_count);
    }
    ++total_request_count_;
    return entry->value;
  }

  std::vector<Value> Evict(int max_count, int request_count_scrub_interval) {
    std::vector<Value> evicted;

    // Remove excess entries.
    ABSL_CHECK_GE(max_count, 0);
    while (entry_list_.size() > static_cast<size_t>(max_count)) {
      Entry* victim = entry_list_.tail();
      evicted.emplace_back(std::move(victim->value));
      entry_list_.Remove(victim);
      map_.erase(victim->key);
    }
    // Every request_count_scrub_interval, halve the request counts, and
    // remove entries which have fallen to 0.
    // This keeps sporadic requests from accumulating and eventually exceeding
    // the minimum request threshold for allocating a pool. Also, it means that
    // if the request regimen changes (e.g. a graph was always requesting a
    // large size, but then switches to a small size to save memory or CPU), the
    // pool can quickly adapt to it.
    bool scrub = total_request_count_ >= request_count_scrub_interval;
    if (scrub) {
      total_request_count_ = 0;
      for (Entry* entry = entry_list_.head(); entry != nullptr;) {
        entry->request_count /= 2;
        Entry* next = entry->next;
        if (entry->request_count == 0) {
          evicted.emplace_back(std::move(entry->value));
          entry_list_.Remove(entry);
          map_.erase(entry->key);
        }
        entry = next;
      }
    }
    return evicted;
  }

 private:
  struct Entry {
    Entry(const Key& key) : key(key) {}
    Entry* prev = nullptr;
    Entry* next = nullptr;
    int request_count = 0;
    Key key;
    Value value;
  };

  // Unlike std::list, this is an intrusive list, meaning that the prev and next
  // pointers live inside the element. Apart from not requiring an extra
  // allocation, this means that once we look up an entry by key in the pools_
  // map we do not need to look it up separately in the list.
  //
  class EntryList {
   public:
    void Prepend(Entry* entry) {
      if (head_ == nullptr) {
        head_ = tail_ = entry;
      } else {
        entry->next = head_;
        head_->prev = entry;
        head_ = entry;
      }
      ++size_;
    }
    void Append(Entry* entry) {
      if (tail_ == nullptr) {
        head_ = tail_ = entry;
      } else {
        tail_->next = entry;
        entry->prev = tail_;
        tail_ = entry;
      }
      ++size_;
    }
    void Remove(Entry* entry) {
      if (entry == head_) {
        head_ = entry->next;
      } else {
        entry->prev->next = entry->next;
      }
      if (entry == tail_) {
        tail_ = entry->prev;
      } else {
        entry->next->prev = entry->prev;
      }
      entry->prev = nullptr;
      entry->next = nullptr;
      --size_;
    }
    void InsertAfter(Entry* entry, Entry* after) {
      if (after != nullptr) {
        entry->next = after->next;
        if (entry->next) entry->next->prev = entry;
        entry->prev = after;
        after->next = entry;
        ++size_;
      } else {
        Prepend(entry);
      }
    }

    Entry* head() { return head_; }
    Entry* tail() { return tail_; }
    size_t size() { return size_; }

   private:
    Entry* head_ = nullptr;
    Entry* tail_ = nullptr;
    size_t size_ = 0;
  };

  absl::flat_hash_map<Key, std::unique_ptr<Entry>, KeyHash> map_;
  EntryList entry_list_;
  int total_request_count_ = 0;
};

}  // namespace mediapipe

#endif  // MEDIAPIPE_UTIL_RESOURCE_CACHE_H_