chromium/third_party/rust/chromium_crates_io/vendor/fend-core-1.5.1/src/num/dist.rs

use crate::error::{FendError, Interrupt};
use crate::interrupt::{test_int, Never};
use crate::num::bigrat::BigRat;
use crate::num::complex::{self, Complex};
use crate::result::FResult;
use crate::serialize::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::Write;
use std::ops::Neg;
use std::{fmt, io};

use super::real::Real;
use super::{Base, Exact, FormattingStyle};

#[derive(Clone)]
pub(crate) struct Dist {
	// invariant: probabilities must sum to 1
	parts: Vec<(Complex, BigRat)>,
}

impl Dist {
	pub(crate) fn serialize(&self, write: &mut impl io::Write) -> FResult<()> {
		self.parts.len().serialize(write)?;
		for (a, b) in &self.parts {
			a.serialize(write)?;
			b.serialize(write)?;
		}
		Ok(())
	}

	pub(crate) fn deserialize(read: &mut impl io::Read) -> FResult<Self> {
		let len = usize::deserialize(read)?;
		let mut parts = Vec::with_capacity(len);
		for _ in 0..len {
			let k = Complex::deserialize(read)?;
			let v = BigRat::deserialize(read)?;
			parts.push((k, v));
		}
		Ok(Self { parts })
	}

	pub(crate) fn one_point(self) -> FResult<Complex> {
		if self.parts.len() == 1 {
			Ok(self.parts.into_iter().next().unwrap().0)
		} else {
			Err(FendError::ProbabilityDistributionsNotAllowed)
		}
	}

	pub(crate) fn one_point_ref(&self) -> FResult<&Complex> {
		if self.parts.len() == 1 {
			Ok(&self.parts[0].0)
		} else {
			Err(FendError::ProbabilityDistributionsNotAllowed)
		}
	}

	pub(crate) fn new_die<I: Interrupt>(count: u32, faces: u32, int: &I) -> FResult<Self> {
		assert!(count != 0);
		assert!(faces != 0);
		if count > 1 {
			let mut result = Self::new_die(1, faces, int)?;
			for _ in 1..count {
				test_int(int)?;
				result = Exact::new(result, true)
					.add(&Exact::new(Self::new_die(1, faces, int)?, true), int)?
					.value;
			}
			return Ok(result);
		}
		let mut parts = Vec::new();
		let probability = BigRat::from(1).div(&BigRat::from(u64::from(faces)), int)?;
		for face in 1..=faces {
			test_int(int)?;
			parts.push((Complex::from(u64::from(face)), probability.clone()));
		}
		Ok(Self { parts })
	}

	pub(crate) fn equals_int<I: Interrupt>(&self, val: u64, int: &I) -> FResult<bool> {
		Ok(self.parts.len() == 1
			&& self.parts[0].0.compare(&val.into(), int)? == Some(Ordering::Equal))
	}

	#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
	pub(crate) fn sample<I: Interrupt>(self, ctx: &crate::Context, int: &I) -> FResult<Self> {
		if self.parts.len() == 1 {
			return Ok(self);
		}
		let mut random = ctx.random_u32.ok_or(FendError::RandomNumbersNotAvailable)?();
		let mut res = None;
		for (k, v) in self.parts {
			random = random.saturating_sub((v.into_f64(int)? * f64::from(u32::MAX)) as u32);
			if random == 0 {
				return Ok(Self::from(k));
			}
			res = Some(Self::from(k));
		}
		res.ok_or(FendError::EmptyDistribution)
	}

	pub(crate) fn mean<I: Interrupt>(self, int: &I) -> FResult<Self> {
		if self.parts.is_empty() {
			return Err(FendError::EmptyDistribution);
		} else if self.parts.len() == 1 {
			return Ok(self);
		}

		let mut result = Exact::new(Complex::from(0), true);
		for (k, v) in self.parts {
			result = Exact::new(k, true)
				.mul(&Exact::new(Complex::from(Real::from(v)), true), int)?
				.add(result, int)?;
		}

		Ok(Self::from(result.value))
	}

	#[allow(
		clippy::cast_possible_truncation,
		clippy::cast_sign_loss,
		clippy::too_many_arguments
	)]
	pub(crate) fn format<I: Interrupt>(
		&self,
		exact: bool,
		style: FormattingStyle,
		base: Base,
		use_parentheses: complex::UseParentheses,
		out: &mut String,
		ctx: &crate::Context,
		int: &I,
	) -> FResult<Exact<()>> {
		if self.parts.len() == 1 {
			let res = self.parts[0]
				.0
				.format(exact, style, base, use_parentheses, int)?;
			write!(out, "{}", res.value)?;
			Ok(Exact::new((), res.exact))
		} else {
			let mut ordered_kvs = vec![];
			let mut max_prob = 0.0;
			for (n, prob) in &self.parts {
				let prob_f64 = prob.clone().into_f64(int)?;
				if prob_f64 > max_prob {
					max_prob = prob_f64;
				}
				ordered_kvs.push((n, prob, prob_f64));
			}
			ordered_kvs.sort_unstable_by(|(a, _, _), (b, _, _)| {
				a.compare(b, &Never).unwrap().unwrap_or(Ordering::Equal)
			});
			if ctx.output_mode == crate::OutputMode::SimpleText {
				write!(out, "{{ ")?;
			}
			let mut first = true;
			for (num, _prob, prob_f64) in ordered_kvs {
				let num = num
					.format(exact, style, base, use_parentheses, int)?
					.value
					.to_string();
				let prob_percentage = prob_f64 * 100.0;
				if ctx.output_mode == crate::OutputMode::TerminalFixedWidth {
					if !first {
						writeln!(out)?;
					}
					let mut bar = String::new();
					for _ in 0..(prob_f64 / max_prob * 30.0).min(30.0) as u32 {
						bar.push('#');
					}
					write!(out, "{num:>3}: {prob_percentage:>5.2}%  {bar}")?;
				} else {
					if !first {
						write!(out, ", ")?;
					}
					write!(out, "{num}: {prob_percentage:.2}%")?;
				}
				if first {
					first = false;
				}
			}
			if ctx.output_mode == crate::OutputMode::SimpleText {
				write!(out, " }}")?;
			}
			// TODO check exactness
			Ok(Exact::new((), true))
		}
	}

	fn bop<I: Interrupt>(
		self,
		rhs: &Self,
		mut f: impl FnMut(&Complex, &Complex, &I) -> FResult<Complex>,
		int: &I,
	) -> FResult<Self> {
		let mut parts = Vec::<(Complex, BigRat)>::new();
		for (n1, p1) in &self.parts {
			for (n2, p2) in &rhs.parts {
				let n = f(n1, n2, int)?;
				let p = p1.clone().mul(p2, int)?;
				let mut found = false;
				for (k, prob) in &mut parts {
					if k.compare(&n, int)? == Some(Ordering::Equal) {
						*prob = prob.clone().add(p.clone(), int)?;
						found = true;
						break;
					}
				}
				if !found {
					parts.push((n, p));
				}
			}
		}
		Ok(Self { parts })
	}
}

impl Exact<Dist> {
	pub(crate) fn add<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
		let self_exact = self.exact;
		let rhs_exact = rhs.exact;
		let mut exact = true;
		Ok(Self::new(
			self.value.bop(
				&rhs.value,
				|a, b, int| {
					let sum = Exact::new(a.clone(), self_exact)
						.add(Exact::new(b.clone(), rhs_exact), int)?;
					exact &= sum.exact;
					Ok(sum.value)
				},
				int,
			)?,
			exact,
		))
	}

	pub(crate) fn mul<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
		let self_exact = self.exact;
		let rhs_exact = rhs.exact;
		let mut exact = true;
		Ok(Self::new(
			self.value.bop(
				&rhs.value,
				|a, b, int| {
					let sum = Exact::new(a.clone(), self_exact)
						.mul(&Exact::new(b.clone(), rhs_exact), int)?;
					exact &= sum.exact;
					Ok(sum.value)
				},
				int,
			)?,
			exact,
		))
	}

	pub(crate) fn div<I: Interrupt>(self, rhs: &Self, int: &I) -> FResult<Self> {
		let self_exact = self.exact;
		let rhs_exact = rhs.exact;
		let mut exact = true;
		Ok(Self::new(
			self.value.bop(
				&rhs.value,
				|a, b, int| {
					let sum = Exact::new(a.clone(), self_exact)
						.div(Exact::new(b.clone(), rhs_exact), int)?;
					exact &= sum.exact;
					Ok(sum.value)
				},
				int,
			)?,
			exact,
		))
	}
}

impl From<Complex> for Dist {
	fn from(v: Complex) -> Self {
		Self {
			parts: vec![(v, 1.into())],
		}
	}
}

impl From<Real> for Dist {
	fn from(v: Real) -> Self {
		Self::from(Complex::from(v))
	}
}

impl From<u64> for Dist {
	fn from(i: u64) -> Self {
		Self::from(Complex::from(i))
	}
}

impl fmt::Debug for Dist {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		match self.one_point_ref() {
			Ok(complex) => write!(f, "{complex:?}"),
			Err(_) => write!(f, "dist {:?}", self.parts),
		}
	}
}

impl Neg for Dist {
	type Output = Self;
	fn neg(mut self) -> Self {
		for (k, _) in &mut self.parts {
			*k = -k.clone();
		}
		self
	}
}