Skip to content

Commit

Permalink
rfactor unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed May 15, 2021
1 parent 377bc91 commit 7911019
Showing 1 changed file with 222 additions and 0 deletions.
222 changes: 222 additions & 0 deletions tests/python/unittest/test_tir_schedule_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import numpy as np
import tvm
import tvm.testing
from tvm import tir
from tvm.script import ty

# pylint: disable=no-member,invalid-name,unused-variable


@tvm.script.tir
def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])

for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]:
tir.bind(vi, i0)
tir.bind(vj, i1)
tir.bind(vk, (((i2_outer*32) + (i2_inner_outer*4)) + i2_inner_inner))
tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
tir.writes([C[vi, vj]])
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))


@tvm.script.tir
def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
C_rf = tir.alloc_buffer([4, 128, 128])

for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4):
with tir.block([4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf") as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]:
tir.bind(vi2_inner_inner, i2_inner_inner)
tir.bind(vi, i0)
tir.bind(vj, i1)
tir.bind(vi2_outer, i2_outer)
tir.bind(vi2_inner_outer, i2_inner_outer)
with tir.init():
C_rf[vi2_inner_inner, vi, vj] = 0.0
C_rf[vi2_inner_inner, vi, vj] = (C_rf[vi2_inner_inner, vi, vj] + (A[vi, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)]*B[vj, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)]))

for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4):
with tir.block([128, 128, tir.reduce_axis(0, 4)], "update") as [vi_1, vj_1, vi2_inner_inner_1]:
tir.bind(vi_1, i0_1)
tir.bind(vj_1, i1_1)
tir.bind(vi2_inner_inner_1, i2_inner_inner_1)
with tir.init():
C[vi_1, vj_1] = 0.0
C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1])


@tvm.script.tir
def square_sum(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [16, 256, 256])
C = tir.match_buffer(c, [16])

with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]:
with tir.init():
C[b] = 0.0
C[b] = C[b] + A[b, i, j] * A[b, i, j]


@tvm.script.tir
def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [16, 256, 256])
C = tir.match_buffer(c, [16])
C_rf = tir.alloc_buffer([16, 256])

for i0, i1, i2 in tir.grid(16, 256, 256):
with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]:
tir.bind(vi2, i2)
tir.bind(b, i0)
tir.bind(i, i1)
with tir.init():
C_rf[b, vi2] = 0.0
C_rf[b, vi2] = (C_rf[b, vi2] + (A[b, i, vi2]*A[b, i, vi2]))

for i0_1, i2_1 in tir.grid(16, 256):
with tir.block([16, tir.reduce_axis(0, 256)], "C") as [b_1, vi2_1]:
tir.bind(b_1, i0_1)
tir.bind(vi2_1, i2_1)
with tir.init():
C[b_1] = 0.0
C[b_1] = (C[b_1] + C_rf[b_1, vi2_1])


@tvm.script.tir
def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None:
A = tir.match_buffer(a, [16, 256, 256])
D = tir.match_buffer(d, [16])
C = tir.alloc_buffer([16])

for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]:
tir.bind(b, i0)
tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
tir.reads([C[b], A[b, i, j]])
tir.writes([C[b]])
with tir.init():
C[b] = 0.0
C[b] = (C[b] + (A[b, i, j]*A[b, i, j]))
for i0_1 in tir.serial(0, 16):
with tir.block([16], "D") as [b_1]:
tir.bind(b_1, i0_1)
tir.reads([C[b_1]])
tir.writes([D[b_1]])
D[b_1] = tir.sqrt(C[b_1], dtype="float32")


@tvm.script.tir
def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None:
A = tir.match_buffer(a, [16, 256, 256])
D = tir.match_buffer(d, [16])
C = tir.alloc_buffer([16])
C_rf = tir.alloc_buffer([1, 16])

for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1):
with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [vi1_i2_fused_inner, b, i, j]:
tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner)
tir.bind(b, i0)
tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256))
tir.bind(j, tir.floormod(i1_i2_fused_outer, 256))
with tir.init():
C_rf[vi1_i2_fused_inner, b] = 0.0
C_rf[vi1_i2_fused_inner, b] = (C_rf[vi1_i2_fused_inner, b] + (A[b, i, j]*A[b, i, j]))

for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1):
with tir.block([16, tir.reduce_axis(0, 1)], "C") as [b_1, vi1_i2_fused_inner_1]:
tir.bind(b_1, i0_1)
tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1)
with tir.init():
C[b_1] = 0.0
C[b_1] = (C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1])

for i0_2 in tir.serial(0, 16):
with tir.block([16], "D") as [b_2]:
tir.bind(b_2, i0_2)
D[b_2] = tir.sqrt(C[b_2], dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable


def test_reduction_rfactor_matmul():
s = tir.Schedule(transformed_matmul, debug_mode=True)
C = s.get_block("update")
_, _, _, _, kii = s.get_loops(C)
rf_block = s.rfactor(kii, 0)
tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor)
assert s.get(rf_block).same_as(s.get(s.get_block("update_rf")))

f = tvm.build(s.mod["main"], target="llvm")
a_np = np.random.uniform(size=(128, 128)).astype("float32")
b_np = np.random.uniform(size=(128, 128)).astype("float32")
a = tvm.nd.array(a_np)
b = tvm.nd.array(b_np)
c = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
f(a, b, c)
c_np = np.matmul(a_np, b_np.T)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_square_sum():
s = tir.Schedule(square_sum, debug_mode=True)
C = s.get_block("C")
_, _, j = s.get_loops(C)
rf_block = s.rfactor(j, 1)
tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor)
assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))

f = tvm.build(s.mod["main"], target="llvm")
a_np = np.random.uniform(size=(16, 256, 256)).astype("float32")
a = tvm.nd.array(a_np)
c = tvm.nd.array(np.zeros((16,), dtype="float32"))
f(a, c)
c_np = np.sum(a_np * a_np, axis=(1, 2))
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_square_sum_square_root():
s = tir.Schedule(transformed_square_sum_square_root, debug_mode=True)
C = s.get_block("C")
_, _, fi = s.get_loops(C)
rf_block = s.rfactor(fi, 0)
tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor)
assert s.get(rf_block).same_as(s.get(s.get_block("C_rf")))

f = tvm.build(s.mod["main"], target="llvm")
a_np = np.random.uniform(size=(16, 256, 256)).astype("float32")
a = tvm.nd.array(a_np)
c = tvm.nd.array(np.zeros((16,), dtype="float32"))
f(a, c)
c_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2)))
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
test_reduction_rfactor_matmul()
test_reduction_rfactor_square_sum()
test_reduction_rfactor_square_sum_square_root()

0 comments on commit 7911019

Please sign in to comment.