Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds No Head #218

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions autoPyTorch/pipeline/components/setup/network_head/no_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any, Dict, Optional, Tuple, Union

from ConfigSpace.configuration_space import ConfigurationSpace
from ConfigSpace.hyperparameters import CategoricalHyperparameter

import numpy as np

from torch import nn

from autoPyTorch.pipeline.components.setup.network_head.base_network_head import NetworkHeadComponent
from autoPyTorch.pipeline.components.setup.network_head.utils import _activations
from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter


class NoHead(NetworkHeadComponent):
"""
Head which only adds a fully connected layer which takes the
output of the backbone as input and outputs the predictions.
Flattens any input in a array of shape [B, prod(input_shape)].
"""

def build_head(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> nn.Module:
layers = [nn.Flatten()]
in_features = np.prod(input_shape).item()
out_features = np.prod(output_shape).item()
layers.append(_activations[self.config["activation"]]())
layers.append(nn.Linear(in_features=in_features,
out_features=out_features))
return nn.Sequential(*layers)

@staticmethod
def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, Union[str, bool]]:
return {
'shortname': 'NoHead',
'name': 'NoHead',
'handles_tabular': True,
'handles_image': True,
'handles_time_series': True,
}

@staticmethod
def get_hyperparameter_search_space(
dataset_properties: Optional[Dict[str, str]] = None,
activation: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="activation",
value_range=tuple(_activations.keys()),
default_value=list(_activations.keys())[0]),
) -> ConfigurationSpace:
cs = ConfigurationSpace()

add_hyperparameter(cs, activation, CategoricalHyperparameter)

return cs
2 changes: 1 addition & 1 deletion test/test_pipeline/components/setup/test_setup_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def backbone(request):
return request.param


@pytest.fixture(params=['fully_connected'])
@pytest.fixture(params=['fully_connected', 'no_head'])
def head(request):
return request.param

Expand Down