Skip to content

Commit

Permalink
feat: rewritePeepholeRecursively (#410)
Browse files Browse the repository at this point in the history
@tobiasgrosser spotted that 'rewritePeephole' only works at one level of
the IR,
and will not recurse into regions.
We write a variant ('rewritePeepholeRecursively').
This first calls 'rewritePeephole' on the 'Com'.
Then, it recurses into each let-binding of the 'Com' to call
'rewritePeepholeRecursively' on all region arguments.
This ensures that the rewrite is applied to all occurrences of the lhs
in all (nested) regions.

This supercedes #408
  • Loading branch information
bollu authored Jun 24, 2024
1 parent b7757a7 commit ae0dd93
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
71 changes: 71 additions & 0 deletions SSA/Core/Framework.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2516,6 +2516,77 @@ theorem denote_rewritePeephole (fuel : ℕ)
/-- info: 'denote_rewritePeephole' depends on axioms: [propext, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms denote_rewritePeephole

theorem Expr.denote_eq_of_region_denote_eq (op : d.Op)
(ty_eq : ty = DialectSignature.outTy op)
(eff' : DialectSignature.effectKind op ≤ eff)
(args : HVector (Var Γ) (DialectSignature.sig op))
(regArgs regArgs' : HVector (fun t => Com d t.1 EffectKind.impure t.2) (DialectSignature.regSig op))
(hregArgs' : regArgs'.denote = regArgs.denote) :
(Expr.mk op ty_eq eff' args regArgs').denote = (Expr.mk op ty_eq eff' args regArgs).denote := by
funext Γv
cases eff
case pure =>
subst ty_eq
have heff' : DialectSignature.effectKind op = EffectKind.pure := by simp [eff']
simp [heff', Expr.denote, hregArgs']
case impure =>
subst ty_eq
simp [Expr.denote, hregArgs']

mutual

def rewritePeepholeRecursivelyRegArgs (fuel : ℕ)
(pr : PeepholeRewrite d Γ t) {ts : List (Ctxt d.Ty × d.Ty)}
(args : HVector (fun t => Com d t.1 EffectKind.impure t.2) ts)
: { out : HVector (fun t => Com d t.1 EffectKind.impure t.2) ts // out.denote = args.denote} :=
match ts with
| .nil =>
match args with
| .nil => ⟨HVector.nil, rfl⟩
| .cons .. =>
match args with
| .cons com coms =>
let ⟨com', hcom'⟩ := (rewritePeepholeRecursively fuel pr com)
let ⟨coms', hcoms'⟩ := (rewritePeepholeRecursivelyRegArgs fuel pr coms)
⟨.cons com' coms', by simp [hcom', hcoms']⟩

def rewritePeepholeRecursivelyExpr (fuel : ℕ)
(pr : PeepholeRewrite d Γ t) {ty : d.Ty}
(e : Expr d Γ₂ eff ty) : { out : Expr d Γ₂ eff ty // out.denote = e.denote } :=
match e with
| Expr.mk op ty eff' args regArgs =>
let ⟨regArgs', hregArgs'⟩ := rewritePeepholeRecursivelyRegArgs fuel pr regArgs
⟨Expr.mk op ty eff' args regArgs', by
apply Expr.denote_eq_of_region_denote_eq op ty eff' args regArgs regArgs' hregArgs'⟩

/-- A peephole rewriter that recurses into regions, allowing
peephole rewriting into nested code. -/
def rewritePeepholeRecursively (fuel : ℕ)
(pr : PeepholeRewrite d Γ t) (target : Com d Γ₂ eff t₂) :
{ out : Com d Γ₂ eff t₂ // out.denote = target.denote } :=
match fuel with
| 0 => ⟨target, rfl⟩
| fuel + 1 =>
let target' := rewritePeephole fuel pr target
have htarget'_denote_eq_htarget : target'.denote = target.denote := by apply denote_rewritePeephole
match htarget : target' with
| .ret v => ⟨target', by
simp [htarget, htarget'_denote_eq_htarget]⟩
| .var (α := α) e body =>
let ⟨e', he'⟩ := rewritePeepholeRecursivelyExpr fuel pr e
let ⟨body', hbody'⟩ :=
-- decreases because 'body' is smaller.
rewritePeepholeRecursively fuel pr body
⟨.var e' body', by
rw [← htarget'_denote_eq_htarget]
simp [he', hbody']⟩
end

/--
info: 'rewritePeepholeRecursively' depends on axioms: [propext, Classical.choice, Quot.sound]
-/
#guard_msgs in #print axioms rewritePeepholeRecursively

end SimpPeepholeApplier

section TypeProjections
Expand Down
99 changes: 99 additions & 0 deletions SSA/Projects/PaperExamples/PaperExamples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ instance : TyDenote Ty where
toType
| .int => BitVec 32

instance : Inhabited (TyDenote.toType (t : Ty)) where
default := by
cases t
exact (0#32)

inductive Op : Type
| add : Op
| const : (val : ℤ) → Op
Expand Down Expand Up @@ -285,6 +290,7 @@ def iterate {Γ : Ctxt _} (k : Nat) (input : Var Γ int) (body : Com SimpleReg [

attribute [local simp] Ctxt.snoc

namespace P1
/-- running `f(x) = x + x` 0 times is the identity. -/
def lhs : Com SimpleReg [int] .pure int :=
Com.var (iterate (k := 0) (⟨0, by simp[Ctxt.snoc]⟩) (
Expand Down Expand Up @@ -355,4 +361,97 @@ theorem EX1' : ex1' = (
:= by rfl
-/

end P1

namespace P2

/-- running `f(x) = x + 0` 0 times is the identity. -/
def lhs : Com SimpleReg [int] .pure int :=
Com.var (cst 0) <| -- %c0
Com.var (add ⟨0, by simp[Ctxt.snoc]⟩ ⟨1, by simp[Ctxt.snoc]⟩) <| -- %out = %x + %c0
Com.ret ⟨0, by simp[Ctxt.snoc]⟩

def rhs : Com SimpleReg [int] .pure int :=
Com.ret ⟨0, by simp[Ctxt.snoc]⟩

def p2 : PeepholeRewrite SimpleReg [int] int:=
{ lhs := lhs, rhs := rhs, correct := by
rw [lhs, rhs]
funext Γv
simp_peephole [add, cst] at Γv
/- ∀ (a : BitVec 32), a + BitVec.ofInt 32 0 = a -/
intros a
simp only [ofInt_zero, ofNat_eq_ofNat, BitVec.add_zero, BitVec.zero_add]
}

/--
example program that has the pattern 'x + 0' both at the top level,
and inside a region in an iterate. -/
def egLhs : Com SimpleReg [int] .pure int :=
Com.var (cst 0) <|
Com.var (add ⟨0, by simp[Ctxt.snoc]⟩ ⟨1, by simp[Ctxt.snoc]⟩) <| -- %out = %x + %c0
Com.var (iterate (k := 0) (⟨0, by simp[Ctxt.snoc]⟩) (
Com.letPure (cst 0) <|
Com.letPure (add ⟨0, by simp[Ctxt.snoc]⟩ ⟨1, by simp[Ctxt.snoc]⟩) -- fun x => (x + x)
<| Com.ret ⟨0, by simp[Ctxt.snoc]⟩
)) <|
Com.ret ⟨0, by simp[Ctxt.snoc]⟩

/--
info: {
^entry(%0 : ToyRegion.Ty.int):
%1 = ToyRegion.Op.const 0 : () → (ToyRegion.Ty.int)
%2 = ToyRegion.Op.add (%1, %0) : (ToyRegion.Ty.int, ToyRegion.Ty.int) → (ToyRegion.Ty.int)
%3 = ToyRegion.Op.iterate 0 (%2) ({
^entry(%0 : ToyRegion.Ty.int):
%1 = ToyRegion.Op.const 0 : () → (ToyRegion.Ty.int)
%2 = ToyRegion.Op.add (%1, %0) : (ToyRegion.Ty.int, ToyRegion.Ty.int) → (ToyRegion.Ty.int)
return %2 : (ToyRegion.Ty.int) → ()
}) : (ToyRegion.Ty.int) → (ToyRegion.Ty.int)
return %3 : (ToyRegion.Ty.int) → ()
}
-/
#guard_msgs in #eval egLhs

def runRewriteOnLhs : Com SimpleReg [int] .pure int :=
(rewritePeepholeRecursively (fuel := 100) p2 egLhs).val

/--
info: {
^entry(%0 : ToyRegion.Ty.int):
%1 = ToyRegion.Op.const 0 : () → (ToyRegion.Ty.int)
%2 = ToyRegion.Op.add (%1, %0) : (ToyRegion.Ty.int, ToyRegion.Ty.int) → (ToyRegion.Ty.int)
%3 = ToyRegion.Op.iterate 0 (%0) ({
^entry(%0 : ToyRegion.Ty.int):
%1 = ToyRegion.Op.const 0 : () → (ToyRegion.Ty.int)
%2 = ToyRegion.Op.add (%1, %0) : (ToyRegion.Ty.int, ToyRegion.Ty.int) → (ToyRegion.Ty.int)
return %0 : (ToyRegion.Ty.int) → ()
}) : (ToyRegion.Ty.int) → (ToyRegion.Ty.int)
return %3 : (ToyRegion.Ty.int) → ()
}
-/
#guard_msgs in #eval runRewriteOnLhs

def expectedRhs : Com SimpleReg [int] .pure int :=
Com.var (cst 0) <|
Com.var (add ⟨0, by simp[Ctxt.snoc]⟩ ⟨1, by simp[Ctxt.snoc]⟩) <| -- %out = %x + %c0
-- | Note that the argument to 'iterate' is rewritten.
-- This is a rewrite that fires at the top level.
Com.var (iterate (k := 0) (⟨2, by simp[Ctxt.snoc]⟩) (
Com.letPure (cst 0) <|
Com.letPure (add ⟨0, by simp[Ctxt.snoc]⟩ ⟨1, by simp[Ctxt.snoc]⟩)
-- | See that the rewrite has fired in the nested region for 'iterate',
-- and we directly return the block argument.
<| Com.ret ⟨2, by simp[Ctxt.snoc]⟩
)) <|
Com.ret ⟨0, by simp[Ctxt.snoc]⟩

theorem rewriteDidSomething : runRewriteOnLhs ≠ lhs := by
simp [runRewriteOnLhs, lhs]
native_decide

theorem rewriteCorrect : runRewriteOnLhs = expectedRhs := by rfl

end P2

end ToyRegion

0 comments on commit ae0dd93

Please sign in to comment.