Skip to content

Commit

Permalink
[CUDA][PASS]Legalize tensorcore (apache#7147)
Browse files Browse the repository at this point in the history
* add pad_to_tensorcore & legalize for dense/bmm/conv2d

* fix pad & slice

* fix comments

* fix comments

* resolve conflict

* resolve conflict

* support only fp16

* add tests/python/relay/test_pass_legalize_tensorcore.py

* add tests for legalize tensorcore

* fix pylint

* fix pylint

* code format

* use_gpu test only; fix conv2d_alter_op

* fix tests params

* revert transform fix
  • Loading branch information
Meteorix authored and electriclilies committed Feb 18, 2021
1 parent 4547d8c commit aa6a4b4
Show file tree
Hide file tree
Showing 7 changed files with 582 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)


@reg.register_legalize("nn.dense")
def legalize_dense(attrs, inputs, types):
"""Legalize dense op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.dense_legalize(attrs, inputs, types)


# dense
reg.register_strategy("nn.dense", strategy.dense_strategy)
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
Expand All @@ -67,6 +88,27 @@ def compute_fifo_buffer(attrs, inputs, out_type):
reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)


@reg.register_legalize("nn.batch_matmul")
def legalize_batch_matmul(attrs, inputs, types):
"""Legalize batch_matmul op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
return topi.nn.batch_matmul_legalize(attrs, inputs, types)


# batch_matmul
reg.register_strategy("nn.batch_matmul", strategy.batch_matmul_strategy)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@
from .conv2d_hwnc_tensorcore import *
from .correlation import *
from .sparse import *
from . import tensorcore_alter_op
from .argwhere import *
from .scan import *
48 changes: 48 additions & 0 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
from .. import nn
from ..utils import get_const_tuple
from .conv2d_winograd import _infer_tile_size
from .tensorcore_alter_op import pad_to_tensorcore
from ..nn import conv2d_legalize


logger = logging.getLogger("topi")


Expand Down Expand Up @@ -345,4 +347,50 @@ def _conv2d_legalize(attrs, inputs, arg_types):
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
elif data_dtype in ["float16"]: # todo: support int8/int4
if data_layout == "NHWC" and kernel_layout == "HWIO":
batch = data_tensor.shape[0].value
in_channel = data_tensor.shape[3].value
out_channel = kernel_tensor.shape[3].value

if (
(batch % 8 == 0 and in_channel % 16 == 0 and out_channel % 32 == 0)
or (batch % 16 == 0 and in_channel % 16 == 0 and out_channel % 16 == 0)
or (batch % 32 == 0 and in_channel % 16 == 0 and out_channel % 8 == 0)
):
# no need to pad
return None

(db, di, do), extra_flops = pad_to_tensorcore(batch, in_channel, out_channel)

if extra_flops > 2:
logger.info("conv2d pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("conv2d pad_to_tensorcore, extra_flops %s", extra_flops)

# Pad batch size
if db != 0:
data = relay.nn.pad(data, pad_width=((0, db), (0, 0), (0, 0), (0, 0)))

# Pad input channel
if di != 0:
data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0), (0, di)))
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, di), (0, 0)))

# Pad output channel
if do != 0:
kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0, 0), (0, do)))

if do != 0:
new_out_channel = out_channel + do
new_attrs["channels"] = new_out_channel

out = relay.nn.conv2d(data, kernel, **new_attrs)

if db != 0 or do != 0:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)

return out
return None
204 changes: 204 additions & 0 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Tensorcore alter op and legalize functions for cuda backend"""

import logging
import math
from tvm import relay

from .. import nn

logger = logging.getLogger("topi")


@nn.batch_matmul_legalize.register("cuda")
def _batch_matmul_legalize(attrs, inputs, arg_types):
"""Legalizes batch_matmul op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
arg_types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the input tensors.
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

# Pad input and output channels to use tensorcore schedule.
if dtype in ["float16"]: # todo: support int8/int4
B, M, K = x_tensor.shape
B, N, K = y_tensor.shape
M = M.value
K = K.value
N = N.value

# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
if (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
# no need to pad
return None

(dm, dk, dn), extra_flops = pad_to_tensorcore(M, K, N)

if extra_flops > 2:
logger.info("batch_matmul pad_to_tensorcore skipped, extra_flops %s", extra_flops)
return None

logger.info("batch_matmul pad_to_tensorcore, extra_flops %s", extra_flops)
if dm or dk:
x_ = relay.nn.pad(x, pad_width=((0, 0), (0, dm), (0, dk)))
else:
x_ = x
if dn or dk:
y_ = relay.nn.pad(y, pad_width=((0, 0), (0, dn), (0, dk)))
else:
y_ = y
out_ = relay.nn.batch_matmul(x_, y_)
if dm or dn:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out_, begin=[0, 0, 0], end=original_out_shape)
else:
out = out_
return out
return None


@nn.dense_legalize.register("cuda")
def _dense_legalize(attrs, inputs, arg_types):
"""Legalizes dense op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# Collect the input tensors.
x_tensor, y_tensor = arg_types[0], arg_types[1]
dtype = x_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
x, y = inputs

# Pad input and output channels to use tensorcore schedule.
if dtype in ["float16"]: # todo: support int8/int4
M, K = x_tensor.shape
N, K = y_tensor.shape
try:
M = M.value
K = K.value
N = N.value
except AttributeError:
# todo: deal with unfixed shape when compiling wdl model
return None

# The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)
if (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
# no need to pad
return None

(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N)

if extra_flops_ratio > 2:
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
return None

logger.info("dense pad_to_tensorcore, extra_flops_ratio %s", extra_flops_ratio)

if dm or dk:
x_ = relay.nn.pad(x, pad_width=((0, dm), (0, dk)))
else:
x_ = x
if dn or dk:
y_ = relay.nn.pad(y, pad_width=((0, dn), (0, dk)))
else:
y_ = y
out_ = relay.nn.dense(x_, y_)
if dm or dn:
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out_, begin=[0, 0], end=original_out_shape)
else:
out = out_
return out
return None


def pad_to_tensorcore(M, K, N):
"""pad shape to enable tensorcore"""
candidates = [(16, 16, 16), (32, 16, 8), (8, 16, 32)]

flops = M * K * N
extra_flops = math.inf
best_pad = (0, 0, 0)
for padding in candidates:
dm, dk, dn = _pad_to(M, K, N, padding)
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
# print(dm, dk, dn, e, flops)
if e < extra_flops:
extra_flops = e
best_pad = (dm, dk, dn)
return best_pad, extra_flops / flops


def _pad_to(M, K, N, PADDING):
dm, dk, dn = 0, 0, 0

if M % PADDING[0] != 0:
M_ = ((M + PADDING[0]) // PADDING[0]) * PADDING[0]
dm = M_ - M
if K % PADDING[1] != 0:
K_ = ((K + PADDING[1]) // PADDING[1]) * PADDING[1]
dk = K_ - K
if N % PADDING[2] != 0:
N_ = ((N + PADDING[2]) // PADDING[2]) * PADDING[2]
dn = N_ - N

return dm, dk, dn
24 changes: 24 additions & 0 deletions python/tvm/topi/nn/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Batch matrix multiplication"""
# pylint: disable=invalid-name
import tvm
from tvm import te, auto_scheduler
from ..utils import get_const_tuple

Expand Down Expand Up @@ -77,3 +78,26 @@ def batch_matmul(x, y, oshape=None, auto_scheduler_rewritten_layout=""):
output = auto_scheduler.rewrite_compute_body(output, auto_scheduler_rewritten_layout)

return output


@tvm.target.generic_func
def batch_matmul_legalize(attrs, inputs, types):
"""Legalizes batch_matmul op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current batch_matmul
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
# pylint: disable=unused-argument
return None
24 changes: 24 additions & 0 deletions python/tvm/topi/nn/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""TVM operator fully connected compute."""
import tvm
from tvm import te, auto_scheduler
from .. import tag

Expand Down Expand Up @@ -80,3 +81,26 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
matmul = auto_scheduler.rewrite_compute_body(matmul, auto_scheduler_rewritten_layout)

return matmul


@tvm.target.generic_func
def dense_legalize(attrs, inputs, types):
"""Legalizes dense op.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current dense
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types
Returns
-------
result : tvm.relay.Expr
The legalized expr
"""
# not to change by default
# pylint: disable=unused-argument
return None
Loading

0 comments on commit aa6a4b4

Please sign in to comment.