Skip to content

Commit

Permalink
Add upsample_bilinear2d, unify norms, and bump version to 0.0.3
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhijian Liu committed Apr 15, 2021
1 parent a2f3fa3 commit b843c30
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 35 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from setuptools import find_packages, setup

from torchprofile import __version__

setup(
Expand Down
22 changes: 0 additions & 22 deletions test.py

This file was deleted.

32 changes: 20 additions & 12 deletions torchprofile/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,16 @@ def convolution(node):
return math.prod(os) * ic * math.prod(ks)


def batch_norm(node):
# TODO: provide an option to not fuse `batch_norm` into `linear` or `conv`
return 0

def norm(node):
if node.operator in ['aten::batch_norm', 'aten::instance_norm']:
affine = node.inputs[1].shape is not None
elif node.operator in ['aten::layer_norm', 'aten::group_norm']:
affine = node.inputs[2].shape is not None
else:
raise ValueError(node.operator)

def instance_norm_or_layer_norm(node):
os = node.outputs[0].shape
return math.prod(os)
return math.prod(os) if affine else 0


def avg_pool_or_mean(node):
Expand All @@ -91,29 +93,35 @@ def leaky_relu(node):
return math.prod(os)


def upsample_bilinear2d(node):
os = node.outputs[0].shape
return math.prod(os) * 4


handlers = (
('aten::addmm', addmm),
('aten::addmv', addmv),
('aten::bmm', bmm),
('aten::matmul', matmul),
(('aten::mul', 'aten::mul_'), mul),
('aten::_convolution', convolution),
('aten::batch_norm', batch_norm),
(('aten::instance_norm', 'aten::layer_norm'), instance_norm_or_layer_norm),
(('aten::batch_norm', 'aten::instance_norm', 'aten::layer_norm',
'aten::group_norm'), norm),
(('aten::adaptive_avg_pool1d', 'aten::adaptive_avg_pool2d',
'aten::adaptive_avg_pool3d', 'aten::avg_pool1d', 'aten::avg_pool2d',
'aten::avg_pool3d', 'aten::mean'), avg_pool_or_mean),
('aten::leaky_relu', leaky_relu),
('aten::upsample_bilinear2d', upsample_bilinear2d),
(('aten::adaptive_max_pool1d', 'aten::adaptive_max_pool2d',
'aten::adaptive_max_pool3d', 'aten::add', 'aten::add_',
'aten::alpha_dropout', 'aten::cat', 'aten::chunk', 'aten::clamp',
'aten::clone', 'aten::constant_pad_nd', 'aten::contiguous',
'aten::detach', 'aten::div', 'aten::div_', 'aten::dropout',
'aten::dropout_', 'aten::embedding', 'aten::eq', 'aten::feature_dropout',
'aten::flatten', 'aten::floor', 'aten::gt', 'aten::hardtanh_',
'aten::index', 'aten::int', 'aten::log_softmax', 'aten::lt',
'aten::max_pool1d', 'aten::max_pool1d_with_indices', 'aten::max_pool2d',
'aten::max_pool2d_with_indices', 'aten::max_pool3d',
'aten::flatten', 'aten::floor', 'aten::floor_divide', 'aten::gt',
'aten::hardtanh_', 'aten::index', 'aten::int', 'aten::log_softmax',
'aten::lt', 'aten::max_pool1d', 'aten::max_pool1d_with_indices',
'aten::max_pool2d', 'aten::max_pool2d_with_indices', 'aten::max_pool3d',
'aten::max_pool3d_with_indices', 'aten::max_unpool1d',
'aten::max_unpool2d', 'aten::max_unpool3d', 'aten::ne',
'aten::reflection_pad1d', 'aten::reflection_pad2d',
Expand Down
2 changes: 1 addition & 1 deletion torchprofile/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.2'
__version__ = '0.0.3'

0 comments on commit b843c30

Please sign in to comment.