Skip to content

Commit

Permalink
finished
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Aug 15, 2024
1 parent 1256d54 commit 509d3f8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 24 deletions.
28 changes: 9 additions & 19 deletions torchao/sparsity/prototype/superblock/blocksparse_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,8 @@

@torch.library.custom_op("blocksparse::linear", mutates_args=())
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor:
shape = A.shape
A_2d = A.view(-1, shape[-1])
# custom_bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0])
# print(bias.shape)
# print(bias)
# breakpoint()
# weight_bsr_subclass = BlockSparseTensor(
# shape = torch.Size([M, K]),
# bsr_crow_indicies=crow_indices,
# bsr_col_indicies=col_indices,
# bsr_values=values,
# )
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
res_good = torch.nn.functional.linear(A, weight_bsr, bias)
return res_good
# res = bsr_dense_addmm(custom_bias, weight_bsr_subclass, A_2d.t())
# res = res.view(*shape[:-1], -1)
# print(torch.allclose(res_good, res))
# breakpoint()
# return res
return torch.nn.functional.linear(A, weight_bsr, bias)

# # Write the FakeTensor kernel
@torch.library.register_fake("blocksparse::linear")
Expand Down Expand Up @@ -175,3 +157,11 @@ def block_sparse_col_indices(func, types, args, kwargs):
@implements(aten._nnz.default)
def block_sparse__nnz(func, types, args, kwargs):
return args[0].bsr_values.shape[0]

@implements(torch.nn.functional.linear)
def block_sparse_linear(func, types, args, kwargs):
x, w, bias = args
crow_indicies = w.crow_indices()
col_indices = w.col_indices()
values = w.values()
return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], bias)
6 changes: 1 addition & 5 deletions torchao/sparsity/prototype/superblock/supermask.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,9 @@ def forward(self, x):
if not self.sparsify_weights:
subnet = self.get_mask()
w = (self.weight*self.scale+self.shift) * subnet
return F.linear(x, w, self.bias)
else:
w = self.weight
crow_indicies = w.crow_indices()
col_indices = w.col_indices()
values = w.values()
return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], self.bias)
return F.linear(x, w, self.bias)


class SupermaskConv2d(nn.Conv2d):
Expand Down

0 comments on commit 509d3f8

Please sign in to comment.