Skip to content

Commit

Permalink
added no head (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Feb 28, 2022
1 parent 7f83ce3 commit a21c2e4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
52 changes: 52 additions & 0 deletions autoPyTorch/pipeline/components/setup/network_head/no_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional, Tuple, Union

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import CategoricalHyperparameter

import numpy as np

from torch import nn

from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent
from autoPyTorch.pipeline.components.setup.network_head.utils import _activations
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


class NoHead(NetworkHeadComponent):
"""
Head which only adds a fully connected layer which takes the
output of the backbone as input and outputs the predictions.
Flattens any input in a array of shape [B, prod(input_shape)].
"""

def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
layers = [nn.Flatten()]
in_features = np.prod(input_shape).item()
out_features = np.prod(output_shape).item()
layers.append(_activations[self.config["activation"]]())
layers.append(nn.Linear(in_features=in_features,
out_features=out_features))
return nn.Sequential(*layers)

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
return {
'shortname': 'NoHead',
'name': 'NoHead',
'handles_tabular': True,
'handles_image': True,
'handles_time_series': True,
}

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, str]] = None,
activation: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="activation",
value_range=tuple(_activations.keys()),
default_value=list(_activations.keys())[0]),
) -> ConfigurationSpace:
cs = ConfigurationSpace()

add_hyperparameter(cs, activation, CategoricalHyperparameter)

return cs
2 changes: 1 addition & 1 deletion test/test_pipeline/components/setup/test_setup_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def backbone(request):
return request.param


@pytest.fixture(params=['fully_connected'])
@pytest.fixture(params=['fully_connected', 'no_head'])
def head(request):
return request.param

Expand Down

0 comments on commit a21c2e4

Please sign in to comment.