Skip to content

Commit

Permalink
[Autoscheduler] Add sparse conv2d(1*1) support for auto_scheduler (ap…
Browse files Browse the repository at this point in the history
…ache#8065)

* add sparse conv2d support for auto_scheduler

* add description

* fix bug

* fix annotation

* Lint fix

Co-authored-by: laiyin.lyc <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 17, 2021
1 parent 3305106 commit 1f6dc15
Show file tree
Hide file tree
Showing 17 changed files with 1,221 additions and 18 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,18 @@ struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
};

/*! \brief Attributes for sparse_dense operator */
struct SparseConv2DAttrs : public tvm::AttrsNode<SparseConv2DAttrs> {
std::string layout;

TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") {
TVM_ATTR_FIELD(layout).set_default("NHWC").describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC'"
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively.");
}
};

/*! \brief Attributes for FIFO buffer operator */
struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
int axis;
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# Feature
from . import feature
from . import sparse_dense
from . import sparse_conv2d

# Utilities
from .count_layers import count_layers
154 changes: 154 additions & 0 deletions python/tvm/relay/analysis/sparse_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
"""
This file contains helper functions for convert dense model
to block sparse model
"""
from collections import namedtuple
import numpy as np
import scipy.sparse as sp
import tvm
from . import _ffi_api


SparseAnalysisResult = namedtuple(
"SparseAnalysisResult",
[
"weight_name",
"weight_shape",
],
)


def _search_conv2d_op_weight(expr):
"""Search name of weight in all ```nn.conv2d``` operator
This is a helpful function to determine which param need
to be converted to sparse
Parameters
----------
expr : relay.Expr
Expr will be searched
Returns
-------
ret : Array[String]
name of weight in all ``nn.conv2d``` operator
"""
return _ffi_api.search_conv2d_op_weight(expr)


def process_params(expr, params, block_size, sparsity_threshold, layout):
"""Process parameters of conv2d from dense to sparse.
Parameters
----------
expr : Relay.Expr
Expr of the network
params : Dict[String, tvm.nd.array]
parameters of the network
block_size : Tuple(int, int)
Blocksize in BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting to sparse operation
layout : str
layout of network
Returns
-------
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified conv2d weight and the shape in BSR format
"""

# pylint: disable=import-outside-toplevel
from tvm.auto_scheduler.search_task import (
register_task_input_buffer,
) # lazily import to avoid recursive dependency

memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_conv2d_op_weight(expr)
for name in weight_names:
name = str(name)
w_np = params[name].asnumpy()
# currently only support conv2d_1*1
if not (
(w_np.shape[0] == 1 and w_np.shape[1] == 1)
or (w_np.shape[2] == 1 and w_np.shape[3] == 1)
):
continue
sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
if sparsity >= sparsity_threshold:
if layout == "NHWC":
w_np = w_np.squeeze().T
elif layout == "NCHW":
w_np = w_np.squeeze()

sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)

# when bs_c=1, remove this dim
if block_size[1] == 1:
sparse_weight_data = sparse_weight.data.reshape(
sparse_weight.data.shape[0], block_size[0]
)
else:
sparse_weight_data = sparse_weight.data

# remove dense weight
del params[name]
memo.weight_name.append(name)
memo.weight_shape.append(
list(sparse_weight_data.shape)
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight_data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)

prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % (
w_np.shape[0],
w_np.shape[1],
block_size[0],
block_size[1],
sparse_weight.indices.shape[0],
sparse_weight.indptr.shape[0],
)
register_task_input_buffer(
"default",
prefix + "W_data",
tvm.runtime.ndarray.array(sparse_weight_data),
overwrite=True,
)
register_task_input_buffer(
"default",
prefix + "W_indices",
tvm.runtime.ndarray.array(sparse_weight.indices),
overwrite=True,
)
register_task_input_buffer(
"default",
prefix + "W_indptr",
tvm.runtime.ndarray.array(sparse_weight.indptr),
overwrite=True,
)
ret = SparseAnalysisResult(
weight_name=tvm.runtime.convert(memo.weight_name),
weight_shape=tvm.runtime.convert(memo.weight_shape),
)
return ret
1 change: 1 addition & 0 deletions python/tvm/relay/data_dep_optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from . import bsr_dense
from . import simplify_fc_transpose
from . import bsr_conv2d
58 changes: 58 additions & 0 deletions python/tvm/relay/data_dep_optimization/bsr_conv2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, not-context-manager
"""Automatic convert model from dense to block sparse"""

from tvm import relay
from tvm.relay.analysis.sparse_conv2d import process_params

from .utils import _run_opt_pass


def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"):
"""Convert a dense func and according parameters to block sparse
Parameters
----------
func : relay.Expr
Expr will be optimized to sparse operation
params : Dict[Srting, tvm.nd.array]
Parameters of the Expr
blocksize : Tuple(int, int)
Blocksize for BSR matrix
sparsity_threshold : float
Minimal sparsity requirement for converting.
If weight sparsity is lower than this threshold,
the dense operation will be kept.
layout : str
layout of network
Returns
-------
new_func: relay.Expr
Mutated Expr with sparse operations
params: Dict[Srting, tvm.nd.array]
New params with BSR matrix for mutated Expr
"""
weight_info = process_params(func, params, blocksize, sparsity_threshold, layout)
new_func = _run_opt_pass(
func,
relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout),
)

return new_func, params
11 changes: 11 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def compute_sparse_transpose(attrs, inputs, out_type):
reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# sparse_conv2d
@reg.register_compute("nn.sparse_conv2d")
def compute_sparse_conv2d(attrs, inputs, out_type):
"""Compute definition of sparse_conv2d"""
return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])]


reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy)
reg.register_pattern("nn.sparse_conv2d", reg.OpPattern.OUT_ELEMWISE_FUSABLE)


# conv1d
reg.register_strategy("nn.conv1d", strategy.conv1d_strategy)
reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,11 @@ class SparseTransposeAttrs(Attrs):
"""Attributes used in sparse_transpose operators"""


@tvm._ffi.register_object("relay.attrs.SparseConv2DAttrs")
class SparseConv2DAttrs(Attrs):
"""Attributes used in sparse_conv2d operators"""


@tvm._ffi.register_object("relay.attrs.TopkAttrs")
class TopkAttrs(Attrs):
"""Attributes used in topk operators"""
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,29 @@ def schedule_sparse_transpose(attrs, outs, target):
return topi.generic.schedule_sparse_transpose(outs)


# sparse conv2d
def wrap_compute_sparse_conv2d(topi_compute):
"""wrap sparse conv2d topi compute"""

def _compute_sparse_conv2d(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])]

return _compute_sparse_conv2d


@override_native_generic_func("sparse_conv2d_strategy")
def sparse_conv2d_strategy(attrs, inputs, out_type, target):
"""sparse conv2d generic strategy"""
logger.warning("sparse conv2d is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d),
wrap_topi_schedule(topi.generic.schedule_sparse_conv2d),
name="sparse_conv2d.generic",
)
return strategy


# sort
def wrap_compute_sort(topi_compute):
"""Wrap sort topi compute"""
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,29 @@ def DenseToSparse(weight_name, weight_shape):
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def Conv2dToSparse(weight_name, weight_shape, layout):
"""
Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d```
Parameters
----------
weight_name: Array[String]
Names of weights which qualified sparse contrains
weight_shape: Array[Array[IntImm]]
Weights shape in BSR format.
layout : str
layout of data
Returns
-------
ret : tvm.transform.Pass
The registered DenseToSparse pass.
"""
return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout)


def SimplifyFCTranspose(target_weight_name):
"""
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,23 @@ def schedule_sparse_transpose(outs):
return _default_schedule(outs, False)


def schedule_sparse_conv2d(outs):
"""Schedule for sparse_conv2d
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_conv2d
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)


def schedule_batch_matmul(outs):
"""Schedule for batch_matmul
Expand Down
Loading

0 comments on commit 1f6dc15

Please sign in to comment.