From 1256d549f614cd9d1af5e5f8cae6e44a1690cc38 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Wed, 14 Aug 2024 17:55:41 -0700 Subject: [PATCH] working now --- .../prototype/superblock/benchmark.py | 2 +- .../superblock/blocksparse_subclass.py | 29 ++++++++++++------- .../sparsity/prototype/superblock/evaluate.py | 18 ++++++------ 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py index 38ec65727..1944bceeb 100644 --- a/torchao/sparsity/prototype/superblock/benchmark.py +++ b/torchao/sparsity/prototype/superblock/benchmark.py @@ -25,7 +25,7 @@ def apply_sparsity(model): module.sparsify_offline() -def apply_bsr(model, blocksize): +def apply_bsr(model, blocksize=64): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and "mlp" in name: try: diff --git a/torchao/sparsity/prototype/superblock/blocksparse_subclass.py b/torchao/sparsity/prototype/superblock/blocksparse_subclass.py index 82d19aa08..501fb3122 100644 --- a/torchao/sparsity/prototype/superblock/blocksparse_subclass.py +++ b/torchao/sparsity/prototype/superblock/blocksparse_subclass.py @@ -16,17 +16,24 @@ def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor: shape = A.shape A_2d = A.view(-1, shape[-1]) - bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0]) - weight_bsr = BlockSparseTensor( - shape = torch.Size([M, K]), - bsr_crow_indicies=crow_indices, - bsr_col_indicies=col_indices, - bsr_values=values, - ) - res = bsr_dense_addmm(bias, weight_bsr, A_2d.t()) - res = res.view(*shape[:-1], -1) - return res - # return bsr_dense_addmm(bias, weight, A_2d) + # custom_bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0]) + # print(bias.shape) + # print(bias) + # breakpoint() + # weight_bsr_subclass = BlockSparseTensor( + # shape = torch.Size([M, K]), + # bsr_crow_indicies=crow_indices, + # bsr_col_indicies=col_indices, + # bsr_values=values, + # ) + weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K)) + res_good = torch.nn.functional.linear(A, weight_bsr, bias) + return res_good + # res = bsr_dense_addmm(custom_bias, weight_bsr_subclass, A_2d.t()) + # res = res.view(*shape[:-1], -1) + # print(torch.allclose(res_good, res)) + # breakpoint() + # return res # # Write the FakeTensor kernel @torch.library.register_fake("blocksparse::linear") diff --git a/torchao/sparsity/prototype/superblock/evaluate.py b/torchao/sparsity/prototype/superblock/evaluate.py index 23e825f65..8811858b0 100644 --- a/torchao/sparsity/prototype/superblock/evaluate.py +++ b/torchao/sparsity/prototype/superblock/evaluate.py @@ -15,7 +15,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from supermask import apply_supermask, SupermaskLinear - +# from benchmark import apply_sparsity, apply_bsr, verify_sparsity def apply_sparsity(model): for name, module in model.named_modules(): @@ -82,16 +82,16 @@ def load_data(valdir, args): ) # for META internal - dataset_test = torchvision.datasets.ImageFolder( - valdir, - preprocessing, - ) - # for OSS - # dataset_test = torchvision.datasets.ImageNet( + # dataset_test = torchvision.datasets.ImageFolder( # valdir, - # split='val', - # transform=preprocessing + # preprocessing, # ) + #for OSS + dataset_test = torchvision.datasets.ImageNet( + valdir, + split='val', + transform=preprocessing + ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path))