/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <algorithm>
#include <cmath>
#include <functional>
#include <limits>
#include <tuple>
#include <type_traits>
#include <folly/lang/Exception.h>
namespace folly {
// Robust and efficient online computation of statistics,
// using Welford's method for variance.
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
template <typename SampleDataType, typename StatsType = double>
class StreamingStats final {
// Caclulated statistic result has to be floating point type
static_assert(std::is_floating_point_v<StatsType>);
public:
struct StreamingState {
size_t count = 0;
StatsType mean = 0;
StatsType m2 = 0;
SampleDataType min = std::numeric_limits<SampleDataType>::max();
SampleDataType max = std::numeric_limits<SampleDataType>::lowest();
};
template <class Iterator>
StreamingStats(Iterator first, Iterator last) noexcept {
add(first, last);
}
explicit StreamingStats(StreamingState state)
: count_(state.count),
mean_(state.mean),
m2_(state.m2),
min_(state.min),
max_(state.max) {}
StreamingStats() = default;
~StreamingStats() = default;
/// Add sample data via iteratation
template <class Iterator>
void add(Iterator first, Iterator last) noexcept {
for (auto it = first; it != last; ++it) {
add(*it);
}
}
/// Add a single sample
void add(SampleDataType value) noexcept {
max_ = std::max(max_, value);
min_ = std::min(min_, value);
++count_;
StatsType const delta = value - mean_;
mean_ += delta / count_;
StatsType const delta2 = value - mean_;
m2_ += delta * delta2;
}
/// Merge with an existing StreamingStats object
void merge(StreamingStats const& other) {
if (other.count_ == 0) {
return;
}
max_ = std::max(max_, other.max_);
min_ = std::min(min_, other.min_);
size_t const new_size = count_ + other.count_;
StatsType const new_mean =
(mean_ * count_ + other.mean_ * other.count_) / new_size;
// Each cumulant must be corrected.
// * from: sum((x_i - mean_)²)
// * to: sum((x_i - new_mean)²)
auto delta = [&](auto const& stats) {
return stats.count_ *
(new_mean * (new_mean - 2 * stats.mean_) + stats.mean_ * stats.mean_);
};
m2_ = m2_ + delta(*this) + other.m2_ + delta(other);
mean_ = new_mean;
count_ = new_size;
}
size_t count() const noexcept { return count_; }
SampleDataType minimum() const {
checkMinimumDataSize(1);
return min_;
}
SampleDataType maximum() const {
checkMinimumDataSize(1);
return max_;
}
StatsType mean() const {
checkMinimumDataSize(1);
return mean_;
}
StatsType m2() const {
checkMinimumDataSize(1);
return m2_;
}
StatsType populationVariance() const {
checkMinimumDataSize(2);
return var_(0);
}
StatsType sampleVariance() const {
checkMinimumDataSize(2);
return var_(1);
}
StatsType populationStandardDeviation() const {
checkMinimumDataSize(2);
return std_(0);
}
StatsType sampleStandardDeviation() const {
checkMinimumDataSize(2);
return std_(1);
}
StreamingState state() const {
StreamingState state;
state.count = count_;
state.m2 = m2_;
state.max = max_;
state.mean = mean_;
state.min = min_;
return state;
}
private:
void checkMinimumDataSize(size_t const minElements) const {
if (count_ < minElements) {
throw_exception<std::logic_error>("stats: unavailable with no samples");
}
}
StatsType var_(size_t bias) const noexcept { return m2_ / (count_ - bias); }
StatsType std_(size_t bias) const noexcept { return std::sqrt(var_(bias)); }
size_t count_ = 0;
StatsType mean_ = 0;
StatsType m2_ = 0;
SampleDataType min_ = std::numeric_limits<SampleDataType>::max();
SampleDataType max_ = std::numeric_limits<SampleDataType>::lowest();
};
} // namespace folly