Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Strategy] Support for Int8 schedules - CUDA/x86 #5031

Merged
merged 8 commits into from
Mar 12, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,8 +1373,8 @@ def _get_sum(_res, _output_scale, out_dtype):

# 3) Clip/cast to change the out dtype.
_res = relay.clip(_res,
a_min=float(tvm.api.min_value(out_dtype).value),
a_max=float(tvm.api.max_value(out_dtype).value))
a_min=float(tvm.tir.op.min_value(out_dtype).value),
a_max=float(tvm.tir.op.max_value(out_dtype).value))
_res = relay.cast(_res, out_dtype)
return _res

Expand Down Expand Up @@ -1647,8 +1647,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale):
_op.multiply(_op.cast(bias_data, 'float32'), bias_requantize_scale)
rounded_bias = _op.round(multiplied_bias)
clipped_bias = _op.clip(rounded_bias,
a_min=tvm.api.min_value('int32').value,
a_max=tvm.api.max_value('int32').value)
a_min=tvm.tir.op.min_value('int32').value,
a_max=tvm.tir.op.max_value('int32').value)
requantized_bias = _op.cast(clipped_bias, 'int32')
res = _op.nn.bias_add(res, requantized_bias, axis=-1)
enable_float_output = attrs.get_bool('enable_float_output', False)
Expand Down
16 changes: 11 additions & 5 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,18 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
if data.dtype in ('int8', 'uint8') and kernel.dtype in ('int8', 'uint8'):
assert data.dtype == kernel.dtype
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.cuda")
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
name="conv2d_nchw.cuda")
_, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,17 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
if is_fast_int8_on_intel():
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)

#####################
# CUDA legalizations.
#####################

@qnn_conv2d_legalize.register('cuda')
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)

@qnn_dense_legalize.register('cuda')
def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers the dtypes to be same.
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
60 changes: 60 additions & 0 deletions topi/python/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .. import nn
from ..util import get_const_tuple
from .conv2d_winograd import _infer_tile_size
from ..nn import conv2d_legalize

logger = logging.getLogger('topi')

Expand Down Expand Up @@ -135,3 +136,62 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.conv2d(*inputs, **new_attrs)

return None

@conv2d_legalize.register("cuda")
def _conv2d_legalize(attrs, inputs, arg_types):
"""Legalizes Conv2D op.

Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
types : list of types
List of input and output types

Returns
-------
result : tvm.relay.Expr
The legalized expr
"""

# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
if not (dilation[0] == 1 and dilation[1] == 1):
return None

# No legalization for depthwise convolutions yet.
groups = attrs.get_int("groups")
if groups != 1:
return None

# Collect the input tensors.
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
data_dtype = data_tensor.dtype

# Collect the output tensor.
output_tensor = arg_types[2]

# Collect the input exprs.
data, kernel = inputs

# Get the conv attrs
new_attrs = {k: attrs[k] for k in attrs.keys()}

# Get data layout. Return None if not NCHW
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']

if data_dtype in ['int8', 'uint8']:
if data_layout == 'NCHW' and kernel_layout == "OIHW":
in_channel = data_tensor.shape[1].value
if in_channel % 4 != 0:
new_in_channel = ((in_channel + 4) // 4) * 4
diff = new_in_channel - in_channel
pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
data = relay.nn.pad(data, pad_width=pad_width)
kernel = relay.nn.pad(kernel, pad_width=pad_width)
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
return None
21 changes: 20 additions & 1 deletion topi/python/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=no-value-for-parameter
"""Int8 conv2d in NCHWc layout"""
import tvm
from tvm import te
Expand All @@ -23,10 +24,22 @@
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline


def conv2d_nchw_int8(data, kernel, strides, padding, dilation, out_dtype='int32'):
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
assert data.dtype in ('int8', 'uint8')
assert kernel.dtype in ('int8', 'uint8')
assert data.dtype == kernel.dtype
packed_out = conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, "NCHW", out_dtype)
return unpack_NCHWc_to_nchw(packed_out, out_dtype)

def schedule_conv2d_nchw_int8(outs):
"""Create schedule for tensors"""
return schedule_conv2d_NCHWc_int8(outs)

@autotvm.register_topi_compute("conv2d_NCHWc_int8.cuda")
def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_dtype):
"""Convolution operator in NCHW[x]c layout for int8.
Expand Down Expand Up @@ -205,7 +218,13 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
output = s.outputs[0].output(0)

# tile and bind spatial axes
n, f, y, x, c = s[output].op.axis
if len(s[output].op.axis) == 5:
n, f, y, x, c = s[output].op.axis
else:
# For task extraction of auto-tuning, the expected output is 4D. Since auto-tuning tasks
# are created from scratch, therefore the real auto-tuning will still happen on 5D output.
n, f, y, x = s[output].op.axis

cfg.define_split("tile_n", cfg.axis(n), num_outputs=4)
cfg.define_split("tile_f", cfg.axis(f), num_outputs=4)
cfg.define_split("tile_y", cfg.axis(y), num_outputs=4)
Expand Down
66 changes: 48 additions & 18 deletions topi/python/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# conv2d_nchwc_int8 has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -189,13 +190,26 @@ def schedule_conv_NCHWc_cpu_common_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oc_f_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
ow_chunk, ow_block = s[O].split(ow, factor=reg_n)
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
s[O].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
parallel_axis = s[O].fuse(batch, oc_chunk, oh)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s

Expand Down Expand Up @@ -234,7 +248,8 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
parallel_axis = s[data_vec].fuse(batch, ic_chunk, ih)
s[data_vec].parallel(parallel_axis)

oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[kernel_vec].op.axis
# Conv2d int8 schedule has 7D kernel
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block, _ = s[kernel_vec].op.axis
s[kernel_vec].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
oc_bn = cfg["tile_oc"].size[-1]
if oc_bn > 1:
Expand Down Expand Up @@ -277,14 +292,29 @@ def schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data_vec, kernel_vec, conv_out,
s[CC].unroll(oh_inner)

if C != O:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
out_ndim = len(s[O].op.axis)
if out_ndim == 5:
batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
elif out_ndim == 4:
batch, oc, oh, ow = s[O].op.axis
oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)

parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
s[C].compute_at(s[O], parallel_axis)
s[O].vectorize(oc_block)
s[O].parallel(parallel_axis)
else:
raise ValueError("Unsupported output ndim: %s" % out_ndim)

return s