Skip to content

Commit

Permalink
We observe multiple groups across a range of domains (ASR, NMT, LM, e…
Browse files Browse the repository at this point in the history
…tc), (apache#3566)

internally and externally, interested in replacing standard dense layers with
block-sparse matrix multiplication layers. The motivations are generally: higher
performance (due to reduction in FLOPs, memory bandwidth/cache footprint),
enabling larger models (e.g. fitting more layers in a given memory budget).

Some public work along these lines:

* https://openai.com/blog/block-sparse-gpu-kernels/
* https://openai.com/blog/sparse-transformer/
* https://arxiv.org/abs/1802.08435
* https://arxiv.org/abs/1711.02782

Various groups have been able to successfully train models with reasonable
levels of sparsity (90%+) with marginal accuracy changes, which suggests
substantial speedups are possible (as this implies a >10x reduction in FLOPs).

It is fairly straightforward to realize these theoretical speedups, see e.g. TVM
benchmarks for Intel CPUs in
https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902, and CUDA
results in https://github.com/openai/blocksparse, etc.

* https://github.com/openai/blocksparse (CUDA)
* https://software.intel.com/en-us/mkl-developer-reference-c-mkl-bsrmm (MKL BSRM)
* https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html (SCIPY BSR representation)

This is extracted from an internal patch we've been using internally. There are
various extensions possible (int8/fp16/bf16, CUDA/other GPU architectures), but
this is a reasonable starting point. This needs more thorough unit test coverage
however.

We follow the conventions established by scipy.sparse.bsr_matrix and other
libraries, see the unit tests for details.

For folks interested in experimenting with scheduling/AutoTVM etc,
https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902 is a useful
starting point.
  • Loading branch information
ajtulloch authored and tqchen committed Jul 23, 2019
1 parent 2ed31b2 commit d6dcd6c
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
}
};

/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
};

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target):

reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type, target):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]

@reg.register_schedule("nn.sparse_dense")
def schedule_sparse_dense(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_sparse_dense(outputs)

reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# conv2d
def _find_conv2d_op(op):
Expand Down
33 changes: 33 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,39 @@ def batch_matmul(x, y):
"""
return _make.batch_matmul(x, y)

def sparse_dense(data, weight):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.
.. math::
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
where `as_dense` returns dense equivalent of the given sparse matrix.
See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
and
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
for more detail on the sparse matrix representation.
Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication
weight : namedtuple.
The sparse weight matrix for the matrix multiplication.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)


def contrib_conv2d_winograd_without_weight_transform(data,
weight,
Expand Down
97 changes: 97 additions & 0 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2018 by Contributors
* \file sparse.cc
* \brief Property def of nn.sparse_dense operator.
*/

#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/alter_op_layout.h"

namespace tvm {
namespace relay {

// relay.nn.sparse_dense
TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);

bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight_data = types[1].as<TensorTypeNode>();
CHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
const auto* weight_indptr = types[3].as<TensorTypeNode>();
if (data == nullptr) return false;

if (weight_data->shape.size() == 1) {
// CSR case.
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}

if (weight_data->shape.size() == 3) {
// BSR case.
Array<IndexExpr> oshape({
data->shape[0],
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
return false;
}

// Positional relay function to create dense operator used by frontend FFI.
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
auto attrs = make_node<SparseDenseAttrs>();
static const Op& op = Op::Get("nn.sparse_dense");
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
});

RELAY_REGISTER_OP("nn.sparse_dense")
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
.set_num_inputs(4)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
.set_support_level(1)
.add_type_rel("SparseDense", SparseDenseRel);

} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,23 @@ def schedule_l2_normalize(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)

@tvm.target.generic_func
def schedule_sparse_dense(outs):
"""Schedule for sparse_dense
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False)
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
103 changes: 103 additions & 0 deletions topi/python/topi/nn/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Sparse operators"""
from __future__ import absolute_import
import tvm

from ..util import get_const_tuple


@tvm.target.generic_func
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Parameters
----------
x : tvm.Tensor
2-D with shape [M, K], float32
weight_data : tvm.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
weight_indices : tvm.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
weight_indptr : tvm.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)
Returns
-------
output : tvm.Tensor
2-D with shape [M, N]
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm
return func(data, weight_data, weight_indices, weight_indptr)


def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
oshape = (
get_const_tuple(data.shape)[0],
get_const_tuple(weight_indptr.shape)[0] - 1)

def f(i, row):
row_start = weight_indptr[row]
row_end = weight_indptr[row + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx")
elem = row_start + elem_idx
a_val = weight_data[elem]
weight_val = data[i, weight_indices[elem]]
return tvm.sum(a_val * weight_val, axis=elem_idx)
return tvm.compute(oshape, f, tag="sparse_dense_csrmm")


def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape)
num_blocks = num_blocks_plus_1 - 1

def _compute_block(i, nb_j, j):
row_start = weight_indptr[nb_j]
row_end = weight_indptr[nb_j + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis(
(0, row_elems), name="elem_idx")
block_offset = row_start + elem_idx
c = tvm.reduce_axis((0, bs_c), name="c")
block_j = weight_indices[block_offset]
block_ij_val = weight_data[block_offset][j][c]
x_val = data[i, bs_c * block_j + c]
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])

bsrmm_block = tvm.compute(
(m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_bsrmm_block")
return tvm.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
tag="sparse_dense_bsrmm")
63 changes: 63 additions & 0 deletions topi/python/topi/x86/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""sparse_dense schedule on x86"""
import tvm

from .. import generic
from ..util import traverse_inline, get_const_int
from .util import get_fp32_len


@generic.schedule_sparse_dense.register(["cpu"])
def _schedule_sparse_dense(outs):
s = tvm.create_schedule([x.op for x in outs])

def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
(_, v_i) = s[op].op.axis
s[op].vectorize(v_i)
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[op].compute_at(s[outs[0]], y_o)
s[outs[0].op].vectorize(y_i)
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
y_reshape = op
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
bs_r = get_const_int(b_r.dom.extent)
(elem_idx, c) = s[y_bsrmm].op.reduce_axis
s[y_bsrmm].reorder(num_blocks, m, elem_idx, b_r, c)
s[y_bsrmm].vectorize(b_r)
(m_o, n_o) = s[y_reshape].op.axis
(noo, noi) = s[y_reshape].split(n_o, bs_r)
s[y_bsrmm].compute_at(s[y_reshape], noi)
s[y_reshape].vectorize(noi)
if op != s[outs[0]].op:
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[y_reshape].compute_at(s[outs[0]], y_o)
s[outs[0].op].parallel(y_o)
s[outs[0].op].vectorize(y_i)
else:
m_o_noo = s[y_reshape].fuse(m_o, noo)
s[y_reshape].parallel(m_o_noo)

traverse_inline(s, outs[0].op, _callback)
return s
Loading

0 comments on commit d6dcd6c

Please sign in to comment.