diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index c7b175837d3d..15f6b03f0c06 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -1010,6 +1010,18 @@ struct SparseTransposeAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {} }; +/*! \brief Attributes for sparse_dense operator */ +struct SparseConv2DAttrs : public tvm::AttrsNode { + std::string layout; + + TVM_DECLARE_ATTRS(SparseConv2DAttrs, "relay.attrs.SparseConv2DAttrs") { + TVM_ATTR_FIELD(layout).set_default("NHWC").describe( + "Dimension ordering of input data. Can be 'NCHW', 'NHWC'" + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively."); + } +}; + /*! \brief Attributes for FIFO buffer operator */ struct FIFOBufferAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index b4ea7f3cff62..ae642e44cf8d 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -29,6 +29,7 @@ # Feature from . import feature from . import sparse_dense +from . import sparse_conv2d # Utilities from .count_layers import count_layers diff --git a/python/tvm/relay/analysis/sparse_conv2d.py b/python/tvm/relay/analysis/sparse_conv2d.py new file mode 100644 index 000000000000..9790c1fdda2d --- /dev/null +++ b/python/tvm/relay/analysis/sparse_conv2d.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return +# pylint: disable=unidiomatic-typecheck +""" +This file contains helper functions for convert dense model +to block sparse model +""" +from collections import namedtuple +import numpy as np +import scipy.sparse as sp +import tvm +from . import _ffi_api + + +SparseAnalysisResult = namedtuple( + "SparseAnalysisResult", + [ + "weight_name", + "weight_shape", + ], +) + + +def _search_conv2d_op_weight(expr): + """Search name of weight in all ```nn.conv2d``` operator + This is a helpful function to determine which param need + to be converted to sparse + + Parameters + ---------- + expr : relay.Expr + Expr will be searched + + Returns + ------- + ret : Array[String] + name of weight in all ``nn.conv2d``` operator + """ + return _ffi_api.search_conv2d_op_weight(expr) + + +def process_params(expr, params, block_size, sparsity_threshold, layout): + """Process parameters of conv2d from dense to sparse. + + Parameters + ---------- + expr : Relay.Expr + Expr of the network + params : Dict[String, tvm.nd.array] + parameters of the network + block_size : Tuple(int, int) + Blocksize in BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting to sparse operation + layout : str + layout of network + + Returns + ------- + ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]] + return names of qualified conv2d weight and the shape in BSR format + """ + + # pylint: disable=import-outside-toplevel + from tvm.auto_scheduler.search_task import ( + register_task_input_buffer, + ) # lazily import to avoid recursive dependency + + memo = SparseAnalysisResult(weight_name=[], weight_shape=[]) + weight_names = _search_conv2d_op_weight(expr) + for name in weight_names: + name = str(name) + w_np = params[name].asnumpy() + # currently only support conv2d_1*1 + if not ( + (w_np.shape[0] == 1 and w_np.shape[1] == 1) + or (w_np.shape[2] == 1 and w_np.shape[3] == 1) + ): + continue + sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size) + if sparsity >= sparsity_threshold: + if layout == "NHWC": + w_np = w_np.squeeze().T + elif layout == "NCHW": + w_np = w_np.squeeze() + + sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size) + + # when bs_c=1, remove this dim + if block_size[1] == 1: + sparse_weight_data = sparse_weight.data.reshape( + sparse_weight.data.shape[0], block_size[0] + ) + else: + sparse_weight_data = sparse_weight.data + + # remove dense weight + del params[name] + memo.weight_name.append(name) + memo.weight_shape.append( + list(sparse_weight_data.shape) + + list(sparse_weight.indices.shape) + + list(sparse_weight.indptr.shape) + ) + params[name + ".data"] = tvm.nd.array(sparse_weight_data) + params[name + ".indices"] = tvm.nd.array(sparse_weight.indices) + params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr) + + prefix = "sparse_conv2d_bsr_%d_%d_%d_%d_%d_%d_" % ( + w_np.shape[0], + w_np.shape[1], + block_size[0], + block_size[1], + sparse_weight.indices.shape[0], + sparse_weight.indptr.shape[0], + ) + register_task_input_buffer( + "default", + prefix + "W_data", + tvm.runtime.ndarray.array(sparse_weight_data), + overwrite=True, + ) + register_task_input_buffer( + "default", + prefix + "W_indices", + tvm.runtime.ndarray.array(sparse_weight.indices), + overwrite=True, + ) + register_task_input_buffer( + "default", + prefix + "W_indptr", + tvm.runtime.ndarray.array(sparse_weight.indptr), + overwrite=True, + ) + ret = SparseAnalysisResult( + weight_name=tvm.runtime.convert(memo.weight_name), + weight_shape=tvm.runtime.convert(memo.weight_shape), + ) + return ret diff --git a/python/tvm/relay/data_dep_optimization/__init__.py b/python/tvm/relay/data_dep_optimization/__init__.py index 8feb45238bf3..5f429917b8a6 100644 --- a/python/tvm/relay/data_dep_optimization/__init__.py +++ b/python/tvm/relay/data_dep_optimization/__init__.py @@ -19,3 +19,4 @@ from . import bsr_dense from . import simplify_fc_transpose +from . import bsr_conv2d diff --git a/python/tvm/relay/data_dep_optimization/bsr_conv2d.py b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py new file mode 100644 index 000000000000..6913a428b2ac --- /dev/null +++ b/python/tvm/relay/data_dep_optimization/bsr_conv2d.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, not-context-manager +"""Automatic convert model from dense to block sparse""" + +from tvm import relay +from tvm.relay.analysis.sparse_conv2d import process_params + +from .utils import _run_opt_pass + + +def convert(func, params, blocksize, sparsity_threshold, layout="NHWC"): + """Convert a dense func and according parameters to block sparse + + Parameters + ---------- + func : relay.Expr + Expr will be optimized to sparse operation + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr + blocksize : Tuple(int, int) + Blocksize for BSR matrix + sparsity_threshold : float + Minimal sparsity requirement for converting. + If weight sparsity is lower than this threshold, + the dense operation will be kept. + layout : str + layout of network + + Returns + ------- + new_func: relay.Expr + Mutated Expr with sparse operations + + params: Dict[Srting, tvm.nd.array] + New params with BSR matrix for mutated Expr + """ + weight_info = process_params(func, params, blocksize, sparsity_threshold, layout) + new_func = _run_opt_pass( + func, + relay.transform.Conv2dToSparse(weight_info.weight_name, weight_info.weight_shape, layout), + ) + + return new_func, params diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 3d817c7378b5..c6c4f4bfb959 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -168,6 +168,17 @@ def compute_sparse_transpose(attrs, inputs, out_type): reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +# sparse_conv2d +@reg.register_compute("nn.sparse_conv2d") +def compute_sparse_conv2d(attrs, inputs, out_type): + """Compute definition of sparse_conv2d""" + return [topi.nn.sparse_conv2d(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + + +reg.register_strategy("nn.sparse_conv2d", strategy.sparse_conv2d_strategy) +reg.register_pattern("nn.sparse_conv2d", reg.OpPattern.OUT_ELEMWISE_FUSABLE) + + # conv1d reg.register_strategy("nn.conv1d", strategy.conv1d_strategy) reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 4cc6e0f26b91..f82da2bf88a5 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -534,6 +534,11 @@ class SparseTransposeAttrs(Attrs): """Attributes used in sparse_transpose operators""" +@tvm._ffi.register_object("relay.attrs.SparseConv2DAttrs") +class SparseConv2DAttrs(Attrs): + """Attributes used in sparse_conv2d operators""" + + @tvm._ffi.register_object("relay.attrs.TopkAttrs") class TopkAttrs(Attrs): """Attributes used in topk operators""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 7451b397265f..db73a874005f 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -842,6 +842,29 @@ def schedule_sparse_transpose(attrs, outs, target): return topi.generic.schedule_sparse_transpose(outs) +# sparse conv2d +def wrap_compute_sparse_conv2d(topi_compute): + """wrap sparse conv2d topi compute""" + + def _compute_sparse_conv2d(attrs, inputs, out_type): + return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["layout"])] + + return _compute_sparse_conv2d + + +@override_native_generic_func("sparse_conv2d_strategy") +def sparse_conv2d_strategy(attrs, inputs, out_type, target): + """sparse conv2d generic strategy""" + logger.warning("sparse conv2d is not optimized for this platform.") + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_sparse_conv2d(topi.nn.sparse_conv2d), + wrap_topi_schedule(topi.generic.schedule_sparse_conv2d), + name="sparse_conv2d.generic", + ) + return strategy + + # sort def wrap_compute_sort(topi_compute): """Wrap sort topi compute""" diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 5b0e480f5f28..20e8bb94c501 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1092,6 +1092,29 @@ def DenseToSparse(weight_name, weight_shape): return _ffi_api.DenseToSparse(weight_name, weight_shape) +def Conv2dToSparse(weight_name, weight_shape, layout): + """ + Rewrite qualified ```nn.conv2d operation``` to ```nn.sparse_conv2d``` + + Parameters + ---------- + weight_name: Array[String] + Names of weights which qualified sparse contrains + + weight_shape: Array[Array[IntImm]] + Weights shape in BSR format. + + layout : str + layout of data + + Returns + ------- + ret : tvm.transform.Pass + The registered DenseToSparse pass. + """ + return _ffi_api.Conv2dToSparse(weight_name, weight_shape, layout) + + def SimplifyFCTranspose(target_weight_name): """ Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` diff --git a/python/tvm/topi/generic/nn.py b/python/tvm/topi/generic/nn.py index 60ccd0d36abf..866887706862 100644 --- a/python/tvm/topi/generic/nn.py +++ b/python/tvm/topi/generic/nn.py @@ -730,6 +730,23 @@ def schedule_sparse_transpose(outs): return _default_schedule(outs, False) +def schedule_sparse_conv2d(outs): + """Schedule for sparse_conv2d + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of sparse_conv2d + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + def schedule_batch_matmul(outs): """Schedule for batch_matmul diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index 60d7dde91a0c..511ed195101c 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -478,6 +478,270 @@ def _traverse(t): return sparse_input_map +def _sparse_conv2d_bsr_compute_nhwc(data, weight_data, weight_indices, weight_indptr): + (m, h, w, k) = get_const_tuple(data.shape) # pylint: disable=C0103 + if len(weight_data.shape) == 2: + _, bs_r = get_const_tuple(weight_data.shape) + elif len(weight_data.shape) == 3: + _, bs_r, bs_c = get_const_tuple(weight_data.shape) + (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) + num_blocks = num_blocks_plus_1 - 1 + + def _compute_block(i, h, w, nb_j, j): # pylint: disable=C0103 + row_start = weight_indptr[nb_j] + row_end = weight_indptr[nb_j + 1] + row_elems = row_end - row_start + elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") + block_offset = row_start + elem_idx + block_j = weight_indices[block_offset] + if len(weight_data.shape) == 3: + c = te.reduce_axis((0, bs_c), name="c") + block_ij_val = weight_data[block_offset][j][c] + x_val = data[i, h, w, bs_c * block_j + c] + return te.sum(block_ij_val * x_val, axis=[elem_idx, c]) + else: + block_ij_val = weight_data[block_offset][j] + x_val = data[i, h, w, block_j] + return te.sum(block_ij_val * x_val, axis=[elem_idx]) + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + bsrmm_block = te.compute( + (m, h, w, num_blocks, bs_r), + _compute_block, + tag="sparse_conv2d_sp_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k * h * w}, + ) + return te.compute( + (m, h, w, num_blocks * bs_r), + lambda m, h, w, n: bsrmm_block[m, h, w, idxd(n, bs_r), idxm(n, bs_r)], + tag="sparse_conv2d_sp_bsrmm", + name="sparse_conv2d", + attrs={"layout": "NHWC"}, + ) + + +def _sparse_conv2d_bsr_compute_nchw(data, weight_data, weight_indices, weight_indptr): + (m, k, h, w) = get_const_tuple(data.shape) # pylint: disable=C0103 + if len(weight_data.shape) == 2: + _, bs_r = get_const_tuple(weight_data.shape) + elif len(weight_data.shape) == 3: + _, bs_r, bs_c = get_const_tuple(weight_data.shape) + (num_blocks_plus_1,) = get_const_tuple(weight_indptr.shape) + num_blocks = num_blocks_plus_1 - 1 + + def _compute_block(i, nb_j, j, h, w): # pylint: disable=C0103 + row_start = weight_indptr[nb_j] + row_end = weight_indptr[nb_j + 1] + row_elems = row_end - row_start + elem_idx = te.reduce_axis((0, row_elems), name="elem_idx") + block_offset = row_start + elem_idx + block_j = weight_indices[block_offset] + if len(weight_data.shape) == 3: + c = te.reduce_axis((0, bs_c), name="c") + block_ij_val = weight_data[block_offset][j][c] + x_val = data[i, bs_c * block_j + c, h, w] + return te.sum(block_ij_val * x_val, axis=[elem_idx, c]) + else: + block_ij_val = weight_data[block_offset][j] + x_val = data[i, block_j, h, w] + return te.sum(block_ij_val * x_val, axis=[elem_idx]) + + idxd = tvm.tir.indexdiv + idxm = tvm.tir.indexmod + + bsrmm_block = te.compute( + (m, num_blocks, bs_r, h, w), + _compute_block, + tag="sparse_conv2d_sp_bsrmm_block", + attrs={"FLOP": 2 * m * num_blocks * bs_r * k * h * w}, + ) + return te.compute( + (m, num_blocks * bs_r, h, w), + lambda m, n, h, w: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r), h, w], + tag="sparse_conv2d_sp_bsrmm", + name="sparse_conv2d", + attrs={"layout": "NCHW"}, + ) + + +def sparse_conv2d(dense_data, sparse_data, sparse_indices, sparse_indptr, layout="NHWC"): + """ + Computes sparse-conv2d(1*1) of `data` and + `(weight_data, weight_indices, weight_indptr) + + Parameters + ---------- + dense_data : tvm.te.Tensor + 4-D with shape [M, H, W, K] (layout=NHWC) + 4-D with shape [M, K, H, W] (layout=NCHW) + + sparse_data : tvm.te.Tensor + 2-D with shape [num_blocks, bs_r] (BSR) + 3-D with shape [num_blocks, bs_r, bs_c] (BSR) + + sparse_indices : tvm.te.Tensor + 1-D with shape [num_blocks] (BSR) + + sparse_indptr : tvm.te.Tensor + 1-D with shape [(N + 1) // bs_r] (BSR) + + layout : str + layout of data + + Returns + ------- + output : tvm.te.Tensor + 4-D with shape [M, H, W, N] (layout=NHWC) + 4-D with shape [M, N, H ,W] (layout=NCHW) + """ + if layout == "NHWC": + return _sparse_conv2d_bsr_compute_nhwc( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + elif layout == "NCHW": + return _sparse_conv2d_bsr_compute_nchw( + dense_data, sparse_data, sparse_indices, sparse_indptr + ) + else: + raise ValueError("Unsupport Layout %s" % layout) + + +@auto_scheduler.register_task_input_check_func +def try_get_conv2d_sparse_input(args): + """Analyze the input data from the given args. + + Parameters + ---------- + args : List[Tensor] + Input/output Tensor of a TVM subgraph. + + Returns + ------- + Dict[Tensor, str] : + Map from the input Tensor to its buffer name. + + Notes + ----- + The buffer name is specially designed, and these buffer should be provided in + `SearchTask(..., task_inputs={...})`. + """ + sparse_prefix = sparse_data = sparse_indices = sparse_indptr = None + + def _process_inputs(input_tensors, m, h, w, n, prefix_init, layout): # pylint: disable=C0103 + nonlocal sparse_prefix + nonlocal sparse_data + nonlocal sparse_indices + nonlocal sparse_indptr + + assert len(input_tensors) == 4 + unsure_tensors = list(input_tensors) + # Get the Dense data + dense_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 4: + assert dense_data is None + dense_data = tensor + if layout == "NHWC": + assert m == dense_data.shape[0] + assert h == dense_data.shape[1] + assert w == dense_data.shape[2] + k = dense_data.shape[3] + elif layout == "NCHW": + assert m == dense_data.shape[0] + assert h == dense_data.shape[2] + assert w == dense_data.shape[3] + k = dense_data.shape[1] + unsure_tensors.remove(dense_data) + # Get the Sparse data + sparse_data = None + for tensor in unsure_tensors: + if len(tensor.shape) == 3: + assert sparse_data is None + sparse_data = tensor + block_size, bs_r, bs_c = sparse_data.shape + if len(tensor.shape) == 2: + assert sparse_data is None + sparse_data = tensor + block_size, bs_r = sparse_data.shape + bs_c = 1 + unsure_tensors.remove(sparse_data) + # Get the Sparse indptr & indices + sparse_indices = None + for tensor in unsure_tensors: + assert len(tensor.shape) == 1 + if tensor.shape[0] == block_size: + assert sparse_indices is None + sparse_indices = tensor + unsure_tensors.remove(sparse_indices) + assert len(unsure_tensors) == 1 + sparse_indptr = unsure_tensors[0] + # Generate the sparse_prefix + density = 1.0 + for i in sparse_data.shape: + density *= i + density /= k * n + density = density.value + sparse_prefix = "%s_%d_%d_%d_%d_%d_%d_" % ( + prefix_init, + n, + k, + bs_r, + bs_c, + sparse_indices.shape[0], + sparse_indptr.shape[0], + ) + + visited = set() + + def _traverse(t): + # We cannot directly add tensors to the set, because the comparison of + # two tensors with ndim=0 is ambiguous. + assert t.handle is not None + if t.handle.value in visited: + return + + if isinstance(t.op, te.ComputeOp): + if t.op.tag == "sparse_conv2d_sp_bsrmm": + m, h, w, n = t.shape # pylint: disable=C0103 + assert len(t.op.input_tensors) == 1 + block_tensor = t.op.input_tensors[0] + _process_inputs( + block_tensor.op.input_tensors, + m, + h, + w, + n, + "sparse_conv2d_bsr", + t.op.attrs["layout"], + ) + if sparse_prefix is not None: + # Early stop if we find a sparse_prefix + # Notice: If any workload has more than one sparse input, this may get problem + return + for x in t.op.input_tensors: + _traverse(x) + visited.add(t.handle.value) + + try: + for arg in args: + _traverse(arg) + # pylint: disable=broad-except + except Exception: + return {} + + if sparse_data is None or sparse_indices is None or sparse_indptr is None: + return {} + + sparse_input_map = {} + sparse_input_map[sparse_data] = sparse_prefix + "W_data" + sparse_input_map[sparse_indices] = sparse_prefix + "W_indices" + sparse_input_map[sparse_indptr] = sparse_prefix + "W_indptr" + + return sparse_input_map + + def sparse_add(dense_data, sparse_data, sparse_indices, sparse_indptr): """ Computes sparse-dense addition diff --git a/python/tvm/topi/sparse/utils.py b/python/tvm/topi/sparse/utils.py index 43bc6e021429..f57418ee399a 100644 --- a/python/tvm/topi/sparse/utils.py +++ b/python/tvm/topi/sparse/utils.py @@ -16,7 +16,9 @@ # under the License. """Some utils for Sparse operation.""" import tvm -from tvm import relay +from tvm import relay, auto_scheduler +from tvm.relay import data_dep_optimization as ddo +from tvm.auto_scheduler import _ffi_api def random_bsr_matrix(m, n, bs_r, bs_c, density, dtype): @@ -90,7 +92,57 @@ def deepcopy(param_dic): return new_params -def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c=1, sparsity=0.85): +def random_sparse_conv2d_params(func, params, bs_r, bs_c, density, layout): + """Replace the dense parameters with random sparse parameters. Mainly used for testing. + + Parameters + ---------- + func : tvm.relay.Expr + Expr will be optimized to sparse operation. + params : Dict[Srting, tvm.nd.array] + Parameters of the Expr. + bs_r : int + The row of BSR matrix block. + bs_c : int + The column of BSR matrix block. + density : float + The density of the random sparse parameters. + layout : str + layout of network + + Returns + ------- + Dict[Srting, tvm.nd.array] + The generated random parameters. + """ + # pylint: disable=import-outside-toplevel + import numpy as np + + def deepcopy(param_dic): + ret = {} + for k, v in param_dic.items(): + ret[k] = tvm.nd.array(v.asnumpy()) + return ret + + new_params = deepcopy(params) + conv2d_weight_names = relay.analysis.sparse_conv2d._search_conv2d_op_weight(func) + for item in conv2d_weight_names: + name = str(item) + shape = new_params[name].shape + if not ((shape[0] == 1 and shape[1] == 1) or (shape[2] == 1 and shape[3] == 1)): + continue + if layout == "NCHW" and shape[0] % bs_r == 0 and shape[1] % bs_c == 0: + new_w = random_bsr_matrix(shape[0], shape[1], bs_r, bs_c, density, "float32").todense() + new_params[name] = tvm.nd.array(np.array(new_w).reshape(shape)) + elif layout == "NHWC" and shape[3] % bs_r == 0 and shape[2] % bs_c == 0: + new_w = random_bsr_matrix(shape[3], shape[2], bs_r, bs_c, density, "float32").todense() + new_params[name] = tvm.nd.array(np.array(new_w).reshape(shape)) + return new_params + + +def convert_model_dense_to_sparse( + mod, params, random_params=False, bs_r=1, bs_c=1, sparsity=0.85, layout="NHWC" +): """Convert a dense model to sparse model. Parameters @@ -108,6 +160,8 @@ def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c The column of BSR matrix block. sparsity : float The sparsity of the random sparse parameters. + layout : str + layout of network Returns ------- @@ -116,11 +170,189 @@ def convert_model_dense_to_sparse(mod, params, random_params=False, bs_r=1, bs_c Dict[Srting, tvm.nd.array] The updated parameters. """ + mod, params = ddo.simplify_fc_transpose.convert(mod["main"], params) if random_params: - # Manually replace the parameters of dense model to sparse tensors + # Manually replace the parameters of dense to sparse tensors params = random_sparse_dense_params(mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity) - # Currently we only support to conver dense matmul to sparse dense matmul + # Manually replace the parameters of conv2d to sparse tensors + params = random_sparse_conv2d_params( + mod, params, bs_r=bs_r, bs_c=bs_c, density=1 - sparsity, layout=layout + ) + # convert dense matmul to sparse matmul mod, params = ddo.bsr_dense.convert(mod, params, (bs_r, bs_c), sparsity_threshold=0.8) + # convert dense conv2d to sparse conv2d + mod, params = ddo.bsr_conv2d.convert( + mod, params, (bs_r, bs_c), sparsity_threshold=0.8, layout=layout + ) return tvm.IRModule.from_expr(mod), params + + +def sparse_sketch_rules(): + """Return the sketch rules for sparse op""" + sparse_sketch_rule_list = [ + auto_scheduler.PreloadCustomSketchRule( + sparse_conv2d_meet_condition_func, sparse_conv2d_apply_func, "SparseConv2D" + ), + auto_scheduler.PreloadCustomSketchRule( + sparse_dense_meet_condition_func, sparse_dense_apply_func, "SparseDense" + ), + # Add more sketch rules for sparse + ] + return sparse_sketch_rule_list + + +def sparse_conv2d_meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_conv2d_sp_bsrmm", + "sparse_conv2d_sp_bsrmm_block", + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + return auto_scheduler.PreloadCustomSketchRule.PASS + + +def sparse_conv2d_apply_func(search_policy, state, stage_id): + """Describe how to generate the initial sketch for sparse conv2d""" + ret = [] + s_0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s_0.stages[stage_id].op.tag == "sparse_conv2d_sp_bsrmm_block": + return [s_0.state_object, stage_id - 1] + + sparse_conv2d = s_0.stages[stage_id].op + sparse_conv2d_block = s_0.stages[stage_id - 1].op + assert sparse_conv2d.tag == "sparse_conv2d_sp_bsrmm" + assert sparse_conv2d_block.tag == "sparse_conv2d_sp_bsrmm_block" + layout = sparse_conv2d.attrs["layout"] + + # Set the default consumer of compute block + consumer = sparse_conv2d + + # If sparse conv2d has a single elementwise consumer + # We can compute inline the sparse_conv2d output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s_0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s_0.state_object, stage_id, consumer_id + ): + consumer = s_0.stages[consumer_id].op + s_0.compute_inline(sparse_conv2d) + + c = None + if layout == "NHWC": + if len(s_0[sparse_conv2d_block].iters) == 6: + # bs_c = 1 + i, h, w, nb_j, j, row_offset = s_0[ # pylint: disable=invalid-name + sparse_conv2d_block + ].iters + else: + i, h, w, nb_j, j, row_offset, c = s_0[ # pylint: disable=invalid-name + sparse_conv2d_block + ].iters + m, x, y, n = s_0[consumer].iters + elif layout == "NCHW": + if len(s_0[sparse_conv2d_block].iters) == 6: + # bs_c = 1 + i, nb_j, j, h, w, row_offset = s_0[ # pylint: disable=invalid-name + sparse_conv2d_block + ].iters + else: + i, nb_j, j, h, w, row_offset, c = s_0[ # pylint: disable=invalid-name + sparse_conv2d_block + ].iters + m, n, x, y = s_0[consumer].iters + + i_0, i_1, i_2 = s_0.split(sparse_conv2d_block, i, [None, None]) + m_0, m_1 = s_0.follow_split(consumer, m, len(s_0.transform_steps) - 1, 1) + h_0, h_1, h_2 = s_0.split(sparse_conv2d_block, h, [None, None]) + x_0, x_1 = s_0.follow_split(consumer, x, len(s_0.transform_steps) - 1, 1) + w_0, w_1, w_2 = s_0.split(sparse_conv2d_block, w, [None, None]) # pylint: disable=invalid-name + y_0, y_1 = s_0.follow_split(consumer, y, len(s_0.transform_steps) - 1, 1) + j_0, j_1 = s_0.split(sparse_conv2d_block, nb_j, [None]) + n_0, n_1 = s_0.follow_split(consumer, n, len(s_0.transform_steps) - 1, 1) + if layout == "NHWC": + if c is None: + s_0.reorder( + sparse_conv2d_block, + [i_0, h_0, w_0, j_0, i_1, h_1, w_1, j_1, row_offset, i_2, h_2, w_2, j], + ) + else: + s_0.reorder( + sparse_conv2d_block, + [i_0, h_0, w_0, j_0, i_1, h_1, w_1, j_1, row_offset, i_2, h_2, w_2, j, c], + ) + s_0.reorder(consumer, [m_0, x_0, y_0, n_0, m_1, x_1, y_1, n_1]) + elif layout == "NCHW": + if c is None: + s_0.reorder( + sparse_conv2d_block, + [i_0, j_0, h_0, w_0, i_1, j_1, h_1, w_1, row_offset, i_2, j, h_2, w_2], + ) + else: + s_0.reorder( + sparse_conv2d_block, + [i_0, j_0, h_0, w_0, i_1, j_1, h_1, w_1, row_offset, i_2, j, c, h_2, w_2], + ) + s_0.reorder(consumer, [m_0, n_0, x_0, y_0, m_1, n_1, x_1, y_1]) + s_0.compute_at(sparse_conv2d_block, consumer, n_0) + + ret.append([s_0.state_object, stage_id - 2]) + + return ret + + +def sparse_dense_meet_condition_func(search_policy, state, stage_id): + state = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if state.stages[stage_id].op.tag in [ + "sparse_dense_sp_rhs_bsrmm", + "sparse_dense_sp_rhs_bsrmm_block", + ]: + return auto_scheduler.PreloadCustomSketchRule.APPLY_AND_SKIP_REST + return auto_scheduler.PreloadCustomSketchRule.PASS + + +def sparse_dense_apply_func(search_policy, state, stage_id): + """Describe how to generate the initial sketch for sparse dense""" + ret = [] + s_0 = auto_scheduler.loop_state.State(state, search_policy.search_task.compute_dag) + if s_0.stages[stage_id].op.tag == "sparse_dense_sp_rhs_bsrmm_block": + return [s_0.state_object, stage_id - 1] + + sparse_dense = s_0.stages[stage_id].op + sparse_dense_block = s_0.stages[stage_id - 1].op + assert sparse_dense.tag == "sparse_dense_sp_rhs_bsrmm" + assert sparse_dense_block.tag == "sparse_dense_sp_rhs_bsrmm_block" + + # Set the default consumer of compute block + consumer = sparse_dense + + # If sparse dense has a single elementwise consumer + # We can compute inline the sparse_dense output stage + consumers = _ffi_api.SearchPolicyUtilsGetConsumers( + search_policy.search_task, s_0.state_object, stage_id + ) + if len(consumers) == 1: + consumer_id = int(consumers.items()[0][0]) + if _ffi_api.SearchPolicyUtilsIsElementwiseMatch( + search_policy.search_task, s_0.state_object, stage_id, consumer_id + ): + consumer = s_0.stages[consumer_id].op + s_0.compute_inline(sparse_dense) + + i, nb_j, j, row_offset, c = s_0[sparse_dense_block].iters + m, n = s_0[consumer].iters + i_0, i_1, i_2 = s_0.split(sparse_dense_block, i, [None, None]) + m_0, m_1 = s_0.follow_split(consumer, m, len(s_0.transform_steps) - 1, 1) + j_0, j_1 = s_0.split(sparse_dense_block, nb_j, [None]) + n_0, n_1 = s_0.follow_split(consumer, n, len(s_0.transform_steps) - 1, 1) + s_0.reorder(sparse_dense_block, [i_0, j_0, i_1, j_1, row_offset, i_2, j, c]) + s_0.reorder(consumer, [m_0, n_0, m_1, n_1]) + s_0.compute_at(sparse_dense_block, consumer, n_0) + + ret.append([s_0.state_object, stage_id - 2]) + + return ret diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index b1a16f18b623..32b0811b48ac 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include "../../transforms/infer_layout_utils.h" @@ -237,5 +238,68 @@ RELAY_REGISTER_OP("nn.sparse_add") .set_support_level(1) .add_type_rel("SparseAdd", SparseAddRel); +TVM_REGISTER_NODE_TYPE(SparseConv2DAttrs); + +bool SparseConv2dRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 5); + const auto* param = attrs.as(); + ICHECK(param != nullptr); + + const auto* data = types[0].as(); + const auto* weight_data = types[1].as(); + ICHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 2 || + weight_data->shape.size() == 3); + const auto* weight_indptr = types[3].as(); + if (data == nullptr) return false; + + if (weight_data->shape.size() == 2 || weight_data->shape.size() == 3) { + // BSR case. + if (param->layout == "NHWC") { + Array oshape({data->shape[0], data->shape[1], data->shape[2], + (weight_indptr->shape[0] - 1) * weight_data->shape[1]}); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; + } else if (param->layout == "NCHW") { + Array oshape({data->shape[0], + (weight_indptr->shape[0] - 1) * weight_data->shape[1], + data->shape[2], data->shape[3]}); + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; + } + } + LOG(FATAL) << "Unknown weight ndim " << weight_data->shape.size() + << " for nn.sparse_conv2d, should be 2 or 3 (BSR)"; + return false; +} + +Expr MakeSparseConv2d(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr, + std::string layout) { + static const Op& op = Op::Get("nn.sparse_conv2d"); + auto attrs = make_object(); + attrs->layout = std::move(layout); + return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_conv2d").set_body_typed(MakeSparseConv2d); + +RELAY_REGISTER_OP("nn.sparse_conv2d") + .describe( + R"code(Applies a sparse convolution :math:`Y = X*W^T` with W sparse. + +- **data**: `(x1, x2, ..., xn, input_dim)` +- **weight**: `(units, input_dim)` +- **out**: `(x1, x2, ..., xn, units)`. + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(4) + .add_argument("dense_data", "nD Tensor", "Input dense data.") + .add_argument("sparse_data", "1D or 3D Tensor", "Sparse data matrix.") + .add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.") + .add_argument("sparse_indptr", "1D Tensor", "Sparse indptr matrix.") + .set_support_level(1) + .add_type_rel("SparseConv2d", SparseConv2dRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc new file mode 100644 index 000000000000..6e4c03b0fcbc --- /dev/null +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file convert_sparse_conv2d.cc + * + * \brief Mutate conv2d operator to sparse conv2d operator + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relay { + +// Search conv2d op weight name from Expr +class Conv2dOpWeightVisitor : private ExprVisitor { + public: + Conv2dOpWeightVisitor() : conv2d_op_(Op::Get("nn.conv2d")) {} + + Array Search(const Expr& expr) { + VisitExpr(expr); + return memo_; + } + + private: + void VisitExpr_(const CallNode* n) final { + if (n->op == conv2d_op_) { + const auto weight = n->args[1].as(); + if (weight) { + memo_.push_back(weight->name_hint()); + } + } + for (const auto& arg : n->args) { + VisitExpr(arg); + } + } + // Cache op + const Op& conv2d_op_; + + Array memo_; +}; // SearchConv2dOpWeight + +Array SearchConv2dOpWeight(const Expr& e) { return Conv2dOpWeightVisitor().Search(e); } + +TVM_REGISTER_GLOBAL("relay.analysis.search_conv2d_op_weight").set_body_typed(SearchConv2dOpWeight); + +// Mutate ```nn.conv2d``` to ```nn.sparse_conv2d``` +class Conv2dToSparseConv2dMutator : public ExprRewriter { + public: + Conv2dToSparseConv2dMutator(const Array& weight_name, + const Array>& weight_shape, const String& layout) + : conv2d_op_(Op::Get("nn.conv2d")), sparse_conv2d_op_(Op::Get("nn.sparse_conv2d")) { + ICHECK_EQ(weight_name.size(), weight_shape.size()); + layout_ = layout; + for (size_t i = 0; i < weight_name.size(); ++i) { + ICHECK(weight_name[i]->IsInstance()); + std::string k = weight_name[i].as()->data; + const auto& ws = weight_shape[i]; + std::vector v(ws.size()); + for (size_t j = 0; j < ws.size(); ++j) { + v[j] = ws[j].as()->value; + } + target_weights_.emplace(k, v); + } + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (pre->op == conv2d_op_) { + const auto weight = pre->args[1].as(); + if (weight) { + if (target_weights_.count(weight->name_hint())) { + const auto& prefix = weight->name_hint(); + const auto& ws = target_weights_.at(prefix); + const auto data = post.as()->args[0]; + relay::TensorType ws_data_type, ws_indices_type, ws_indptr_type; + if (ws.size() == 5) { + ws_data_type = relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32)); + ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); + ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32)); + } else if (ws.size() == 4) { + ws_data_type = relay::TensorType({ws.at(0), ws.at(1)}, DataType::Float(32)); + ws_indices_type = relay::TensorType({ws.at(2)}, DataType::Int(32)); + ws_indptr_type = relay::TensorType({ws.at(3)}, DataType::Int(32)); + } + Var weight_data(prefix + ".data", ws_data_type); + Var weight_indices(prefix + ".indices", ws_indices_type); + Var weight_indptr(prefix + ".indptr", ws_indptr_type); + auto attrs = make_object(); + attrs->layout = std::move(layout_); + return Call(sparse_conv2d_op_, {data, weight_data, weight_indices, weight_indptr}, + Attrs(attrs)); + } + } + } + return post; + } + + private: + // Cached op + const Op& conv2d_op_; + const Op& sparse_conv2d_op_; + std::unordered_map> target_weights_; + String layout_; +}; // class Conv2dToSparseConv2dAlter + +Expr Conv2dToSparse(const Expr& e, const Array& weight_name, + const Array>& weight_shape, const String& layout) { + auto rewriter = Conv2dToSparseConv2dMutator(weight_name, weight_shape, layout); + return PostOrderRewrite(e, &rewriter); +} + +namespace transform { + +Pass Conv2dToSparse(const Array& weight_name, const Array>& weight_shape, + const String& layout) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + // Remove FreeVar warnings + auto f0 = Downcast(Conv2dToSparse(f, weight_name, weight_shape, layout)); + Array sparse_params = FreeVars(f0); + auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + Array params = FreeVars(f1); + for (const auto& var : sparse_params) { + params.push_back(var); + } + return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + }; + return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.Conv2dToSparse").set_body_typed(Conv2dToSparse); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_sparse_conv2d_convert.py b/tests/python/relay/test_sparse_conv2d_convert.py new file mode 100644 index 000000000000..671693cc5827 --- /dev/null +++ b/tests/python/relay/test_sparse_conv2d_convert.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import itertools + +import numpy as np +import scipy.sparse as sp + + +import tvm +from tvm.ir import IRModule +from tvm import relay +from tvm.topi.sparse.utils import random_bsr_matrix + + +def run_func(func, params, x): + with tvm.transform.PassContext(opt_level=3): + graph, lib, new_params = relay.build(func, "llvm", params=params) + + from tvm.contrib import graph_executor + + dev = tvm.cpu(0) + dtype = "float32" + m = graph_executor.create(graph, lib, dev) + # set inputs + m.set_input("data", tvm.nd.array(x.astype(dtype))) + m.set_input(**new_params) + # execute + m.run() + # get outputs + tvm_output = m.get_output(0) + return tvm_output.asnumpy() + + +def test_bsr_sparse_conv2d_nchw(): + data = relay.var("data", shape=(1, 64, 32, 32), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(128, 64, 1, 1), dtype="float32") + y = relay.nn.conv2d(x, w, channels=128, kernel_size=1, data_layout="NCHW", kernel_layout="OIHW") + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64, 8, 1, 0.1, "float32").todense()).reshape( + 128, 64, 1, 1 + ) + ) + } + + x_np = np.random.randn(1, 64, 32, 32).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert( + func, params, (8, 1), 0.2, "NCHW" + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +def test_bsr_sparse_conv2d_nhwc(): + data = relay.var("data", shape=(1, 32, 32, 64), dtype="float32") + x = relay.nn.relu(data) + w = relay.var("weight", shape=(1, 1, 64, 128), dtype="float32") + y = relay.nn.conv2d(x, w, channels=128, kernel_size=1, data_layout="NHWC", kernel_layout="HWIO") + z = relay.nn.relu(y) + func = relay.Function(relay.analysis.free_vars(z), z) + + params = { + "weight": tvm.nd.array( + np.array(random_bsr_matrix(128, 64, 8, 1, 0.1, "float32").todense()).T.reshape( + 1, 1, 64, 128 + ) + ) + } + + x_np = np.random.randn(1, 32, 32, 64).astype("float32") + # dense output + dense_output = run_func(func, params, x_np) + # sparse + sparse_func, params = relay.data_dep_optimization.bsr_conv2d.convert( + func, params, (8, 1), 0.2, "NHWC" + ) + sparse_output = run_func(sparse_func, params, x_np) + np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + test_bsr_sparse_conv2d_nhwc() + test_bsr_sparse_conv2d_nchw() diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 98a3ec86180c..a2aa8fdd9805 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -552,15 +552,69 @@ def test_sparse_add_csr(): tvm.testing.assert_allclose(Z_tvm.asnumpy(), Z_np, atol=1e-4, rtol=1e-4) +def verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, layout): + if layout == "NHWC": + X_np = np.random.randn(M, H, W, K).astype("float32") + elif layout == "NCHW": + X_np = np.random.randn(M, K, H, W).astype("float32") + W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32") + W_np = W_sp_np.todense() + if layout == "NHWC": + Y_np = tvm.topi.testing.conv2d_nhwc_python(X_np, np.array(W_np).T.reshape(1, 1, K, N), 1, 0) + elif layout == "NCHW": + Y_np = tvm.topi.testing.conv2d_nchw_python(X_np, np.array(W_np).reshape(N, K, 1, 1), 1, 0) + + if BS_C == 1: + W_data = te.placeholder(shape=W_sp_np.data.shape[:-1], dtype=str(W_sp_np.data.dtype)) + W_sp_np_data = W_sp_np.data.reshape(W_sp_np.data.shape[0], BS_R) + else: + W_data = te.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype)) + W_sp_np_data = W_sp_np.data + W_indices = te.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype)) + W_indptr = te.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype)) + X = te.placeholder(shape=X_np.shape, dtype=str(X_np.dtype)) + + Y = topi.nn.sparse_conv2d(X, W_data, W_indices, W_indptr, layout) + s = te.create_schedule(Y.op) + + def check_device(device): + dev = tvm.device(device, 0) + if not tvm.testing.device_enabled(device): + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y]) + Y_tvm = tvm.nd.array(np.zeros(Y_np.shape, dtype="float32")) + func( + tvm.nd.array(X_np, dev), + tvm.nd.array(W_sp_np_data, dev), + tvm.nd.array(W_sp_np.indices, dev), + tvm.nd.array(W_sp_np.indptr, dev), + Y_tvm, + ) + tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np.astype("float32"), atol=1e-4, rtol=1e-4) + + check_device("llvm") + + +def test_sparse_conv2d_bsr(): + M, H, W, N, K, BS_R, BS_C, density = 1, 32, 32, 128, 64, 8, 16, 0.9 + verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, "NHWC") + verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, BS_C, density, "NCHW") + verify_sparse_conv2d_bsr(M, H, W, N, K, BS_R, 1, density, "NHWC") + + if __name__ == "__main__": - test_csrmv() - test_csrmm() - test_dense() - test_sparse_dense_csr() - test_sparse_dense_bsr_randomized() - test_sparse_transpose_csr() - test_sparse_dense_padded_cuda() - test_sparse_dense_padded_alter_op() - test_sparse_dense_csr_reverse() - test_sparse_dense_bsr_reverse() - test_sparse_add_csr() + # test_csrmv() + # test_csrmm() + # test_dense() + # test_sparse_dense_csr() + # test_sparse_dense_bsr_randomized() + # test_sparse_transpose_csr() + # test_sparse_dense_padded_cuda() + # test_sparse_dense_padded_alter_op() + # test_sparse_dense_csr_reverse() + # test_sparse_dense_bsr_reverse() + # test_sparse_add_csr() + test_sparse_conv2d() diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 55e9e4e803fe..76068fa79605 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -138,7 +138,7 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal if use_sparse: from tvm.topi.sparse.utils import convert_model_dense_to_sparse - mod, params = convert_model_dense_to_sparse(mod, params, random_params=True) + mod, params = convert_model_dense_to_sparse(mod, params, bs_r=4, random_params=True) return mod, params, input_shape, output_shape @@ -168,7 +168,11 @@ def get_network(name, batch_size, layout="NHWC", dtype="float32", use_sparse=Fal # Extract tasks from the network print("Get model...") mod, params, input_shape, output_shape = get_network( - network, batch_size, layout, dtype=dtype, use_sparse=use_sparse + network, + batch_size, + layout, + dtype=dtype, + use_sparse=use_sparse, ) print("Extract tasks...") tasks, task_weights = auto_scheduler.extract_tasks(mod["main"], params, target) @@ -205,7 +209,21 @@ def run_tuning(): measure_callbacks=[auto_scheduler.RecordToFile(log_file)], ) - tuner.tune(tune_option) + if use_sparse: + from tvm.topi.sparse.utils import sparse_sketch_rules + + search_policy = [ + auto_scheduler.SketchPolicy( + task, + program_cost_model=auto_scheduler.XGBModel(), + init_search_callbacks=sparse_sketch_rules(), + ) + for task in tasks + ] + + tuner.tune(tune_option, search_policy=search_policy) + else: + tuner.tune(tune_option) # We do not run the tuning in our webpage server since it takes too long.