diff --git a/autoPyTorch/pipeline/components/setup/network_head/no_head.py b/autoPyTorch/pipeline/components/setup/network_head/no_head.py new file mode 100644 index 000000000..f5cadb416 --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/network_head/no_head.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, Optional, Tuple, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import CategoricalHyperparameter + +import numpy as np + +from torch import nn + +from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent +from autoPyTorch.pipeline.components.setup.network_head.utils import _activations +from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter + + +class NoHead(NetworkHeadComponent): + """ + Head which only adds a fully connected layer which takes the + output of the backbone as input and outputs the predictions. + Flattens any input in a array of shape [B, prod(input_shape)]. + """ + + def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module: + layers = [nn.Flatten()] + in_features = np.prod(input_shape).item() + out_features = np.prod(output_shape).item() + layers.append(_activations[self.config["activation"]]()) + layers.append(nn.Linear(in_features=in_features, + out_features=out_features)) + return nn.Sequential(*layers) + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]: + return { + 'shortname': 'NoHead', + 'name': 'NoHead', + 'handles_tabular': True, + 'handles_image': True, + 'handles_time_series': True, + } + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, str]] = None, + activation: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="activation", + value_range=tuple(_activations.keys()), + default_value=list(_activations.keys())[0]), + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + add_hyperparameter(cs, activation, CategoricalHyperparameter) + + return cs diff --git a/test/test_pipeline/components/setup/test_setup_networks.py b/test/test_pipeline/components/setup/test_setup_networks.py index 6826d7ef2..8070c0344 100644 --- a/test/test_pipeline/components/setup/test_setup_networks.py +++ b/test/test_pipeline/components/setup/test_setup_networks.py @@ -12,7 +12,7 @@ def backbone(request): return request.param -@pytest.fixture(params=['fully_connected']) +@pytest.fixture(params=['fully_connected', 'no_head']) def head(request): return request.param