Skip to content

Commit

Permalink
working now
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Aug 15, 2024
1 parent 1fc9edf commit 1256d54
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
2 changes: 1 addition & 1 deletion torchao/sparsity/prototype/superblock/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def apply_sparsity(model):
module.sparsify_offline()


def apply_bsr(model, blocksize):
def apply_bsr(model, blocksize=64):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "mlp" in name:
try:
Expand Down
29 changes: 18 additions & 11 deletions torchao/sparsity/prototype/superblock/blocksparse_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,24 @@
def blocksparse_linear(A: torch.Tensor, crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, M: int, K: int, bias: torch.Tensor) -> torch.Tensor:
shape = A.shape
A_2d = A.view(-1, shape[-1])
bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0])
weight_bsr = BlockSparseTensor(
shape = torch.Size([M, K]),
bsr_crow_indicies=crow_indices,
bsr_col_indicies=col_indices,
bsr_values=values,
)
res = bsr_dense_addmm(bias, weight_bsr, A_2d.t())
res = res.view(*shape[:-1], -1)
return res
# return bsr_dense_addmm(bias, weight, A_2d)
# custom_bias = bias.unsqueeze(1).expand(-1, A_2d.shape[0])
# print(bias.shape)
# print(bias)
# breakpoint()
# weight_bsr_subclass = BlockSparseTensor(
# shape = torch.Size([M, K]),
# bsr_crow_indicies=crow_indices,
# bsr_col_indicies=col_indices,
# bsr_values=values,
# )
weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
res_good = torch.nn.functional.linear(A, weight_bsr, bias)
return res_good
# res = bsr_dense_addmm(custom_bias, weight_bsr_subclass, A_2d.t())
# res = res.view(*shape[:-1], -1)
# print(torch.allclose(res_good, res))
# breakpoint()
# return res

# # Write the FakeTensor kernel
@torch.library.register_fake("blocksparse::linear")
Expand Down
18 changes: 9 additions & 9 deletions torchao/sparsity/prototype/superblock/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from supermask import apply_supermask, SupermaskLinear

# from benchmark import apply_sparsity, apply_bsr, verify_sparsity

def apply_sparsity(model):
for name, module in model.named_modules():
Expand Down Expand Up @@ -82,16 +82,16 @@ def load_data(valdir, args):
)

# for META internal
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
# for OSS
# dataset_test = torchvision.datasets.ImageNet(
# dataset_test = torchvision.datasets.ImageFolder(
# valdir,
# split='val',
# transform=preprocessing
# preprocessing,
# )
#for OSS
dataset_test = torchvision.datasets.ImageNet(
valdir,
split='val',
transform=preprocessing
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
Expand Down

0 comments on commit 1256d54

Please sign in to comment.