Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] support hypermodule: autoactivation #3868

Merged
merged 10 commits into from
Jul 15, 2021
9 changes: 9 additions & 0 deletions docs/en_US/NAS/Hypermodules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Hypermodules
============

Hypermodule is a (PyTorch) module which contains many architecture/hyperparameter candidates for this module. By using hypermodule in user defined model, NNI will help users automatically find the best architecture/hyperparameter of the hypermodules for this model. This follows the design philosophy of Retiarii that users write DNN model as a space.

There has been proposed some hypermodules in NAS community, such as AutoActivation, AutoDropout. Some of them are implemented in the Retiarii framework.

.. autoclass:: nni.retiarii.nn.pytorch.AutoActivation
:members:
3 changes: 2 additions & 1 deletion docs/en_US/NAS/construct_space.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ NNI provides powerful APIs for users to easily express model space (or search sp
:maxdepth: 1

Mutation Primitives <MutationPrimitives>
Customize Mutators <Mutators>
Customize Mutators <Mutators>
Hypermodule Lib <Hypermodules>
1 change: 1 addition & 0 deletions nni/retiarii/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .api import *
from .component import *
from .nn import *
from .hypermodule import *
249 changes: 249 additions & 0 deletions nni/retiarii/nn/pytorch/hypermodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn

from nni.retiarii.serializer import basic_unit

from .api import LayerChoice
from ...utils import version_larger_equal

__all__ = ['AutoActivation']

TorchVersion = '1.5.0'

# ============== unary function modules ==============

@basic_unit
class UnaryIdentity(nn.Module):
def forward(self, x):
return x

@basic_unit
class UnaryNegative(nn.Module):
def forward(self, x):
return -x

@basic_unit
class UnaryAbs(nn.Module):
def forward(self, x):
return torch.abs(x)

@basic_unit
class UnarySquare(nn.Module):
def forward(self, x):
return torch.square(x)

@basic_unit
class UnaryPow(nn.Module):
def forward(self, x):
return torch.pow(x, 3)

@basic_unit
class UnarySqrt(nn.Module):
def forward(self, x):
return torch.sqrt(x)

@basic_unit
class UnaryMul(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x * self.beta

@basic_unit
class UnaryAdd(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x + self.beta

@basic_unit
class UnaryLogAbs(nn.Module):
def forward(self, x):
return torch.log(torch.abs(x) + 1e-7)

@basic_unit
class UnaryExp(nn.Module):
def forward(self, x):
return torch.exp(x)

@basic_unit
class UnarySin(nn.Module):
def forward(self, x):
return torch.sin(x)

@basic_unit
class UnaryCos(nn.Module):
def forward(self, x):
return torch.cos(x)

@basic_unit
class UnarySinh(nn.Module):
def forward(self, x):
return torch.sinh(x)

@basic_unit
class UnaryCosh(nn.Module):
def forward(self, x):
return torch.cosh(x)

@basic_unit
class UnaryTanh(nn.Module):
def forward(self, x):
return torch.tanh(x)

if not version_larger_equal(torch.__version__, TorchVersion):
@basic_unit
class UnaryAsinh(nn.Module):
def forward(self, x):
return torch.asinh(x)

@basic_unit
class UnaryAtan(nn.Module):
def forward(self, x):
return torch.atan(x)

if not version_larger_equal(torch.__version__, TorchVersion):
@basic_unit
class UnarySinc(nn.Module):
def forward(self, x):
return torch.sinc(x)

@basic_unit
class UnaryMax(nn.Module):
def forward(self, x):
return torch.max(x, torch.zeros_like(x))

@basic_unit
class UnaryMin(nn.Module):
def forward(self, x):
return torch.min(x, torch.zeros_like(x))

@basic_unit
class UnarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x)

@basic_unit
class UnaryLogExp(nn.Module):
def forward(self, x):
return torch.log(1 + torch.exp(x))

@basic_unit
class UnaryExpSquare(nn.Module):
def forward(self, x):
return torch.exp(-torch.square(x))

@basic_unit
class UnaryErf(nn.Module):
def forward(self, x):
return torch.erf(x)

unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow',
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos',
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax',
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf']

if not version_larger_equal(torch.__version__, TorchVersion):
unary_modules.append('UnaryAsinh')
unary_modules.append('UnarySinc')

# ============== binary function modules ==============

@basic_unit
class BinaryAdd(nn.Module):
def forward(self, x):
return x[0] + x[1]

@basic_unit
class BinaryMul(nn.Module):
def forward(self, x):
return x[0] * x[1]

@basic_unit
class BinaryMinus(nn.Module):
def forward(self, x):
return x[0] - x[1]

@basic_unit
class BinaryDivide(nn.Module):
def forward(self, x):
return x[0] / (x[1] + 1e-7)

@basic_unit
class BinaryMax(nn.Module):
def forward(self, x):
return torch.max(x[0], x[1])

@basic_unit
class BinaryMin(nn.Module):
def forward(self, x):
return torch.min(x[0], x[1])

@basic_unit
class BinarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x[0]) * x[1]

@basic_unit
class BinaryExpSquare(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.square(x[0] - x[1]))

@basic_unit
class BinaryExpAbs(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.abs(x[0] - x[1]))

@basic_unit
class BinaryParamAdd(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return self.beta * x[0] + (1 - self.beta) * x[1]

binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax',
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd']


class AutoActivation(nn.Module):
"""
This module is an implementation of the paper "Searching for Activation Functions"
(https://arxiv.org/abs/1710.05941).
NOTE: current `beta` is not per-channel parameter

Parameters
----------
unit_num : int
the number of core units
"""
def __init__(self, unit_num = 1):
super().__init__()
self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules])
for _ in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules])
self.unaries.append(one_unary)
for _ in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules])
self.binaries.append(one_binary)

def forward(self, x):
out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries):
out = binary(torch.stack([out, unary(x)]))
return out
4 changes: 2 additions & 2 deletions nni/retiarii/operation_def/torch_op_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val
# TODO: deal with all the types
if self.parameters['type'] == 'None':
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
Expand Down Expand Up @@ -171,7 +171,7 @@ class AtenTensors(PyTorchOperation):
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros']
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']

def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
Expand Down
35 changes: 35 additions & 0 deletions test/ut/retiarii/test_highlevel_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,23 @@ def forward(self, x):
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))

def test_autoactivation(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.act = nn.AutoActivation()

def forward(self, x):
return self.act(x)

raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 10]))


class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir):
Expand Down Expand Up @@ -544,3 +561,21 @@ def forward(self, x):
except InvalidMutation:
continue
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can write one in the base class and Python will automatically inherit it.

def test_autoactivation(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.act = nn.AutoActivation()

def forward(self, x):
return self.act(x)

raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 10]))