Skip to content

Commit

Permalink
Add shape function and dynamic shape checking for mlas_matmul (apache…
Browse files Browse the repository at this point in the history
…#214)

* Add shape function for mlas_matmul

* Fix lint

* Add dynamic shape checking for mlas AlterOpLayout

* Add testing for dynamic shape checking

* Fix dense alterOpLayout
  • Loading branch information
ymwangg authored and trevor-m committed Aug 3, 2021
1 parent 1b83419 commit 179fdec
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 24 deletions.
66 changes: 53 additions & 13 deletions python/tvm/relay/op/_mlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Strategy and AlterOpLayout functions of MLAS operators"""
import tvm
from tvm import relay, topi
from tvm.te.hybrid import script
from .strategy import wrap_topi_schedule
from . import op as reg

Expand Down Expand Up @@ -74,19 +75,24 @@ def _alter_batch_matmul_layout(attrs, inputs, tinfos, out_type):
and tinfos[1].dtype == "float32"
and out_type.dtype == "float32"
):
# if matrix B is constant, use packed matmul
if isinstance(inputs[1], relay.expr.Constant):
b_shape = inputs[1].data.shape
assert len(b_shape) == 3
batch, N, K = b_shape[0], b_shape[1], b_shape[2]
# batch_B must be 1
if batch == 1:
packed_b = relay.op.mlas_packb(inputs[1], K, N)
output = relay.op.mlas_matmul(inputs[0], packed_b, True, K, N)
return output
# if matrix A, B are not constant and no other libs are enabled, use normal matmul
if not any([item in target.libs for item in ["mkl", "clbas", "mkldnn"]]):
return relay.op.mlas_matmul(inputs[0], inputs[1], False)
# mlas is only used for static tensors
if not (
any([isinstance(dim, tvm.tir.Any) for dim in tinfos[0].shape])
or any([isinstance(dim, tvm.tir.Any) for dim in tinfos[1].shape])
):
# if matrix B is constant, use packed matmul
if isinstance(inputs[1], relay.expr.Constant):
b_shape = inputs[1].data.shape
assert len(b_shape) == 3
batch, N, K = b_shape[0], b_shape[1], b_shape[2]
# batch_B must be 1
if batch == 1:
packed_b = relay.op.mlas_packb(inputs[1], K, N)
output = relay.op.mlas_matmul(inputs[0], packed_b, True, K, N)
return output
# if matrix A, B are not constant and no other libs are enabled, use normal matmul
if not any([item in target.libs for item in ["mkl", "clbas", "mkldnn"]]):
return relay.op.mlas_matmul(inputs[0], inputs[1], False)
return None


Expand Down Expand Up @@ -123,3 +129,37 @@ def _compute_mlas_packb(attrs, inputs, _):

# Dense AlterOpLayout
# See tvm.topi.x86.dense_alter_op


@script
def _mlas_matmul_shape_func(tensor_a_shape, tensor_b_shape):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
if tensor_a_shape.shape[0] == 3:
out[0] = tensor_a_shape[0]
out[1] = tensor_a_shape[1]
out[2] = tensor_b_shape[1]
else:
out[0] = tensor_a_shape[0]
out[1] = tensor_b_shape[0]
return out


@script
def _mlas_matmul_packb_shape_func(tensor_a_shape, N):
out = output_tensor((tensor_a_shape.shape[0],), "int64")
if tensor_a_shape.shape[0] == 3:
out[0] = tensor_a_shape[0]
out[1] = tensor_a_shape[1]
out[2] = N
else:
out[0] = tensor_a_shape[0]
out[1] = N
return out


@reg.register_shape_func("mlas_matmul", False)
def matmul_shape_func(attrs, inputs, _):
"""Shape function for matmul op."""
if attrs.packb:
return [_mlas_matmul_packb_shape_func(inputs[0], tvm.tir.expr.IntImm("int64", attrs.N))]
return [_mlas_matmul_shape_func(inputs[0], inputs[1])]
27 changes: 16 additions & 11 deletions python/tvm/topi/x86/dense_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,22 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
and tinfos[1].dtype == "float32"
and out_type.dtype == "float32"
):
# if matrix B is constant, use packed matmul
if isinstance(inputs[1], relay.expr.Constant):
b_shape = inputs[1].data.shape
assert len(b_shape) == 2
N, K = b_shape[0], b_shape[1]
packed_b = relay.op.mlas_packb(inputs[1], K, N)
output = relay.op.mlas_matmul(inputs[0], packed_b, True, K, N)
return output
# if matrix A, B are not constant and no other libs are enabled, use normal matmul
if not any([item in target.libs for item in ["mkl", "clbas", "mkldnn"]]):
return relay.op.mlas_matmul(inputs[0], inputs[1], False)
# mlas is only used for static tensors
if not (
any([isinstance(dim, tvm.tir.Any) for dim in tinfos[0].shape])
or any([isinstance(dim, tvm.tir.Any) for dim in tinfos[1].shape])
):
# if matrix B is constant, use packed matmul
if isinstance(inputs[1], relay.expr.Constant):
b_shape = inputs[1].data.shape
assert len(b_shape) == 2
N, K = b_shape[0], b_shape[1]
packed_b = relay.op.mlas_packb(inputs[1], K, N)
output = relay.op.mlas_matmul(inputs[0], packed_b, True, K, N)
return output
# if matrix A, B are not constant and no other libs are enabled, use normal matmul
if not any([item in target.libs for item in ["mkl", "clbas", "mkldnn"]]):
return relay.op.mlas_matmul(inputs[0], inputs[1], False)

dispatch_ctx = autotvm.task.DispatchContext.current
data_tensor, weight_tensor = tinfos
Expand Down
41 changes: 41 additions & 0 deletions tests/python/contrib/test_mlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,31 @@ def nopack_expected():
b = _run_opt_pass(nopack_expected(), relay.transform.InferType())
assert tvm.ir.structural_equal(a, b)

def dynamic_before():
A = relay.var("A", shape=(relay.Any(), k), dtype="float32")
B = relay.var("B", shape=(n, k), dtype="float32")
C = relay.nn.dense(A, B)
f = relay.Function(relay.analysis.free_vars(C), C)
return f

def dynamic_expected():
A = relay.var("A", shape=(relay.Any(), k), dtype="float32")
B = relay.var("B", shape=(n, k), dtype="float32")
target_layout = "NK16n"
weight_transform = relay.layout_transform(B, "NK", target_layout)
y = relay.nn.contrib_dense_pack(A, weight_transform, units=None, out_dtype="float32")
y = relay.Function(relay.analysis.free_vars(y), y)
return y

with tvm.target.Target(target):
with TempOpAttr(
"nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout
):
a = dynamic_before()
a = _run_opt_pass(a, relay.transform.AlterOpLayout())
b = _run_opt_pass(dynamic_expected(), relay.transform.InferType())
assert tvm.ir.structural_equal(a, b)


def test_alter_op_layout_batch_matmul():
if not get_global_func("tvm.contrib.mlas.batch_sgemm", allow_missing=True):
Expand Down Expand Up @@ -333,6 +358,22 @@ def nopack_expected():
b = _run_opt_pass(nopack_expected(), relay.transform.InferType())
assert tvm.ir.structural_equal(a, b)

def dynamic_expected():
A = relay.var("A", shape=(1, relay.Any(), k), dtype="float32")
B = relay.var("B", shape=(1, n, k), dtype="float32")
C = relay.nn.batch_matmul(A, B)
f = relay.Function(relay.analysis.free_vars(C), C)
return f

with tvm.target.Target(target):
with TempOpAttr(
"nn.batch_matmul", "FTVMAlterOpLayout", relay.op._mlas._alter_batch_matmul_layout
):
a = dynamic_expected()
a = _run_opt_pass(a, relay.transform.AlterOpLayout())
b = _run_opt_pass(dynamic_expected(), relay.transform.InferType())
assert tvm.ir.structural_equal(a, b)


if __name__ == "__main__":
test_topi_mlas_matmul()
Expand Down

0 comments on commit 179fdec

Please sign in to comment.