Skip to content

Commit

Permalink
Update sparse semi-structured linear operator
Browse files Browse the repository at this point in the history
ghstack-source-id: 21c50a18fea6990437f99a4f0b2aba1cb7532a80
Pull Request resolved: #104608
  • Loading branch information
alexsamardzic committed Jul 6, 2023
1 parent 28363dd commit 5598f45
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 264 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3233,9 +3233,9 @@
MkldnnCPU: mkldnn_linear_backward
autogen: mkldnn_linear_backward.out

- func: _structured_sparse_linear(Tensor input, Tensor weight, Tensor mask_or_meta, *, Tensor? bias=None, str? activation=None) -> (Tensor, Tensor)
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None) -> Tensor
dispatch:
CUDA: _structured_sparse_linear
CUDA: _sparse_semi_structured_linear

- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor

Expand Down

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions aten/src/ATen/native/sparse/cuda/cutlass/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
This directory contains files from CUTLASS 3.1 source tree, modified
for the purpose of `_structured_sparse_linear()` implementation. This
method implements linear operator with weight matrix $W$, bias vector
$b$, and input tensor $x$ as arguments:
for the purpose of `_sparse_semi_structured_linear()` implementation.
This method implements linear operator with weight matrix $W$, bias
vector $b$, and input tensor $x$ as arguments:

$$y=xW^{T}+b$$

where the matrix $W$ is a structured sparse matrix. Since CUTLASS
support sparse GEMM operation only when the first operand is in
structured sparse format, the operation above is actually implemented
in `_structured_sparse_linear()` as follows:
in `_sparse_semi_structured_linear()` as follows:

$$y=(Wx^{T}+b^{T})^{T}$$

Expand Down
2 changes: 1 addition & 1 deletion docs/source/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ To use these ops, simply pass the output of ``to_sparse_semi_structured(tensor)`
>>> torch.allclose(c, torch.mm(a_sparse, b))
True

Under the hood, SparseSemiStructuredTensor will call ``torch._structured_sparse_linear`` for accelerated inference using CUTLASS sparse kernels.
Under the hood, SparseSemiStructuredTensor will call ``torch._sparse_semi_structured_linear`` for accelerated inference using CUTLASS sparse kernels.

Accelerating nn.Linear with semi-structured sparsity
----------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ aten::_sparse_mask_projection
aten::_sparse_mask_projection.out
aten::_sparse_mm_reduce_impl
aten::_sparse_mm_reduce_impl_backward
aten::_sparse_semi_structured_linear
aten::_sparse_softmax
aten::_sparse_softmax.out
aten::_sparse_softmax_backward_data
Expand All @@ -490,7 +491,6 @@ aten::_standard_gamma
aten::_standard_gamma.out
aten::_standard_gamma_grad
aten::_standard_gamma_grad.out
aten::_structured_sparse_linear
aten::_test_autograd_multiple_dispatch.fullcoverage
aten::_test_autograd_multiple_dispatch.fullcoverage_out
aten::_test_autograd_multiple_dispatch_view
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
("aten::to_sparse_bsr.out", datetime.date(2023, 12, 31)),
("aten::to_sparse_csc.out", datetime.date(2023, 12, 31)),
("aten::to_sparse_csr.out", datetime.date(2023, 12, 31)),
("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)),
("aten::_sparse_semi_structured_linear", datetime.date(2023, 12, 31)),

]

Expand Down
13 changes: 6 additions & 7 deletions test/test_sparse_semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def rand_dense_2by4(r, c, dtype, device, choice=None):
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return (dense, mask)
return dense

def rand_dense_2by4_all_patterns(r, c, dtype, device):
choices = [
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_linear(self, inference_mode, device):

dense_result = model(input)

model.weight = nn.Parameter(SparseSemiStructuredTensor(model.weight))
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))

if inference_mode:
with torch.inference_mode():
Expand Down Expand Up @@ -275,7 +275,7 @@ def test_unsupported_dim(self, device):
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_linear_cutlass(self, device, dtype):
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
weight, mask = rand_dense_2by4(m, k, dtype, device)
weight = rand_dense_2by4(m, k, dtype, device)
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None

Expand All @@ -293,10 +293,9 @@ def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activatio

weight_sparse = weight.masked_select(weight != 0).view(m, k // 2)

output1, meta = torch._structured_sparse_linear(input, weight_sparse, mask, bias=bias, activation=activation)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
meta = to_sparse_semi_structured(weight).indices()

output1, _ = torch._structured_sparse_linear(input, weight_sparse, meta, bias=bias, activation=activation)
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)

batch_shapes = [[], [3], [3, 1]]
Expand All @@ -320,7 +319,7 @@ def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activatio
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_conversions(self, device, dtype):
def run_test(r, c, device, dtype):
dense_ref, _ = rand_dense_2by4(r, c, dtype, device)
dense_ref = rand_dense_2by4(r, c, dtype, device)

compressed = to_sparse_semi_structured(dense_ref)

Expand Down
12 changes: 6 additions & 6 deletions torch/sparse/semi_structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata.
This subclass also overrides __torch_dispatch__ to use _structured_sparse_linear for faster matrix multiplications
This subclass also overrides __torch_dispatch__ to use _sparse_semi_structured_linear for faster matrix multiplications
via sparse CUTLASS kernels. In the future we will also call into cuSPARSELt kernels for more performance gains.
"""

Expand Down Expand Up @@ -219,7 +219,7 @@ def __repr__(self) -> str:

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
"""Overload __torch_dispatch__ to use torch._structured_sparse_linear.
"""Overload __torch_dispatch__ to use torch._sparse_semi_structured_linear.
`torch.structured_sparse_linear` uses accelerated sparse CUTLASS kernels.
In the future we plan to also add in support for cuSPARSELt kernels.
Expand Down Expand Up @@ -270,7 +270,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
# F.linear(x) = addmm(bias, input, weight.t()) = b + xW' = (b + xW')''
# = (W''x' + b')' = (Wx' + b')' = addmm(bias.T, weight, input).T
if isinstance(input_B, cls) and input_B.transposed:
result, _ = torch._structured_sparse_linear(
result = torch._sparse_semi_structured_linear(
input_A, input_B.values(), input_B.indices(), bias=bias
)
return result
Expand All @@ -280,13 +280,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
input_A, input_B = args

if isinstance(input_A, cls) and not input_A.transposed:
transposed_result, _ = torch._structured_sparse_linear(
transposed_result = torch._sparse_semi_structured_linear(
input_B.t(), input_A.values(), input_A.indices()
)
return transposed_result.t()

elif isinstance(input_B, cls) and input_B.transposed:
result, _ = torch._structured_sparse_linear(
result = torch._sparse_semi_structured_linear(
input_A, input_B.values(), input_B.indices()
)
return result
Expand All @@ -297,7 +297,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func is torch.ops.aten.linear.default:
input_tensor, weight, bias = args
if isinstance(weight, cls):
result, _ = torch._structured_sparse_linear(
result = torch._sparse_semi_structured_linear(
input_tensor, weight.values(), weight.indices(), bias=bias
)
return result
Expand Down

0 comments on commit 5598f45

Please sign in to comment.