diff --git a/torchao/sparsity/prototype/superblock/blocksparse_subclass.py b/torchao/sparsity/prototype/superblock/blocksparse_subclass.py index 501fb3122..ced3fa48f 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse_subclass.py +++ b/torchao/sparsity/prototype/superblock/blocksparse_subclass.py @@ -14,26 +14,8 @@ @torch.library.custom_op("blocksparse::linear", mutates_args=()) def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: - shape = A.shape - A_2d = A.view(-1, shape[-1]) - # custom_bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0]) - # print(bias.shape) - # print(bias) - # breakpoint() - # weight_bsr_subclass = BlockSparseTensor( - # shape = torch.Size([M, K]), - # bsr_crow_indicies=crow_indices, - # bsr_col_indicies=col_indices, - # bsr_values=values, - # ) weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) - res_good = torch.nn.functional.linear(A, weight_bsr, bias) - return res_good - # res = bsr_dense_addmm(custom_bias, weight_bsr_subclass, A_2d.t()) - # res = res.view(*shape[:-1], -1) - # print(torch.allclose(res_good, res)) - # breakpoint() - # return res + return torch.nn.functional.linear(A, weight_bsr, bias) # # Write the FakeTensor kernel @torch.library.register_fake("blocksparse::linear") @@ -175,3 +157,11 @@ def block_sparse_col_indices(func, types, args, kwargs): @implements(aten._nnz.default) def block_sparse__nnz(func, types, args, kwargs): return args[0].bsr_values.shape[0] + +@implements(torch.nn.functional.linear) +def block_sparse_linear(func, types, args, kwargs): + x, w, bias = args + crow_indicies = w.crow_indices() + col_indices = w.col_indices() + values = w.values() + return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], bias) diff --git a/torchao/sparsity/prototype/superblock/supermask.py b/torchao/sparsity/prototype/superblock/supermask.py index c053caa3d..50b212818 100644 --- a/torchao/sparsity/prototype/superblock/supermask.py +++ b/torchao/sparsity/prototype/superblock/supermask.py @@ -130,13 +130,9 @@ def forward(self, x): if not self.sparsify_weights: subnet = self.get_mask() w = (self.weight*self.scale+self.shift) * subnet - return F.linear(x, w, self.bias) else: w = self.weight - crow_indicies = w.crow_indices() - col_indices = w.col_indices() - values = w.values() - return torch.ops.blocksparse.linear(x, crow_indicies, col_indices, values, w.shape[0], w.shape[1], self.bias) + return F.linear(x, w, self.bias) class SupermaskConv2d(nn.Conv2d):