Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] support hypermodule: autoactivation #3868

Merged
merged 10 commits into from
Jul 15, 2021
9 changes: 9 additions & 0 deletions docs/en_US/NAS/Hypermodules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Hypermodules
============

Hypermodule is a (PyTorch) module which contains many architecture/hyperparameter candidates for this module. By using hypermodule in user defined model, NNI will help users automatically find the best architecture/hyperparameter of the hypermodules for this model. This follows the design philosophy of Retiarii that users write DNN model as a space.

There has been proposed some hypermodules in NAS community, such as AutoActivation, AutoDropout. Some of them are implemented in the Retiarii framework.

.. autoclass:: nni.retiarii.nn.pytorch.AutoActivation
:members:
3 changes: 2 additions & 1 deletion docs/en_US/NAS/construct_space.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ NNI provides powerful APIs for users to easily express model space (or search sp
:maxdepth: 1

Mutation Primitives <MutationPrimitives>
Customize Mutators <Mutators>
Customize Mutators <Mutators>
Hypermodule Lib <Hypermodules>
1 change: 1 addition & 0 deletions nni/retiarii/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .api import *
from .component import *
from .nn import *
from .hypermodule import *
249 changes: 249 additions & 0 deletions nni/retiarii/nn/pytorch/hypermodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn

from nni.retiarii.serializer import basic_unit

from .api import LayerChoice
from ...utils import version_larger_equal

__all__ = ['AutoActivation']

TorchVersion = '1.5.0'

# ============== unary function modules ==============

@basic_unit
class UnaryIdentity(nn.Module):
def forward(self, x):
return x

@basic_unit
class UnaryNegative(nn.Module):
def forward(self, x):
return -x

@basic_unit
class UnaryAbs(nn.Module):
def forward(self, x):
return torch.abs(x)

@basic_unit
class UnarySquare(nn.Module):
def forward(self, x):
return torch.square(x)

@basic_unit
class UnaryPow(nn.Module):
def forward(self, x):
return torch.pow(x, 3)

@basic_unit
class UnarySqrt(nn.Module):
def forward(self, x):
return torch.sqrt(x)

@basic_unit
class UnaryMul(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x * self.beta

@basic_unit
class UnaryAdd(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x + self.beta

@basic_unit
class UnaryLogAbs(nn.Module):
def forward(self, x):
return torch.log(torch.abs(x) + 1e-7)

@basic_unit
class UnaryExp(nn.Module):
def forward(self, x):
return torch.exp(x)

@basic_unit
class UnarySin(nn.Module):
def forward(self, x):
return torch.sin(x)

@basic_unit
class UnaryCos(nn.Module):
def forward(self, x):
return torch.cos(x)

@basic_unit
class UnarySinh(nn.Module):
def forward(self, x):
return torch.sinh(x)

@basic_unit
class UnaryCosh(nn.Module):
def forward(self, x):
return torch.cosh(x)

@basic_unit
class UnaryTanh(nn.Module):
def forward(self, x):
return torch.tanh(x)

if not version_larger_equal(torch.__version__, TorchVersion):
@basic_unit
class UnaryAsinh(nn.Module):
def forward(self, x):
return torch.asinh(x)

@basic_unit
class UnaryAtan(nn.Module):
def forward(self, x):
return torch.atan(x)

if not version_larger_equal(torch.__version__, TorchVersion):
@basic_unit
class UnarySinc(nn.Module):
def forward(self, x):
return torch.sinc(x)

@basic_unit
class UnaryMax(nn.Module):
def forward(self, x):
return torch.max(x, torch.zeros_like(x))

@basic_unit
class UnaryMin(nn.Module):
def forward(self, x):
return torch.min(x, torch.zeros_like(x))

@basic_unit
class UnarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x)

@basic_unit
class UnaryLogExp(nn.Module):
def forward(self, x):
return torch.log(1 + torch.exp(x))

@basic_unit
class UnaryExpSquare(nn.Module):
def forward(self, x):
return torch.exp(-torch.square(x))

@basic_unit
class UnaryErf(nn.Module):
def forward(self, x):
return torch.erf(x)

unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow',
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos',
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax',
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf']

if not version_larger_equal(torch.__version__, TorchVersion):
unary_modules.append('UnaryAsinh')
unary_modules.append('UnarySinc')

# ============== binary function modules ==============

@basic_unit
class BinaryAdd(nn.Module):
def forward(self, x):
return x[0] + x[1]

@basic_unit
class BinaryMul(nn.Module):
def forward(self, x):
return x[0] * x[1]

@basic_unit
class BinaryMinus(nn.Module):
def forward(self, x):
return x[0] - x[1]

@basic_unit
class BinaryDivide(nn.Module):
def forward(self, x):
return x[0] / (x[1] + 1e-7)

@basic_unit
class BinaryMax(nn.Module):
def forward(self, x):
return torch.max(x[0], x[1])

@basic_unit
class BinaryMin(nn.Module):
def forward(self, x):
return torch.min(x[0], x[1])

@basic_unit
class BinarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x[0]) * x[1]

@basic_unit
class BinaryExpSquare(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.square(x[0] - x[1]))

@basic_unit
class BinaryExpAbs(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.abs(x[0] - x[1]))

@basic_unit
class BinaryParamAdd(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return self.beta * x[0] + (1 - self.beta) * x[1]

binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax',
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd']


class AutoActivation(nn.Module):
"""
This module is an implementation of the paper "Searching for Activation Functions"
(https://arxiv.org/abs/1710.05941).
NOTE: current `beta` is not per-channel parameter

Parameters
----------
unit_num : int
the number of core units
"""
def __init__(self, unit_num = 1):
super().__init__()
self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label='one_unary')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest globals()(unary)(), which is semantically more clear that eval.

for _ in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label='one_unary')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they should have the same label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right, this is for easy testing, i will remove the label then.

self.unaries.append(one_unary)
for _ in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules])
self.binaries.append(one_binary)

def forward(self, x):
out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries):
out = binary(torch.stack([out, unary(x)]))
return out
4 changes: 2 additions & 2 deletions nni/retiarii/operation_def/torch_op_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val
# TODO: deal with all the types
if self.parameters['type'] == 'None':
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'):
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
Expand Down Expand Up @@ -171,7 +171,7 @@ class AtenTensors(PyTorchOperation):
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros']
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']

def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
Expand Down