diff --git a/docs/en_US/NAS/Hypermodules.rst b/docs/en_US/NAS/Hypermodules.rst new file mode 100644 index 0000000000..e87bf34725 --- /dev/null +++ b/docs/en_US/NAS/Hypermodules.rst @@ -0,0 +1,9 @@ +Hypermodules +============ + +Hypermodule is a (PyTorch) module which contains many architecture/hyperparameter candidates for this module. By using hypermodule in user defined model, NNI will help users automatically find the best architecture/hyperparameter of the hypermodules for this model. This follows the design philosophy of Retiarii that users write DNN model as a space. + +There has been proposed some hypermodules in NAS community, such as AutoActivation, AutoDropout. Some of them are implemented in the Retiarii framework. + +.. autoclass:: nni.retiarii.nn.pytorch.AutoActivation + :members: \ No newline at end of file diff --git a/docs/en_US/NAS/construct_space.rst b/docs/en_US/NAS/construct_space.rst index b32489d4a7..362bb446ea 100644 --- a/docs/en_US/NAS/construct_space.rst +++ b/docs/en_US/NAS/construct_space.rst @@ -8,4 +8,5 @@ NNI provides powerful APIs for users to easily express model space (or search sp :maxdepth: 1 Mutation Primitives - Customize Mutators \ No newline at end of file + Customize Mutators + Hypermodule Lib \ No newline at end of file diff --git a/nni/retiarii/nn/pytorch/__init__.py b/nni/retiarii/nn/pytorch/__init__.py index 5c392164b1..bcc8c45f3f 100644 --- a/nni/retiarii/nn/pytorch/__init__.py +++ b/nni/retiarii/nn/pytorch/__init__.py @@ -1,3 +1,4 @@ from .api import * from .component import * from .nn import * +from .hypermodule import * \ No newline at end of file diff --git a/nni/retiarii/nn/pytorch/hypermodule.py b/nni/retiarii/nn/pytorch/hypermodule.py new file mode 100644 index 0000000000..86cd00ea93 --- /dev/null +++ b/nni/retiarii/nn/pytorch/hypermodule.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import torch +import torch.nn as nn + +from nni.retiarii.serializer import basic_unit + +from .api import LayerChoice +from ...utils import version_larger_equal + +__all__ = ['AutoActivation'] + +TorchVersion = '1.5.0' + +# ============== unary function modules ============== + +@basic_unit +class UnaryIdentity(nn.Module): + def forward(self, x): + return x + +@basic_unit +class UnaryNegative(nn.Module): + def forward(self, x): + return -x + +@basic_unit +class UnaryAbs(nn.Module): + def forward(self, x): + return torch.abs(x) + +@basic_unit +class UnarySquare(nn.Module): + def forward(self, x): + return torch.square(x) + +@basic_unit +class UnaryPow(nn.Module): + def forward(self, x): + return torch.pow(x, 3) + +@basic_unit +class UnarySqrt(nn.Module): + def forward(self, x): + return torch.sqrt(x) + +@basic_unit +class UnaryMul(nn.Module): + def __init__(self): + super().__init__() + # element-wise for now, will change to per-channel trainable parameter + self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable + def forward(self, x): + return x * self.beta + +@basic_unit +class UnaryAdd(nn.Module): + def __init__(self): + super().__init__() + # element-wise for now, will change to per-channel trainable parameter + self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable + def forward(self, x): + return x + self.beta + +@basic_unit +class UnaryLogAbs(nn.Module): + def forward(self, x): + return torch.log(torch.abs(x) + 1e-7) + +@basic_unit +class UnaryExp(nn.Module): + def forward(self, x): + return torch.exp(x) + +@basic_unit +class UnarySin(nn.Module): + def forward(self, x): + return torch.sin(x) + +@basic_unit +class UnaryCos(nn.Module): + def forward(self, x): + return torch.cos(x) + +@basic_unit +class UnarySinh(nn.Module): + def forward(self, x): + return torch.sinh(x) + +@basic_unit +class UnaryCosh(nn.Module): + def forward(self, x): + return torch.cosh(x) + +@basic_unit +class UnaryTanh(nn.Module): + def forward(self, x): + return torch.tanh(x) + +if not version_larger_equal(torch.__version__, TorchVersion): + @basic_unit + class UnaryAsinh(nn.Module): + def forward(self, x): + return torch.asinh(x) + +@basic_unit +class UnaryAtan(nn.Module): + def forward(self, x): + return torch.atan(x) + +if not version_larger_equal(torch.__version__, TorchVersion): + @basic_unit + class UnarySinc(nn.Module): + def forward(self, x): + return torch.sinc(x) + +@basic_unit +class UnaryMax(nn.Module): + def forward(self, x): + return torch.max(x, torch.zeros_like(x)) + +@basic_unit +class UnaryMin(nn.Module): + def forward(self, x): + return torch.min(x, torch.zeros_like(x)) + +@basic_unit +class UnarySigmoid(nn.Module): + def forward(self, x): + return torch.sigmoid(x) + +@basic_unit +class UnaryLogExp(nn.Module): + def forward(self, x): + return torch.log(1 + torch.exp(x)) + +@basic_unit +class UnaryExpSquare(nn.Module): + def forward(self, x): + return torch.exp(-torch.square(x)) + +@basic_unit +class UnaryErf(nn.Module): + def forward(self, x): + return torch.erf(x) + +unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow', + 'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos', + 'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax', + 'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf'] + +if not version_larger_equal(torch.__version__, TorchVersion): + unary_modules.append('UnaryAsinh') + unary_modules.append('UnarySinc') + +# ============== binary function modules ============== + +@basic_unit +class BinaryAdd(nn.Module): + def forward(self, x): + return x[0] + x[1] + +@basic_unit +class BinaryMul(nn.Module): + def forward(self, x): + return x[0] * x[1] + +@basic_unit +class BinaryMinus(nn.Module): + def forward(self, x): + return x[0] - x[1] + +@basic_unit +class BinaryDivide(nn.Module): + def forward(self, x): + return x[0] / (x[1] + 1e-7) + +@basic_unit +class BinaryMax(nn.Module): + def forward(self, x): + return torch.max(x[0], x[1]) + +@basic_unit +class BinaryMin(nn.Module): + def forward(self, x): + return torch.min(x[0], x[1]) + +@basic_unit +class BinarySigmoid(nn.Module): + def forward(self, x): + return torch.sigmoid(x[0]) * x[1] + +@basic_unit +class BinaryExpSquare(nn.Module): + def __init__(self): + super().__init__() + self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable + def forward(self, x): + return torch.exp(-self.beta * torch.square(x[0] - x[1])) + +@basic_unit +class BinaryExpAbs(nn.Module): + def __init__(self): + super().__init__() + self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable + def forward(self, x): + return torch.exp(-self.beta * torch.abs(x[0] - x[1])) + +@basic_unit +class BinaryParamAdd(nn.Module): + def __init__(self): + super().__init__() + self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable + def forward(self, x): + return self.beta * x[0] + (1 - self.beta) * x[1] + +binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax', + 'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd'] + + +class AutoActivation(nn.Module): + """ + This module is an implementation of the paper "Searching for Activation Functions" + (https://arxiv.org/abs/1710.05941). + NOTE: current `beta` is not per-channel parameter + + Parameters + ---------- + unit_num : int + the number of core units + """ + def __init__(self, unit_num = 1): + super().__init__() + self.unaries = nn.ModuleList() + self.binaries = nn.ModuleList() + self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) + for _ in range(unit_num): + one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) + self.unaries.append(one_unary) + for _ in range(unit_num): + one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules]) + self.binaries.append(one_binary) + + def forward(self, x): + out = self.first_unary(x) + for unary, binary in zip(self.unaries, self.binaries): + out = binary(torch.stack([out, unary(x)])) + return out diff --git a/nni/retiarii/operation_def/torch_op_def.py b/nni/retiarii/operation_def/torch_op_def.py index 313a5558af..bb3a01e546 100644 --- a/nni/retiarii/operation_def/torch_op_def.py +++ b/nni/retiarii/operation_def/torch_op_def.py @@ -61,7 +61,7 @@ def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_val # TODO: deal with all the types if self.parameters['type'] in ['None', 'NoneType']: return f'{output} = None' - elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): + elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ??? return f'{output} = {self.parameters["value"]}' elif self.parameters['type'] == 'str': str_val = self.parameters["value"] @@ -171,7 +171,7 @@ class AtenTensors(PyTorchOperation): 'aten::ones_like', 'aten::zeros_like', 'aten::rand', 'aten::randn', 'aten::scalar_tensor', 'aten::new_full', 'aten::new_empty', 'aten::new_zeros', 'aten::arange', - 'aten::tensor', 'aten::ones', 'aten::zeros'] + 'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor'] def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: schemas = torch._C._jit_get_schemas_for_operator(self.type) diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index 49d4944a6b..aabf3deff2 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -514,6 +514,24 @@ def forward(self, x): model = mutator.bind_sampler(sampler).apply(model) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16])) + def test_autoactivation(self): + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.act = nn.AutoActivation() + + def forward(self, x): + return self.act(x) + + raw_model, mutators = self._get_model_with_mutators(Net()) + for _ in range(10): + sampler = EnumerateSampler() + model = raw_model + for mutator in mutators: + model = mutator.bind_sampler(sampler).apply(model) + self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 10])) + class Python(GraphIR): def _get_converted_pytorch_model(self, model_ir):