Skip to content

Commit

Permalink
Fix running without args
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Jun 6, 2024
1 parent 7c2afdd commit 8e73be8
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions lighter/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def validate_config(parser: ConfigParser) -> None:
ValueError: If there are invalid method names specified in the 'args' section.
"""
# Validate parser keys against structure
invalid_keys = set(parser.get().keys()) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
root_keys = parser.get().keys()
invalid_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
if invalid_keys:
raise ValueError(f"Invalid top-level config keys: {list(invalid_keys)}. Allowed keys: {list(CONFIG_STRUCTURE.keys())}")

# Validate that 'args' contains only valid Trainer/Tuner method names.
invalid_keys = set(parser.get("args").keys()) - set(TRAINER_METHOD_NAMES)
args_keys = parser.get("args", {}).keys()
invalid_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
if invalid_keys:
raise ValueError(f"Invalid method names in 'args': {invalid_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")

Expand All @@ -88,7 +90,7 @@ def run(method: str, **kwargs: Any):
# Get the main components from the parsed config.
system = parser.get_parsed_content("system")
trainer = parser.get_parsed_content("trainer")
trainer_method_args = parser.get_parsed_content(f"args#{method}")
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})

# Checks
if not isinstance(system, LighterSystem):
Expand Down

0 comments on commit 8e73be8

Please sign in to comment.