Skip to content

Commit

Permalink
Tests for package behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
omry committed May 12, 2020
1 parent f7ce3f4 commit 0cdf73f
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 33 deletions.
10 changes: 4 additions & 6 deletions hydra/_internal/config_loader_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,13 +282,11 @@ def record_loading(
)
if not ret.is_schema_source:
try:
schema = ConfigStore.instance().load(
config_path=ConfigSource._normalize_file_name(
filename=input_file
)
)
schema_source = self.repository.get_schema_source()
config_path = ConfigSource._normalize_file_name(filename=input_file)
schema = schema_source.load_config(config_path)

merged = OmegaConf.merge(schema.node, ret.config)
merged = OmegaConf.merge(schema.config, ret.config)
assert isinstance(merged, DictConfig)
return (
merged,
Expand Down
8 changes: 8 additions & 0 deletions hydra/_internal/config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ def __init__(self, config_search_path: ConfigSearchPath) -> None:
source = source_type(search_path.provider, search_path.path)
self.sources.append(source)

def get_schema_source(self):
source = self.sources[-1] # should always be last
assert (
source.__class__.__name__ == "StructuredConfigSource"
and source.provider == "schema"
)
return source

def load_config(self, config_path: str) -> Optional[ConfigResult]:
source = self._find_config(config_path=config_path)
ret = None
Expand Down
9 changes: 6 additions & 3 deletions hydra/_internal/core_plugins/file_config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@ def scheme() -> str:
return "file"

def load_config(self, config_path: str) -> ConfigResult:
config_path = self._normalize_file_name(config_path)
full_path = os.path.realpath(os.path.join(self.path, config_path))
normalized_config_path = self._normalize_file_name(config_path)
full_path = os.path.realpath(os.path.join(self.path, normalized_config_path))
if not os.path.exists(full_path):
raise ConfigLoadError(f"FileConfigSource: Config not found : {full_path}")
with open(full_path) as f:
header_text = f.read(512)
header = ConfigSource._get_header_dict(header_text)
self._update_package_in_header(header, normalized_config_path)
f.seek(0)
cfg = OmegaConf.load(f)
cfg = self._embed_config(cfg, header["package"])
return ConfigResult(
config=OmegaConf.load(f),
config=cfg,
path=f"{self.scheme()}://{self.path}",
provider=self.provider,
header=header,
Expand Down
5 changes: 4 additions & 1 deletion hydra/_internal/core_plugins/package_config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ def load_config(self, config_path: str) -> ConfigResult:
with resource_stream(module_name, resource_name) as stream:
header_text = stream.read(512)
header = ConfigSource._get_header_dict(header_text.decode())
self._update_package_in_header(header, config_path)
stream.seek(0)
cfg = OmegaConf.load(stream)
cfg = self._embed_config(cfg, header["package"])
return ConfigResult(
config=OmegaConf.load(stream),
config=cfg,
path=f"{self.scheme()}://{self.path}",
provider=self.provider,
header=header,
Expand Down
4 changes: 3 additions & 1 deletion hydra/_internal/core_plugins/structured_config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ def load_config(self, config_path: str) -> ConfigResult:
header = {}
if ret.package:
header["package"] = ret.package
self._update_package_in_header(header, full_path)
cfg = self._embed_config(ret.node, header["package"])
return ConfigResult(
config=ret.node,
config=cfg,
path=f"{self.scheme()}://{self.path}",
provider=provider,
header=header,
Expand Down
23 changes: 13 additions & 10 deletions hydra/core/config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> Any:
@dataclass
class ConfigNode:
name: str
node: Any
node: Any # TODO: DictConfig or Container?
group: Optional[str]
package: Optional[str]
provider: Optional[str]


NO_DEFAULT_PACKAGE = str("_NO_DEFAULT_PACKAGE_")


class ConfigStore(metaclass=Singleton):
@staticmethod
def instance(*args: Any, **kwargs: Any) -> "ConfigStore":
Expand All @@ -56,7 +59,7 @@ def store(
name: str,
node: Any,
group: Optional[str] = None,
package: Optional[str] = None,
package: Optional[str] = NO_DEFAULT_PACKAGE,
provider: Optional[str] = None,
) -> None:
"""
Expand All @@ -67,25 +70,25 @@ def store(
:param package: Config node parent hierarchy. child separator is '.', for example foo.bar.baz
:param provider: the name of the module/app providing this config. Helps debugging.
"""

if package == NO_DEFAULT_PACKAGE:
package = "_global_"
# TODO: warn the user if we are defaulting
# to _global_ and they should make an explicit selection recommended _group_.

cur = self.repo
if group is not None:
for d in group.split("/"):
if d not in cur:
cur[d] = {}
cur = cur[d]

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, OmegaConf.structured(node))
else:
cfg = OmegaConf.structured(node)

if not name.endswith(".yaml"):
name = f"{name}.yaml"
assert isinstance(cur, dict)
cfg_copy = copy.deepcopy(cfg)
cfg = OmegaConf.structured(node)
cur[name] = ConfigNode(
name=name, node=cfg_copy, group=group, package=package, provider=provider
name=name, node=cfg, group=group, package=package, provider=provider,
)

def load(self, config_path: str) -> ConfigNode:
Expand Down
42 changes: 41 additions & 1 deletion hydra/plugins/config_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import List, Optional, Dict

from omegaconf import Container
from omegaconf import Container, OmegaConf

from hydra.core.object_type import ObjectType
from hydra.plugins.plugin import Plugin
Expand Down Expand Up @@ -110,6 +110,45 @@ def _normalize_file_name(filename: str) -> str:
filename += ".yaml"
return filename

@staticmethod
def _update_package_in_header(header, normalized_config_path):
normalized_config_path = normalized_config_path[0 : -len(".yaml")]
last = normalized_config_path.rfind("/")
if last == -1:
group = ""
name = normalized_config_path
else:
group = normalized_config_path[0:last]
name = normalized_config_path[last + 1 :]

if "package" not in header:
header["package"] = "_global_"
# TODO: warn the user if we are defaulting
# to _global_ and they should make an explicit selection recommended _group_.

package = header["package"]

if package == "_global_":
# default to the global behavior to remain backward compatible.
package = ""
else:
package = package.replace("_group_", group)
package = package.replace("_name_", name)

header["package"] = package

@staticmethod
def _embed_config(node: Container, package: str):
if package == "_global_":
package = ""

if package is not None and package != "":
cfg = OmegaConf.create()
OmegaConf.update(cfg, package, OmegaConf.structured(node))
else:
cfg = OmegaConf.structured(node)
return cfg

@staticmethod
def _get_header_dict(config_text: str) -> Dict[str, str]:
res = {}
Expand All @@ -134,4 +173,5 @@ def _get_header_dict(config_text: str) -> Dict[str, str]:
else:
# stop parsing header on first non-header line
break

return res
41 changes: 41 additions & 0 deletions hydra/test_utils/config_source_common_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,44 @@ def test_source_load_config(
else:
ret = src.load_config(config_path=config_path)
assert ret.config == expected

@pytest.mark.parametrize(
"config_path, expected_result, expected_package",
[
pytest.param("package_test/none", {"foo": "bar"}, "", id="none"),
pytest.param(
"package_test/explicit",
{"a": {"b": {"foo": "bar"}}},
"a.b",
id="explicit",
),
pytest.param("package_test/global", {"foo": "bar"}, "", id="global"),
pytest.param(
"package_test/group",
{"package_test": {"foo": "bar"}},
"package_test",
id="group",
),
pytest.param(
"package_test/group_name",
{"foo": {"package_test": {"group_name": {"foo": "bar"}}}},
"foo.package_test.group_name",
id="group_name",
),
pytest.param(
"package_test/name", {"name": {"foo": "bar"}}, "name", id="name"
),
],
)
def test_package_behavior(
self,
type_: Type[ConfigSource],
path: str,
config_path: str,
expected_result: Any,
expected_package: str,
) -> None:
src = type_(provider="foo", path=path)
cfg = src.load_config(config_path=config_path)
assert cfg.header["package"] == expected_package
assert cfg.config == expected_result
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package a.b
foo: bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _global_
foo: bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _group_
foo: bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package foo._group_._name_
foo: bar
2 changes: 2 additions & 0 deletions tests/test_apps/config_source_test/dir/package_test/name.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package _name_
foo: bar
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
foo: bar
25 changes: 19 additions & 6 deletions tests/test_apps/config_source_test/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ class Optimizer:
s = ConfigStore.instance()
s.store(name="config_without_group", node=ConfigWithoutGroup)
s.store(name="dataset", node={"dataset_yaml": True})
s.store(name="cifar10", node=Cifar10, group="dataset", package="dataset")
s.store(name="imagenet.yaml", node=ImageNet, group="dataset", package="dataset")
s.store(name="adam", node=Adam, group="optimizer", package="optimizer")
s.store(name="nesterov", node=Nesterov, group="optimizer", package="optimizer")
s.store(group="dataset", name="cifar10", node=Cifar10, package="dataset")
s.store(group="dataset", name="imagenet.yaml", node=ImageNet, package="dataset")
s.store(group="optimizer", name="adam", node=Adam, package="optimizer")
s.store(group="optimizer", name="nesterov", node=Nesterov, package="optimizer")
s.store(
name="nested1", node={"l1_l2_n1": True}, group="level1/level2", package="optimizer"
group="level1/level2", name="nested1", node={"l1_l2_n1": True}, package="optimizer"
)

s.store(
name="nested2", node={"l1_l2_n2": True}, group="level1/level2", package="optimizer"
group="level1/level2", name="nested2", node={"l1_l2_n2": True}, package="optimizer"
)


s.store(group="package_test", name="none", node={"foo": "bar"})
s.store(group="package_test", name="explicit", node={"foo": "bar"}, package="a.b")
s.store(group="package_test", name="global", node={"foo": "bar"}, package="_global_")
s.store(group="package_test", name="group", node={"foo": "bar"}, package="_group_")
s.store(
group="package_test",
name="group_name",
node={"foo": "bar"},
package="foo._group_._name_",
)
s.store(group="package_test", name="name", node={"foo": "bar"}, package="_name_")
20 changes: 15 additions & 5 deletions tests/test_config_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@
pytest.param(
FileConfigSource,
"file://tests/test_apps/config_source_test/dir",
id="file://",
id="FileConfigSource",
),
pytest.param(
PackageConfigSource,
"pkg://tests.test_apps.config_source_test.dir",
id="pkg://",
id="PackageConfigSource",
),
pytest.param(
StructuredConfigSource,
"structured://tests.test_apps.config_source_test.structured",
id="structured://",
id="StructuredConfigSource",
),
],
)
Expand Down Expand Up @@ -77,8 +77,18 @@ def test_config_repository_exists(self, restore_singletons: Any, path: str) -> N
@pytest.mark.parametrize( # type: ignore
"config_path,results_filter,expected",
[
("", None, ["config_without_group", "dataset", "level1", "optimizer"]),
("", ObjectType.GROUP, ["dataset", "level1", "optimizer"]),
(
"",
None,
[
"config_without_group",
"dataset",
"level1",
"optimizer",
"package_test",
],
),
("", ObjectType.GROUP, ["dataset", "level1", "optimizer", "package_test"]),
("", ObjectType.CONFIG, ["config_without_group", "dataset"]),
("dataset", None, ["cifar10", "imagenet"]),
("dataset", ObjectType.GROUP, []),
Expand Down

0 comments on commit 0cdf73f

Please sign in to comment.