cpython/Modules/_decimal/libmpdec/literature/mulmod-ppro.txt



(* Copyright (c) 2011-2020 Stefan Krah. All rights reserved. *)


========================================================================
             Calculate (a * b) % p using the 80-bit x87 FPU
========================================================================

A description of the algorithm can be found in the apfloat manual by
Tommila [1].

The proof follows an argument made by Granlund/Montgomery in [2].


Definitions and assumptions:
----------------------------

The 80-bit extended precision format uses 64 bits for the significand:

  (1) F = 64

The modulus is prime and less than 2**31:

  (2) 2 <= p < 2**31

The factors are less than p:

  (3) 0 <= a < p
  (4) 0 <= b < p

The product a * b is less than 2**62 and is thus exact in 64 bits:

  (5) n = a * b

The product can be represented in terms of quotient and remainder:

  (6) n = q * p + r

Using (3), (4) and the fact that p is prime, the remainder is always
greater than zero:

  (7) 0 <= q < p  /\  1 <= r < p


Strategy:
---------

Precalculate the 80-bit long double inverse of p, with a maximum
relative error of 2**(1-F):

  (8) pinv = (long double)1.0 / p

Calculate an estimate for q = floor(n/p). The multiplication has another
maximum relative error of 2**(1-F):

  (9) qest = n * pinv

If we can show that q < qest < q+1, then trunc(qest) = q. It is then
easy to recover the remainder r. The complete algorithm is:

  a) Set the control word to 64-bit precision and truncation mode.

  b) n = a * b              # Calculate exact product.

  c) qest = n * pinv        # Calculate estimate for the quotient.

  d) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.

  f) r = n - q * p          # Calculate remainder.


Proof for q < qest < q+1:
-------------------------

Using the cumulative error, the error bounds for qest are:

                n                       n * (1 + 2**(1-F))**2
  (9) --------------------- <= qest <=  ---------------------
      p * (1 + 2**(1-F))**2                       p


  Lemma 1:
  --------
                       n                   q * p + r
    (10) q < --------------------- = ---------------------
             p * (1 + 2**(1-F))**2   p * (1 + 2**(1-F))**2


    Proof:
    ~~~~~~

    (I)     q * p * (1 + 2**(1-F))**2 < q * p + r

    (II)    q * p * 2**(2-F) + q * p * 2**(2-2*F) < r

    Using (1) and (7), it is sufficient to show that:

    (III)   q * p * 2**(-62) + q * p * 2**(-126) < 1 <= r

    (III) can easily be verified by substituting the largest possible
    values p = 2**31-1 and q = 2**31-2.

    The critical cases occur when r = 1, n = m * p + 1. These cases
    can be exhaustively verified with a test program.


  Lemma 2:
  --------

           n * (1 + 2**(1-F))**2     (q * p + r) * (1 + 2**(1-F))**2
    (11)   ---------------------  =  -------------------------------  <  q + 1
                     p                              p

    Proof:
    ~~~~~~

    (I)  (q * p + r) + (q * p + r) * 2**(2-F) + (q * p + r) * 2**(2-2*F) < q * p + p

    (II)  (q * p + r) * 2**(2-F) + (q * p + r) * 2**(2-2*F) < p - r

    Using (1) and (7), it is sufficient to show that:

    (III) (q * p + r) * 2**(-62) + (q * p + r) * 2**(-126) < 1 <= p - r

    (III) can easily be verified by substituting the largest possible
    values p = 2**31-1, q = 2**31-2 and r = 2**31-2.

    The critical cases occur when r = (p - 1), n = m * p - 1. These cases
    can be exhaustively verified with a test program.


[1]  http://www.apfloat.org/apfloat/2.40/apfloat.pdf
[2]  http://gmplib.org/~tege/divcnst-pldi94.pdf
     [Section 7: "Use of floating point"]



(* Coq proof for (10) and (11) *)

Require Import ZArith.
Require Import QArith.
Require Import Qpower.
Require Import Qabs.
Require Import Psatz.

Open Scope Q_scope.


Ltac qreduce T :=
  rewrite <- (Qred_correct (T)); simpl (Qred (T)).

Theorem Qlt_move_right :
  forall x y z:Q, x + z < y <-> x < y - z.
Proof.
  intros.
  split.
    intros.
    psatzl Q.
    intros.
    psatzl Q.
Qed.

Theorem Qlt_mult_by_z :
  forall x y z:Q, 0 < z -> (x < y <-> x * z < y * z).
Proof.
  intros.
  split.
    intros.
    apply Qmult_lt_compat_r. trivial. trivial.
    intros.
    rewrite <- (Qdiv_mult_l x z). rewrite <- (Qdiv_mult_l y z).
    apply Qmult_lt_compat_r.
    apply Qlt_shift_inv_l.
    trivial. psatzl Q. trivial. psatzl Q. psatzl Q.
Qed.

Theorem Qle_mult_quad :
  forall (a b c d:Q),
    0 <= a -> a <= c ->
    0 <= b -> b <= d ->
      a * b <= c * d.
  intros.
  psatz Q.
Qed.


Theorem q_lt_qest:
  forall (p q r:Q),
    (0 < p) -> (p <= (2#1)^31 - 1) ->
    (0 <= q) -> (q <= p - 1) ->
    (1 <= r) -> (r <= p - 1) ->
      q < (q * p + r) / (p * (1 + (2#1)^(-63))^2).
Proof.
  intros.
  rewrite Qlt_mult_by_z with (z := (p * (1 + (2#1)^(-63))^2)).

  unfold Qdiv.
  rewrite <- Qmult_assoc.
  rewrite (Qmult_comm (/ (p * (1 + (2 # 1) ^ (-63)) ^ 2)) (p * (1 + (2 # 1) ^ (-63)) ^ 2)).
  rewrite Qmult_inv_r.
  rewrite Qmult_1_r.

  assert (q * (p * (1 + (2 # 1) ^ (-63)) ^ 2) == q * p + (q * p) * ((2 # 1) ^ (-62) + (2 # 1) ^ (-126))).
  qreduce ((1 + (2 # 1) ^ (-63)) ^ 2).
  qreduce ((2 # 1) ^ (-62) + (2 # 1) ^ (-126)).
  ring_simplify.
  reflexivity.
  rewrite H5.

  rewrite Qplus_comm.
  rewrite Qlt_move_right.
  ring_simplify (q * p + r - q * p).
  qreduce ((2 # 1) ^ (-62) + (2 # 1) ^ (-126)).

  apply Qlt_le_trans with (y := 1).
  rewrite Qlt_mult_by_z with (z := 85070591730234615865843651857942052864 # 18446744073709551617).
  ring_simplify.

  apply Qle_lt_trans with (y := ((2 # 1) ^ 31 - (2#1)) * ((2 # 1) ^ 31 - 1)).
  apply Qle_mult_quad.
  assumption. psatzl Q. psatzl Q. psatzl Q. psatzl Q. psatzl Q. assumption. psatzl Q. psatzl Q.
Qed.

Theorem qest_lt_qplus1:
  forall (p q r:Q),
    (0 < p) -> (p <= (2#1)^31 - 1) ->
    (0 <= q) -> (q <= p - 1) ->
    (1 <= r) -> (r <= p - 1) ->
      ((q * p + r) * (1 + (2#1)^(-63))^2) / p < q + 1.
Proof.
  intros.
  rewrite Qlt_mult_by_z with (z := p).

  unfold Qdiv.
  rewrite <- Qmult_assoc.
  rewrite (Qmult_comm (/ p) p).
  rewrite Qmult_inv_r.
  rewrite Qmult_1_r.

  assert ((q * p + r) * (1 + (2 # 1) ^ (-63)) ^ 2 == q * p + r + (q * p + r) * ((2 # 1) ^ (-62) + (2 # 1) ^ (-126))).
  qreduce ((1 + (2 # 1) ^ (-63)) ^ 2).
  qreduce ((2 # 1) ^ (-62) + (2 # 1) ^ (-126)).
  ring_simplify. reflexivity.
  rewrite H5.

  rewrite <- Qplus_assoc. rewrite <- Qplus_comm. rewrite Qlt_move_right.
  ring_simplify ((q + 1) * p - q * p).

  rewrite <- Qplus_comm. rewrite Qlt_move_right.

  apply Qlt_le_trans with (y := 1).
  qreduce ((2 # 1) ^ (-62) + (2 # 1) ^ (-126)).

  rewrite Qlt_mult_by_z with (z := 85070591730234615865843651857942052864 # 18446744073709551617).
  ring_simplify.

  ring_simplify in H0.
  apply Qle_lt_trans with (y := (2147483646 # 1) * (2147483647 # 1) + (2147483646 # 1)).

  apply Qplus_le_compat.
  apply Qle_mult_quad.
  assumption. psatzl Q. auto with qarith. assumption. psatzl Q.
  auto with qarith. auto with qarith.
  psatzl Q. psatzl Q. assumption.
Qed.