Skip to content

Commit

Permalink
speed up NewBaseystem synthesis
Browse files Browse the repository at this point in the history
Use a vm_compute hack fromhttps://arxiv.org/pdf/1305.6543.pdf section 5.5:
pattern terms over what to keep opaque, then reduce the lambda using
vm_compute.
  • Loading branch information
andres-erbsen committed Feb 23, 2017
1 parent 2bda7f1 commit 60c9043
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 46 deletions.
3 changes: 1 addition & 2 deletions _CoqProject
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ src/CompleteEdwardsCurve/Pre.v
src/Encoding/EncodingTheorems.v
src/Encoding/ModularWordEncodingPre.v
src/Encoding/ModularWordEncodingTheorems.v
src/Encoding/PointEncoding.v
src/Encoding/PointEncodingPre.v
src/Experiments/Ed25519.v
src/Experiments/Ed25519Extraction.v
src/Experiments/ExtrHaskellNats.v
Expand Down Expand Up @@ -485,4 +483,5 @@ src/Util/Tactics/RewriteHyp.v
src/Util/Tactics/SpecializeBy.v
src/Util/Tactics/SplitInContext.v
src/Util/Tactics/UniquePose.v
src/Util/Tactics/VM.v
src/WeierstrassCurve/Pre.v
109 changes: 65 additions & 44 deletions src/NewBaseSystem.v
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
Require Import Coq.ZArith.BinIntDef.
Local Open Scope Z_scope.

Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.

Require Import Coq.Lists.List. Import ListNotations.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.

(*****
This file provides a generalized version of arithmetic with "mixed
Expand Down Expand Up @@ -252,6 +240,17 @@ reasonable time, so this is not really an option.
*****)

Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
Require Import Coq.ZArith.BinIntDef.
Local Open Scope Z_scope.

Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
Require Import Crypto.Util.CPSUtil Crypto.Util.Prod Crypto.Util.Tactics.

Require Import Coq.Lists.List. Import ListNotations.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.

Local Ltac prove_id :=
repeat match goal with
Expand Down Expand Up @@ -640,6 +639,10 @@ Module B.
Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed.
Hint Rewrite @eval_carry : push_basesystem_eval.

(* TODO make a correctness proof for this *)
Definition chained_carries (p:list limb) (idxs : list nat)
{T} (f:list limb->T) :=
fold_right_cps2 carry_cps p idxs f.
End Carries.
End Positional.
End Positional.
Expand Down Expand Up @@ -675,6 +678,27 @@ Local Coercion Z.of_nat : nat >-> Z.
Import Coq.Lists.List.ListNotations. Local Open Scope list_scope.
Import B.

Ltac basesystem_partial_evaluation_RHS :=
let t0 := match goal with |- _ _ ?t => t end in
let t := (eval cbv delta [
(* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], or [runtime_mul] *)
Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries
Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry Associational.sat_multerm_cps Associational.sat_multerm Associational.sat_mul_cps Associational.sat_mul
] in t0) in
let t := (eval pattern @runtime_mul in t) in
let t := match t with ?t _ => t end in
let t := (eval pattern @runtime_add in t) in
let t := match t with ?t _ => t end in
let t := (eval pattern @Let_In in t) in
let t := match t with ?t _ => t end in
let t1 := fresh "t1" in
pose t as t1;
transitivity (t1
(@Let_In)
(@runtime_add)
(@runtime_mul));
[replace_with_vm_compute t1; clear t1|reflexivity].

Ltac assert_preconditions :=
repeat match goal with
| |- context [Positional.from_associational_cps ?wt ?n] =>
Expand All @@ -687,24 +711,6 @@ Ltac assert_preconditions :=
unique assert (wt (S i) / wt i <> 0) by (cbv; congruence)
end.

Ltac op_simplify :=
cbv - [runtime_add runtime_mul Let_In];
cbv [runtime_add runtime_mul].

Ltac prove_op sz x :=
cbv [Tuple.tuple Tuple.tuple'] in *;
repeat match goal with p : _ * Z |- _ => destruct p end;
apply lift2_sig;
eexists; cbv zeta beta; intros;
match goal with |- Positional.eval ?wt _ = ?op (Positional.eval ?wt ?a) (Positional.eval ?wt ?b) =>
transitivity (Positional.eval wt (x wt a b))
end;
[ apply f_equal; op_simplify; reflexivity
| assert_preconditions;
progress autorewrite with uncps push_id push_basesystem_eval;
reflexivity ]
.

Section Ops.
Context
(modulo : Z -> Z -> Z)
Expand All @@ -717,23 +723,26 @@ Section Ops.
Let sz := 10%nat.
Let sz2 := Eval compute in ((sz * 2) - 1)%nat.

(* shorthand for many carries in a row *)
Definition chained_carries (w : nat -> Z) (p:list B.limb) (idxs : list nat)
{T} (f:list B.limb->T) :=
fold_right_cps2 (@Positional.carry_cps w modulo div) p idxs f.

Definition addT :
{ add : (Z^sz -> Z^sz -> Z^sz)%type &
forall a b : Z^sz,
let eval {n} := Positional.eval (n := n) wt in
eval (add a b) = eval a + eval b }.
Proof.
prove_op sz (
fun wt a b =>
let x := constr:(fun wt a b =>
Positional.to_associational_cps (n := sz) wt a
(fun r => Positional.to_associational_cps (n := sz) wt b
(fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id
))).
))) in
apply lift2_sig; eexists;
transitivity (Positional.eval wt (x wt a b));
[|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].

apply f_equal.

basesystem_partial_evaluation_RHS.

reflexivity.
Defined.


Expand All @@ -743,19 +752,31 @@ Section Ops.
let eval {n} := Positional.eval (n := n) wt in
eval (mul a b) = eval a * eval b }.
Proof.
let x := (eval cbv [chained_carries seq fold_right_cps2 sz2] in
(fun w a b =>
let x := constr:(fun w a b =>
Positional.to_associational_cps (n := sz) w a
(fun r => Positional.to_associational_cps (n := sz) w b
(fun r0 => Associational.mul_cps r r0
(fun r1 => Positional.from_associational_cps w sz2 r1
(fun r2 => Positional.to_associational_cps w r2
(fun r3 => chained_carries w r3 (seq 0 sz2)
(fun r3 => Positional.chained_carries(div:=div)(modulo:=modulo) w r3 (seq 0 sz2)
(fun r13 => Positional.from_associational_cps w sz2 r13 id
)))))))) in
prove_op sz x.
Defined.
))))))) in
apply lift2_sig; eexists;
transitivity (Positional.eval wt (x wt a b));
[|cbv [Positional.chained_carries fold_right_cps2 seq fold_right sz2]; assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].

apply f_equal.

basesystem_partial_evaluation_RHS.

(* rough breakdown of synthesis time *)
(* 1.2s for side conditions -- should improve significantly when [chained_carries] gets a correctness lemma *)
(* basesystem_partial_evaluation_RHS (primarily vm_compute): 1.8s, which gets re-computed during defined *)

(* doing [cbv -[Let_In runtime_add runtime_mul]] took 37s *)

reflexivity.
Defined. (* 3s *)
End Ops.

(*
Expand Down
1 change: 1 addition & 0 deletions src/Util/Tactics.v
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Require Export Crypto.Util.Tactics.RewriteHyp.
Require Export Crypto.Util.Tactics.SpecializeBy.
Require Export Crypto.Util.Tactics.SplitInContext.
Require Export Crypto.Util.Tactics.UniquePose.
Require Export Crypto.Util.Tactics.VM.

(** Test if a tactic succeeds, but always roll-back the results *)
Tactic Notation "test" tactic3(tac) :=
Expand Down
32 changes: 32 additions & 0 deletions src/Util/Tactics/VM.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
(* Code by Jason Gross for COQBUG 4637: vm_compute in _ makes Defined slow *)

(** First, work around COQBUG 4494, https://coq.inria.fr/bugs/show_bug.cgi?id=4494 (replace is slow and broken under binders *)
Ltac replace_with_at_by x y set_tac tac :=
let H := fresh in
let x' := fresh in
set_tac x' x;
assert (H : y = x') by (subst x'; tac);
clearbody x'; induction H.

Ltac replace_with_at x y set_tac :=
let H := fresh in
let x' := fresh in
set_tac x' x;
cut (y = x');
[ intro H; induction H
| subst x' ].

Ltac replace_with_vm_compute c :=
let c' := (eval vm_compute in c) in
(* we'd like to just do: *)
(* replace c with c' by (clear; abstract (vm_compute; reflexivity)) *)
(* but [set] is too slow in 8.4, so we write our own version (see COQBUG https://coq.inria.fr/bugs/show_bug.cgi?id=3280#c13 *)
let set_tac := (fun x' x
=> pose x as x';
change x with x') in
replace_with_at_by c c' set_tac ltac:(clear; vm_cast_no_check (eq_refl c')).

Ltac replace_with_vm_compute_in c H :=
let c' := (eval vm_compute in c) in
(* By constrast [set ... in ...] seems faster than [change .. with ... in ...] in 8.4?! *)
replace_with_at_by c c' ltac:(fun x' x => set (x' := x) in H ) ltac:(clear; vm_cast_no_check (eq_refl c')).

0 comments on commit 60c9043

Please sign in to comment.