Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up NewBaseystem synthesis #115

Merged
merged 1 commit into from
Feb 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions _CoqProject
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ src/CompleteEdwardsCurve/Pre.v
src/Encoding/EncodingTheorems.v
src/Encoding/ModularWordEncodingPre.v
src/Encoding/ModularWordEncodingTheorems.v
src/Encoding/PointEncoding.v
src/Encoding/PointEncodingPre.v
src/Experiments/Ed25519.v
src/Experiments/Ed25519Extraction.v
src/Experiments/ExtrHaskellNats.v
Expand Down Expand Up @@ -485,4 +483,5 @@ src/Util/Tactics/RewriteHyp.v
src/Util/Tactics/SpecializeBy.v
src/Util/Tactics/SplitInContext.v
src/Util/Tactics/UniquePose.v
src/Util/Tactics/VM.v
src/WeierstrassCurve/Pre.v
109 changes: 65 additions & 44 deletions src/NewBaseSystem.v
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
Require Import Coq.ZArith.BinIntDef.
Local Open Scope Z_scope.

Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
Require Import Crypto.Util.CPSUtil Crypto.Util.Prod.

Require Import Coq.Lists.List. Import ListNotations.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.

(*****

This file provides a generalized version of arithmetic with "mixed
Expand Down Expand Up @@ -252,6 +240,17 @@ reasonable time, so this is not really an option.

*****)

Require Import Coq.ZArith.ZArith Coq.micromega.Psatz Coq.omega.Omega.
Require Import Coq.ZArith.BinIntDef.
Local Open Scope Z_scope.

Require Import Crypto.Tactics.Algebra_syntax.Nsatz.
Require Import Crypto.Util.Tactics Crypto.Util.Decidable Crypto.Util.LetIn.
Require Import Crypto.Util.ZUtil Crypto.Util.ListUtil Crypto.Util.Sigma.
Require Import Crypto.Util.CPSUtil Crypto.Util.Prod Crypto.Util.Tactics.

Require Import Coq.Lists.List. Import ListNotations.
Require Crypto.Util.Tuple. Local Notation tuple := Tuple.tuple.

Local Ltac prove_id :=
repeat match goal with
Expand Down Expand Up @@ -640,6 +639,10 @@ Module B.
Proof. cbv [carry_cps carry]; intros; eapply @eval_carry; eauto. Qed.
Hint Rewrite @eval_carry : push_basesystem_eval.

(* TODO make a correctness proof for this *)
Definition chained_carries (p:list limb) (idxs : list nat)
{T} (f:list limb->T) :=
fold_right_cps2 carry_cps p idxs f.
End Carries.
End Positional.
End Positional.
Expand Down Expand Up @@ -675,6 +678,27 @@ Local Coercion Z.of_nat : nat >-> Z.
Import Coq.Lists.List.ListNotations. Local Open Scope list_scope.
Import B.

Ltac basesystem_partial_evaluation_RHS :=
let t0 := match goal with |- _ _ ?t => t end in
let t := (eval cbv delta [
(* this list must contain all definitions referenced by t that reference [Let_In], [runtime_add], or [runtime_mul] *)
Positional.to_associational_cps Positional.to_associational Positional.eval Positional.zeros Positional.add_to_nth_cps Positional.add_to_nth Positional.place_cps Positional.place Positional.from_associational_cps Positional.from_associational Positional.carry_cps Positional.carry Positional.chained_carries
Associational.eval Associational.multerm Associational.mul_cps Associational.mul Associational.split_cps Associational.split Associational.reduce_cps Associational.reduce Associational.carryterm_cps Associational.carryterm Associational.carry_cps Associational.carry Associational.sat_multerm_cps Associational.sat_multerm Associational.sat_mul_cps Associational.sat_mul
] in t0) in
let t := (eval pattern @runtime_mul in t) in
let t := match t with ?t _ => t end in
let t := (eval pattern @runtime_add in t) in
let t := match t with ?t _ => t end in
let t := (eval pattern @Let_In in t) in
let t := match t with ?t _ => t end in
let t1 := fresh "t1" in
pose t as t1;
transitivity (t1
(@Let_In)
(@runtime_add)
(@runtime_mul));
[replace_with_vm_compute t1; clear t1|reflexivity].

Ltac assert_preconditions :=
repeat match goal with
| |- context [Positional.from_associational_cps ?wt ?n] =>
Expand All @@ -687,24 +711,6 @@ Ltac assert_preconditions :=
unique assert (wt (S i) / wt i <> 0) by (cbv; congruence)
end.

Ltac op_simplify :=
cbv - [runtime_add runtime_mul Let_In];
cbv [runtime_add runtime_mul].

Ltac prove_op sz x :=
cbv [Tuple.tuple Tuple.tuple'] in *;
repeat match goal with p : _ * Z |- _ => destruct p end;
apply lift2_sig;
eexists; cbv zeta beta; intros;
match goal with |- Positional.eval ?wt _ = ?op (Positional.eval ?wt ?a) (Positional.eval ?wt ?b) =>
transitivity (Positional.eval wt (x wt a b))
end;
[ apply f_equal; op_simplify; reflexivity
| assert_preconditions;
progress autorewrite with uncps push_id push_basesystem_eval;
reflexivity ]
.

Section Ops.
Context
(modulo : Z -> Z -> Z)
Expand All @@ -717,23 +723,26 @@ Section Ops.
Let sz := 10%nat.
Let sz2 := Eval compute in ((sz * 2) - 1)%nat.

(* shorthand for many carries in a row *)
Definition chained_carries (w : nat -> Z) (p:list B.limb) (idxs : list nat)
{T} (f:list B.limb->T) :=
fold_right_cps2 (@Positional.carry_cps w modulo div) p idxs f.

Definition addT :
{ add : (Z^sz -> Z^sz -> Z^sz)%type &
forall a b : Z^sz,
let eval {n} := Positional.eval (n := n) wt in
eval (add a b) = eval a + eval b }.
Proof.
prove_op sz (
fun wt a b =>
let x := constr:(fun wt a b =>
Positional.to_associational_cps (n := sz) wt a
(fun r => Positional.to_associational_cps (n := sz) wt b
(fun r0 => Positional.from_associational_cps wt sz (r ++ r0) id
))).
))) in
apply lift2_sig; eexists;
transitivity (Positional.eval wt (x wt a b));
[|assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].

apply f_equal.

basesystem_partial_evaluation_RHS.

reflexivity.
Defined.


Expand All @@ -743,19 +752,31 @@ Section Ops.
let eval {n} := Positional.eval (n := n) wt in
eval (mul a b) = eval a * eval b }.
Proof.
let x := (eval cbv [chained_carries seq fold_right_cps2 sz2] in
(fun w a b =>
let x := constr:(fun w a b =>
Positional.to_associational_cps (n := sz) w a
(fun r => Positional.to_associational_cps (n := sz) w b
(fun r0 => Associational.mul_cps r r0
(fun r1 => Positional.from_associational_cps w sz2 r1
(fun r2 => Positional.to_associational_cps w r2
(fun r3 => chained_carries w r3 (seq 0 sz2)
(fun r3 => Positional.chained_carries(div:=div)(modulo:=modulo) w r3 (seq 0 sz2)
(fun r13 => Positional.from_associational_cps w sz2 r13 id
)))))))) in
prove_op sz x.
Defined.
))))))) in
apply lift2_sig; eexists;
transitivity (Positional.eval wt (x wt a b));
[|cbv [Positional.chained_carries fold_right_cps2 seq fold_right sz2]; assert_preconditions; autorewrite with uncps push_id push_basesystem_eval; reflexivity].

apply f_equal.

basesystem_partial_evaluation_RHS.

(* rough breakdown of synthesis time *)
(* 1.2s for side conditions -- should improve significantly when [chained_carries] gets a correctness lemma *)
(* basesystem_partial_evaluation_RHS (primarily vm_compute): 1.8s, which gets re-computed during defined *)

(* doing [cbv -[Let_In runtime_add runtime_mul]] took 37s *)

reflexivity.
Defined. (* 3s *)
End Ops.

(*
Expand Down
1 change: 1 addition & 0 deletions src/Util/Tactics.v
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Require Export Crypto.Util.Tactics.RewriteHyp.
Require Export Crypto.Util.Tactics.SpecializeBy.
Require Export Crypto.Util.Tactics.SplitInContext.
Require Export Crypto.Util.Tactics.UniquePose.
Require Export Crypto.Util.Tactics.VM.

(** Test if a tactic succeeds, but always roll-back the results *)
Tactic Notation "test" tactic3(tac) :=
Expand Down
32 changes: 32 additions & 0 deletions src/Util/Tactics/VM.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
(* Code by Jason Gross for COQBUG 4637: vm_compute in _ makes Defined slow *)

(** First, work around COQBUG 4494, https://coq.inria.fr/bugs/show_bug.cgi?id=4494 (replace is slow and broken under binders *)
Ltac replace_with_at_by x y set_tac tac :=
let H := fresh in
let x' := fresh in
set_tac x' x;
assert (H : y = x') by (subst x'; tac);
clearbody x'; induction H.

Ltac replace_with_at x y set_tac :=
let H := fresh in
let x' := fresh in
set_tac x' x;
cut (y = x');
[ intro H; induction H
| subst x' ].

Ltac replace_with_vm_compute c :=
let c' := (eval vm_compute in c) in
(* we'd like to just do: *)
(* replace c with c' by (clear; abstract (vm_compute; reflexivity)) *)
(* but [set] is too slow in 8.4, so we write our own version (see COQBUG https://coq.inria.fr/bugs/show_bug.cgi?id=3280#c13 *)
let set_tac := (fun x' x
=> pose x as x';
change x with x') in
replace_with_at_by c c' set_tac ltac:(clear; vm_cast_no_check (eq_refl c')).

Ltac replace_with_vm_compute_in c H :=
let c' := (eval vm_compute in c) in
(* By constrast [set ... in ...] seems faster than [change .. with ... in ...] in 8.4?! *)
replace_with_at_by c c' ltac:(fun x' x => set (x' := x) in H ) ltac:(clear; vm_cast_no_check (eq_refl c')).