Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use TorchDispatchMode for flops calculation in module summary #81

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 158 additions & 160 deletions tests/tools/test_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,64 @@

import torch
import torchvision.models as models
from torcheval.tools.flops import FlopTensor, instrument_module, start_counting
from torcheval.tools.flops import FlopTensorDispatchMode


class ModuleSummaryTest(unittest.TestCase):
def test_torch_operations(self) -> None:
"""Make sure FLOPs calculation works for a single operations."""
inp = torch.randn(10, 4, 5)
bmm_mat = torch.randn(10, 5, 7)
mm_mat = torch.randn(7, 3)

inp = FlopTensor(inp)
start_counting()
res = inp.bmm(bmm_mat).matmul(mm_mat)

self.assertEqual(res.shape[0], 10)
self.assertEqual(res.shape[1], 4)
self.assertEqual(res.shape[2], 3)

self.assertEqual(
FlopTensor.flop_counts[""].get("bmm.default", 0)
+ FlopTensor.flop_counts[""].get("bmm", 0),
1400,
)
self.assertEqual(
FlopTensor.flop_counts[""].get("mm.default", 0)
+ FlopTensor.flop_counts[""].get("mm", 0),
840,
)

inp = torch.randn(10, 4, 5)
# pyre-fixme[28]: Unexpected keyword argument `requires_grad`.
inp = torch.autograd.Variable(inp, requires_grad=True)

# pyre-fixme[6]: For 1st param expected `Tensor` but got `Variable`.
inp = FlopTensor(inp)
start_counting()
res = inp.bmm(bmm_mat).matmul(mm_mat)
res.mean().backward()

self.assertEqual(res.shape[0], 10)
self.assertEqual(res.shape[1], 4)
self.assertEqual(res.shape[2], 3)

self.assertEqual(
FlopTensor.flop_counts[""].get("bmm.default", 0)
+ FlopTensor.flop_counts[""].get("bmm", 0),
2800,
)
self.assertEqual(
FlopTensor.flop_counts[""].get("mm.default", 0)
+ FlopTensor.flop_counts[""].get("mm", 0),
1680,
)
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
bmm_mat = torch.randn(10, 5, 7)
mm_mat = torch.randn(7, 3)
return x.bmm(bmm_mat).matmul(mm_mat)

test_module = TestModule()
with FlopTensorDispatchMode(test_module) as ftdm:
inp = torch.randn(10, 4, 5)
res = test_module(inp)

self.assertEqual(res.shape[0], 10)
self.assertEqual(res.shape[1], 4)
self.assertEqual(res.shape[2], 3)

self.assertEqual(
ftdm.flop_counts[""].get("bmm.default", 0)
+ ftdm.flop_counts[""].get("bmm", 0),
1400,
)
self.assertEqual(
ftdm.flop_counts[""].get("mm.default", 0)
+ ftdm.flop_counts[""].get("mm", 0),
840,
)

inp = torch.randn(10, 4, 5)
# pyre-fixme[28]: Unexpected keyword argument `requires_grad`.
inp = torch.autograd.Variable(inp, requires_grad=True)

ftdm.reset()
res = test_module(inp)
res.mean().backward()

self.assertEqual(res.shape[0], 10)
self.assertEqual(res.shape[1], 4)
self.assertEqual(res.shape[2], 3)

self.assertEqual(
ftdm.flop_counts[""].get("bmm.default", 0)
+ ftdm.flop_counts[""].get("bmm", 0),
2800,
)
self.assertEqual(
ftdm.flop_counts[""].get("mm.default", 0)
+ ftdm.flop_counts[""].get("mm", 0),
1680,
)

def test_torch_linear_layer(self) -> None:
"""Make sure FLOPs calculation works for a module consists of linear layers."""
Expand All @@ -70,122 +75,115 @@ def test_torch_linear_layer(self) -> None:
torch.nn.Linear(5, 1),
)
inp = torch.randn(1, 10)
inp = FlopTensor(inp)

all_hooks = []
instrument_module(lnn, all_hooks, "")
self.assertEqual(len(all_hooks), 8)

start_counting()
res = lnn(inp)
self.assertEqual(
FlopTensor.flop_counts[""].get("addmm.default", 0)
+ FlopTensor.flop_counts[""].get("addmm", 0),
1055,
)
self.assertEqual(
FlopTensor.flop_counts["0"].get("addmm.default", 0)
+ FlopTensor.flop_counts["0"].get("addmm", 0),
1050,
)
self.assertEqual(
FlopTensor.flop_counts["0.0"].get("addmm.default", 0)
+ FlopTensor.flop_counts["0.0"].get("addmm", 0),
700,
)
self.assertEqual(
FlopTensor.flop_counts["0.1"].get("addmm.default", 0)
+ FlopTensor.flop_counts["0.1"].get("addmm", 0),
350,
)
self.assertEqual(
FlopTensor.flop_counts["1"].get("addmm.default", 0)
+ FlopTensor.flop_counts["1"].get("addmm", 0),
5,
)
start_counting()
res.backward()
self.assertEqual(
FlopTensor.flop_counts[""].get("mm.default", 0)
+ FlopTensor.flop_counts[""].get("mm", 0),
1410,
)
self.assertEqual(
FlopTensor.flop_counts["0"].get("mm.default", 0)
+ FlopTensor.flop_counts["0"].get("mm", 0),
1400,
)
self.assertEqual(
FlopTensor.flop_counts["0.0"].get("mm.default", 0)
+ FlopTensor.flop_counts["0.0"].get("mm", 0),
700,
)
self.assertEqual(
FlopTensor.flop_counts["0.1"].get("mm.default", 0)
+ FlopTensor.flop_counts["0.1"].get("mm", 0),
700,
)
self.assertEqual(
FlopTensor.flop_counts["1"].get("mm.default", 0)
+ FlopTensor.flop_counts["1"].get("mm", 0),
10,
)

with FlopTensorDispatchMode(lnn) as ftdm:
self.assertEqual(len(ftdm._all_hooks), 8)

res = lnn(inp)
self.assertEqual(
ftdm.flop_counts[""].get("addmm.default", 0)
+ ftdm.flop_counts[""].get("addmm", 0),
1055,
)
self.assertEqual(
ftdm.flop_counts["0"].get("addmm.default", 0)
+ ftdm.flop_counts["0"].get("addmm", 0),
1050,
)
self.assertEqual(
ftdm.flop_counts["0.0"].get("addmm.default", 0)
+ ftdm.flop_counts["0.0"].get("addmm", 0),
700,
)
self.assertEqual(
ftdm.flop_counts["0.1"].get("addmm.default", 0)
+ ftdm.flop_counts["0.1"].get("addmm", 0),
350,
)
self.assertEqual(
ftdm.flop_counts["1"].get("addmm.default", 0)
+ ftdm.flop_counts["1"].get("addmm", 0),
5,
)
ftdm.reset()
res.backward()
self.assertEqual(
ftdm.flop_counts[""].get("mm.default", 0)
+ ftdm.flop_counts[""].get("mm", 0),
1410,
)
self.assertEqual(
ftdm.flop_counts["0"].get("mm.default", 0)
+ ftdm.flop_counts["0"].get("mm", 0),
1400,
)
self.assertEqual(
ftdm.flop_counts["0.0"].get("mm.default", 0)
+ ftdm.flop_counts["0.0"].get("mm", 0),
700,
)
self.assertEqual(
ftdm.flop_counts["0.1"].get("mm.default", 0)
+ ftdm.flop_counts["0.1"].get("mm", 0),
700,
)
self.assertEqual(
ftdm.flop_counts["1"].get("mm.default", 0)
+ ftdm.flop_counts["1"].get("mm", 0),
10,
)

def test_torch_pretrained_module(self) -> None:
"""Make sure FLOPs calculation works for a resnet18."""
# pyre-fixme[16]: Module `models` has no attribute `resnet18`.
mod = models.resnet18()
inp = torch.randn(1, 3, 224, 224)
all_hooks = []
instrument_module(mod, all_hooks, "")
# Hooks should be 2 * number of modules minus 2 (2 for the model itself)
self.assertEqual(len(all_hooks), 2 * len(list(mod.modules())) - 2)

inp = FlopTensor(inp)
start_counting()
res = mod(inp)

self.assertEqual(
FlopTensor.flop_counts[""].get("convolution.default", 0)
+ FlopTensor.flop_counts[""].get("convolution", 0),
1813561344,
)
self.assertEqual(
FlopTensor.flop_counts[""].get("addmm.default", 0)
+ FlopTensor.flop_counts[""].get("addmm", 0),
512000,
)
self.assertEqual(
FlopTensor.flop_counts["conv1"].get("convolution.default", 0)
+ FlopTensor.flop_counts["conv1"].get("convolution", 0),
118013952,
)
self.assertEqual(
FlopTensor.flop_counts["fc"].get("addmm.default", 0)
+ FlopTensor.flop_counts["fc"].get("addmm", 0),
512000,
)

start_counting()
res.mean().backward()

self.assertEqual(
FlopTensor.flop_counts[""].get("convolution_backward.default", 0)
+ FlopTensor.flop_counts[""].get("convolution_backward", 0),
3509108736,
)
self.assertEqual(
FlopTensor.flop_counts[""].get("mm.default", 0)
+ FlopTensor.flop_counts[""].get("mm", 0),
1024000,
)
self.assertEqual(
FlopTensor.flop_counts["layer1"].get("convolution_backward.default", 0)
+ FlopTensor.flop_counts["layer1"].get("convolution_backward", 0),
924844032,
)
self.assertEqual(
FlopTensor.flop_counts["fc"].get("mm.default", 0)
+ FlopTensor.flop_counts["fc"].get("mm", 0),
1024000,
)
with FlopTensorDispatchMode(mod) as ftdm:
# Hooks should be 2 * number of modules minus 2 (2 for the model itself)
self.assertEqual(len(ftdm._all_hooks), 2 * len(list(mod.modules())) - 2)
res = mod(inp)

self.assertEqual(
ftdm.flop_counts[""].get("convolution.default", 0)
+ ftdm.flop_counts[""].get("convolution", 0),
1813561344,
)
self.assertEqual(
ftdm.flop_counts[""].get("addmm.default", 0)
+ ftdm.flop_counts[""].get("addmm", 0),
512000,
)
self.assertEqual(
ftdm.flop_counts["conv1"].get("convolution.default", 0)
+ ftdm.flop_counts["conv1"].get("convolution", 0),
118013952,
)
self.assertEqual(
ftdm.flop_counts["fc"].get("addmm.default", 0)
+ ftdm.flop_counts["fc"].get("addmm", 0),
512000,
)

ftdm.reset()
res.mean().backward()

self.assertEqual(
ftdm.flop_counts[""].get("convolution_backward.default", 0)
+ ftdm.flop_counts[""].get("convolution_backward", 0),
3509108736,
)
self.assertEqual(
ftdm.flop_counts[""].get("mm.default", 0)
+ ftdm.flop_counts[""].get("mm", 0),
1024000,
)
self.assertEqual(
ftdm.flop_counts["layer1"].get("convolution_backward.default", 0)
+ ftdm.flop_counts["layer1"].get("convolution_backward", 0),
924844032,
)
self.assertEqual(
ftdm.flop_counts["fc"].get("mm.default", 0)
+ ftdm.flop_counts["fc"].get("mm", 0),
1024000,
)
Loading