Skip to content

Commit

Permalink
Deprecate double_fission (#299)
Browse files Browse the repository at this point in the history
Co-authored-by: Samir Droubi <[email protected]>
  • Loading branch information
SamirDroubi and SamirDroubi authored Jan 7, 2023
1 parent 68b5c77 commit e4d51dc
Show file tree
Hide file tree
Showing 7 changed files with 4 additions and 175 deletions.
3 changes: 2 additions & 1 deletion apps/x86/sgemm/sgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def make_avx512_kernel(p):
p = set_memory(p, "C_reg", AVX512)
p = autolift_alloc(p, "C_reg: _", n_lifts=3, keep_dims=True)
p = autolift_alloc(p, "C_reg: _")
p = double_fission(p, "C_reg[_] = C[_]", "C_reg[_] += _", n_lifts=4)
p = autofission(p, p.find("C_reg[_] = _").after(), n_lifts=4)
p = autofission(p, p.find("C[_] = _").before(), n_lifts=4)
# Stage A & B
def stage_input(p, expr, new_buf):
p = bind_expr(p, expr, new_buf)
Expand Down
13 changes: 0 additions & 13 deletions src/exo/API_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,19 +1942,6 @@ def add_unsafe_guard(proc, block_cursor, var_expr):
return Schedules.DoAddUnsafeGuard(proc_c, stmt, var_expr).result()


@sched_op([StmtCursorA, StmtCursorA, PosIntA])
def double_fission(proc, stmt1, stmt2, n_lifts=1):
"""
DEPRECATED
This operation is deprecated, and will be removed soon.
"""
s1 = stmt1._impl
s2 = stmt2._impl
proc_c = ic.Cursor.root(proc)

return Schedules.DoDoubleFission(proc_c, s1, s2, n_lifts).result()


@sched_op([ForSeqCursorA])
def bound_and_guard(proc, loop):
"""
Expand Down
141 changes: 0 additions & 141 deletions src/exo/LoopIR_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,146 +2261,6 @@ def _stmt(s):
return all(_stmt(s) for s in stmts)


class _DoDoubleFission:
def __init__(self, proc_cursor, f_cursor, s_cursor, n_lifts):
self.tgt_stmt1 = f_cursor._node()
self.tgt_stmt2 = s_cursor._node()

assert isinstance(self.tgt_stmt1, LoopIR.stmt)
assert isinstance(self.tgt_stmt2, LoopIR.stmt)
assert is_pos_int(n_lifts)
self.orig_proc = proc_cursor._node()
self.n_lifts = n_lifts
self.provenance = proc_cursor.proc()

self.hit_fission1 = False
self.hit_fission2 = False

pre_body, mid_body, post_body = self.map_stmts(self.orig_proc.body)
self.proc = LoopIR.proc(
name=self.orig_proc.name,
args=self.orig_proc.args,
preds=self.orig_proc.preds,
body=pre_body + mid_body + post_body,
instr=None,
eff=self.orig_proc.eff,
srcinfo=self.orig_proc.srcinfo,
)

self.proc = InferEffects(self.proc).result()

def result(self):
return api.Procedure(self.proc, _provenance_eq_Procedure=self.provenance)

def alloc_check(self, pre, post):
if not _is_alloc_free(pre, post):
raise SchedulingError(
"Will not fission here, because "
"an allocation might be buried "
"in a different scope than some use-site"
)

def map_stmts(self, stmts):
pre_stmts = []
mid_stmts = []
post_stmts = []
for orig_s in stmts:
pre, mid, post = self.map_s(orig_s)
pre_stmts += pre
mid_stmts += mid
post_stmts += post

return pre_stmts, mid_stmts, post_stmts

def map_s(self, s):
if s is self.tgt_stmt1:
self.hit_fission1 = True
return [s], [], []
elif s is self.tgt_stmt2:
self.hit_fission2 = True
return [], [s], []

elif isinstance(s, LoopIR.If):

# first, check if we need to split the body
pre, mid, post = self.map_stmts(s.body)
fission_body = (
len(pre) > 0 and len(mid) > 0 and len(post) > 0 and self.n_lifts > 0
)
if fission_body:
self.n_lifts -= 1
self.alloc_check(pre, mid)
self.alloc_check(mid, post)
pre = LoopIR.If(s.cond, pre, [], None, s.srcinfo)
mid = LoopIR.If(s.cond, mid, s.orelse, None, s.srcinfo)
post = LoopIR.If(s.cond, post, [], None, s.srcinfo)
return [pre], [mid], [post]

body = pre + mid + post

# if we don't, then check if we need to split the or-else
pre, mid, post = self.map_stmts(s.orelse)
fission_orelse = (
len(pre) > 0 and len(post) > 0 and len(mid) > 0 and self.n_lifts > 0
)
if fission_orelse:
self.n_lifts -= 1
self.alloc_check(pre, mid)
self.alloc_check(mid, post)
pre = LoopIR.If(s.cond, [], pre, None, s.srcinfo)
mid = LoopIR.If(s.cond, body, mid, None, s.srcinfo)
post = LoopIR.If(s.cond, [], post, None, s.srcinfo)
return [pre], [mid], [post]

orelse = pre + mid + post

# if we neither split the body nor the or-else,
# then we need to gather together the pre and post.
single_stmt = LoopIR.If(s.cond, body, orelse, None, s.srcinfo)

elif isinstance(s, LoopIR.Seq):

# check if we need to split the loop
pre, mid, post = self.map_stmts(s.body)
do_fission = (
len(pre) > 0 and len(post) > 0 and len(mid) > 0 and self.n_lifts > 0
)
if do_fission:
self.n_lifts -= 1
self.alloc_check(pre, mid)
self.alloc_check(mid, post)

# we can skip the loop iteration if the
# body doesn't depend on the loop
# and the body is idempotent
if s.iter in _FV(pre) or not _is_idempotent(pre):
pre = [s.update(body=pre, eff=None)]
# since we are copying the binding of s.iter,
# we should perform an Alpha_Rename for safety
pre = Alpha_Rename(pre).result()
if s.iter in _FV(mid) or not _is_idempotent(mid):
mid = [s.update(body=mid, eff=None)]
if s.iter in _FV(post) or not _is_idempotent(post):
post = [s.update(body=post, eff=None)]
post = Alpha_Rename(post).result()

return pre, mid, post

single_stmt = s.update(body=pre + mid + post, eff=None)

else:
# all other statements cannot recursively
# contain statements, so...
single_stmt = s

if self.hit_fission1 and not self.hit_fission2:
return [], [single_stmt], []
elif self.hit_fission2:
return [], [], [single_stmt]
else:
return [single_stmt], [], []


class _DoRemoveLoop(Cursor_Rewrite):
def __init__(self, proc_cursor, stmt_cursor):
self.stmt = stmt_cursor._node()
Expand Down Expand Up @@ -3990,7 +3850,6 @@ class Schedules:
DoAddLoop = _DoAddLoop
DoDataReuse = _DoDataReuse
DoLiftScope = _DoLiftScope
DoDoubleFission = _DoDoubleFission
DoPartitionLoop = _PartitionLoop
DoAssertIf = _AssertIf
DoSpecialize = _DoSpecialize
Expand Down
1 change: 0 additions & 1 deletion src/exo/stdlib/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
#
# deprecated scheduling operations
add_unsafe_guard,
double_fission,
bound_and_guard,
stage_assn,
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def matmul_on_gemmini(A: R[128, 128] @ DRAM, B: R[128, 128] @ DRAM,
C: R[128, 128] @ DRAM):
config_st_acc(stride(C, 0))
config_matmul()
res: R[8, 8, 16, 16]
res: R[8, 8, 16, 16] @ DRAM
a: R[8, 8, 16, 16]
b: R[8, 8, 16, 16]
for io in seq(0, 8):
Expand Down
3 changes: 1 addition & 2 deletions tests/pldi22/test_gemmini_matmul_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ def test_matmul_paper(golden):
gemmini = gemmini.partial_eval(NN, MM, KK)

# Stage memories, so that we can use gemmini scratchpad & accumulator
gemmini = stage_assn(gemmini, "C[_] += _", "res")
gemmini = double_fission(gemmini, "res = _", "res += _")
gemmini = stage_mem(gemmini, "for k in _: _", "C[i, j]", "res")
gemmini = bind_expr(gemmini, "A[_]", "a")
gemmini = bind_expr(gemmini, "B[_]", "b")

Expand Down
16 changes: 0 additions & 16 deletions tests/test_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,22 +593,6 @@ def foo(n: size, m: size, A: R[n + m + 12]):
mult_dim(foo, "x", 0, 1)


def test_double_fission(golden):
@proc
def foo(N: size, a: f32[N], b: f32[N], out: f32[N]):
for i in seq(0, N):
res: f32
res = 0.0

res += a[i] * b[i]

out[i] = res

foo = autolift_alloc(foo, "res : _", keep_dims=True)
foo = double_fission(foo, "res = _ #0", "res += _ #0")
assert str(foo) == golden


def test_reuse_buffer(golden):
@proc
def foo(a: f32 @ DRAM, b: f32 @ DRAM):
Expand Down

0 comments on commit e4d51dc

Please sign in to comment.