Skip to content

Commit

Permalink
Fix config structure enforcing and typechecking. Add full Tuner suppo…
Browse files Browse the repository at this point in the history
…rt. (#133)

* Fix config structure enforcing and typechecking. Add full Tuner support.

* Fix config loading when in parse_config()

* Improve error message when CLI command is incorrect

* Reduce system's learning_rate property docs
  • Loading branch information
ibro45 authored Aug 1, 2024
1 parent b10b8c6 commit 66fd0c4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 47 deletions.
14 changes: 14 additions & 0 deletions lighter/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,20 @@ def setup(self, stage: str) -> None:
self.predict_dataloader = partial(self._base_dataloader, mode="predict")
self.predict_step = partial(self._base_step, mode="predict")

@property
def learning_rate(self) -> float:
"""Get the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
if len(self.optimizer.param_groups) > 1:
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
return self.optimizer.param_groups[0]["lr"]

@learning_rate.setter
def learning_rate(self, value) -> None:
"""Set the learning rate of the optimizer. Ensures compatibility with the Tuner's 'lr_find()' method."""
if len(self.optimizer.param_groups) > 1:
raise ValueError("The learning rate is not available when there are multiple optimizer parameter groups.")
self.optimizer.param_groups[0]["lr"] = value

def _init_placeholders_for_dataloader_and_step_methods(self) -> None:
"""
Initializes placeholders for dataloader and step methods.
Expand Down
16 changes: 13 additions & 3 deletions lighter/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,23 @@ def apply_fns(data: Any, fns: Union[Callable, List[Callable]]) -> Any:

def get_optimizer_stats(optimizer: Optimizer) -> Dict[str, float]:
"""
Extract learning rates and momentum values from each parameter group of the optimizer.
Extract learning rates and momentum values from an optimizer into a dictionary.
This function iterates over the parameter groups of the given optimizer and collects
the learning rate and momentum (or beta values) for each group. The collected values
are stored in a dictionary with keys formatted to indicate the optimizer type and
parameter group index (if multiple groups are present).
Args:
optimizer (Optimizer): A PyTorch optimizer.
optimizer (Optimizer): A PyTorch optimizer instance.
Returns:
Dictionary with formatted keys and values for learning rates and momentum.
Dict[str, float]: A dictionary containing the learning rates and momentum values
for each parameter group in the optimizer. The keys are formatted as:
- "optimizer/{optimizer_class_name}/lr" for learning rates
- "optimizer/{optimizer_class_name}/momentum" for momentum values
If there are multiple parameter groups, the keys will include the group index, e.g.,
"optimizer/{optimizer_class_name}/lr/group1".
"""
stats_dict = {}
for group_idx, group in enumerate(optimizer.param_groups):
Expand Down
124 changes: 80 additions & 44 deletions lighter/utils/runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,45 @@
from typing import Any

import copy
from functools import partial

import fire
from monai.bundle.config_parser import ConfigParser
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.tuner import Tuner

from lighter.system import LighterSystem
from lighter.utils.dynamic_imports import import_module_from_path

CONFIG_STRUCTURE = {"project": None, "system": {}, "trainer": {}, "args": {}, "vars": {}}
TRAINER_METHOD_NAMES = ["fit", "validate", "test", "predict", "lr_find", "scale_batch_size"]
CONFIG_STRUCTURE = {
"project": None,
"vars": {},
"args": {
# Keys - names of the methods; values - arguments passed to them.
"fit": {},
"validate": {},
"test": {},
"predict": {},
"lr_find": {},
"scale_batch_size": {},
},
"system": {},
"trainer": {},
}


def cli() -> None:
"""Defines the command line interface for running lightning trainer's methods."""
commands = {method: partial(run, method) for method in TRAINER_METHOD_NAMES}
fire.Fire(commands)
commands = {method: partial(run, method) for method in CONFIG_STRUCTURE["args"]}
try:
fire.Fire(commands)
except TypeError as e:
if "run() takes 1 positional argument but" in str(e):
raise ValueError(
"Ensure that only one command is run at a time (e.g., 'lighter fit') and that "
"other command line arguments start with '--' (e.g., '--config', '--system#batch_size=1')."
) from e
raise


def parse_config(**kwargs) -> ConfigParser:
Expand All @@ -29,25 +52,24 @@ def parse_config(**kwargs) -> ConfigParser:
Returns:
An instance of ConfigParser with configuration and overrides merged and parsed.
"""
# Ensure a config file is specified.
config = kwargs.pop("config", None)
if config is None:
raise ValueError("'--config' not specified. Please provide a valid configuration file.")

# Read the config file and update it with overrides.
parser = ConfigParser(CONFIG_STRUCTURE, globals=False)
parser.read_config(config)
# Create a deep copy to ensure the original structure remains unaltered by ConfigParser.
structure = copy.deepcopy(CONFIG_STRUCTURE)
# Initialize the parser with the predefined structure.
parser = ConfigParser(structure, globals=False)
# Update the parser with the configuration file.
parser.update(parser.load_config_files(config))
# Update the parser with the provided cli arguments.
parser.update(kwargs)
return parser


def validate_config(parser: ConfigParser) -> None:
"""
Validates the configuration parser against predefined structures and allowed method names.
This function checks if the keys in the top-level of the configuration parser are valid according to the
CONFIG_STRUCTURE. It also verifies that the 'args' section of the configuration only contains keys that
correspond to valid trainer method names as defined in TRAINER_METHOD_NAMES.
Validates the configuration parser against predefined structure.
Args:
parser (ConfigParser): The configuration parser instance to validate.
Expand All @@ -56,20 +78,28 @@ def validate_config(parser: ConfigParser) -> None:
ValueError: If there are invalid keys in the top-level configuration.
ValueError: If there are invalid method names specified in the 'args' section.
"""
# Validate parser keys against structure
root_keys = parser.get().keys()
invalid_root_keys = set(root_keys) - set(CONFIG_STRUCTURE.keys()) - {"_meta_", "_requires_"}
invalid_root_keys = set(parser.get()) - set(CONFIG_STRUCTURE)
if invalid_root_keys:
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {CONFIG_STRUCTURE.keys()}")
raise ValueError(f"Invalid top-level config keys: {invalid_root_keys}. Allowed keys: {list(CONFIG_STRUCTURE)}.")

# Validate that 'args' contains only valid trainer method names.
args_keys = parser.get("args", {}).keys()
invalid_args_keys = set(args_keys) - set(TRAINER_METHOD_NAMES)
invalid_args_keys = set(parser.get("args")) - set(CONFIG_STRUCTURE["args"])
if invalid_args_keys:
raise ValueError(f"Invalid trainer method in 'args': {invalid_args_keys}. Allowed methods are: {TRAINER_METHOD_NAMES}")


def run(method: str, **kwargs: Any):
raise ValueError(f"Invalid key in 'args': {invalid_args_keys}. Allowed keys: {list(CONFIG_STRUCTURE['args'])}.")

typechecks = {
"project": (str, type(None)),
"vars": dict,
"system": dict,
"trainer": dict,
"args": dict,
**{f"args#{k}": dict for k in CONFIG_STRUCTURE["args"]},
}
for key, dtype in typechecks.items():
if not isinstance(parser.get(key), dtype):
raise ValueError(f"Invalid value for key '{key}'. Expected a {dtype}.")


def run(method: str, **kwargs: Any) -> None:
"""Run the trainer method.
Args:
Expand All @@ -82,30 +112,36 @@ def run(method: str, **kwargs: Any):
parser = parse_config(**kwargs)
validate_config(parser)

# Import the project folder as a module, if specified.
# Project. If specified, the give path is imported as a module.
project = parser.get_parsed_content("project")
if project is not None:
import_module_from_path("project", project)

# Get the main components from the parsed config.
# System
system = parser.get_parsed_content("system")
if not isinstance(system, LighterSystem):
raise ValueError("Expected 'system' to be an instance of 'LighterSystem'")

# Trainer
trainer = parser.get_parsed_content("trainer")
trainer_method_args = parser.get_parsed_content(f"args#{method}", default={})
if not isinstance(trainer, Trainer):
raise ValueError("Expected 'trainer' to be an instance of PyTorch Lightning 'Trainer'")

# Checks
if not isinstance(system, LighterSystem):
raise ValueError(f"Expected 'system' to be an instance of LighterSystem, got {system.__class__.__name__}.")
if not hasattr(trainer, method):
raise ValueError(f"{trainer.__class__.__name__} has no method named '{method}'.")
if any("dataloaders" in key or "datamodule" in key for key in trainer_method_args):
raise ValueError("All dataloaders should be defined as part of the LighterSystem, not passed as method arguments.")

# Save the config to checkpoints under "hyper_parameters" and log it if a logger is defined.
config = parser.get()
config.pop("_meta_") # MONAI Bundle adds this automatically, remove it.
system.save_hyperparameters(config)
if trainer.logger is not None:
trainer.logger.log_hyperparams(config)
# Trainer/Tuner method arguments.
method_args = parser.get_parsed_content(f"args#{method}")
if any("dataloaders" in key or "datamodule" in key for key in method_args):
raise ValueError("Datasets are defined within the 'system', not passed in `args`.")

# Run the trainer method.
getattr(trainer, method)(system, **trainer_method_args)
# Save the config to checkpoints under "hyper_parameters". Log it if a logger is defined.
system.save_hyperparameters(parser.get())
if trainer.logger is not None:
trainer.logger.log_hyperparams(parser.get())

# Run the trainer/tuner method.
if hasattr(trainer, method):
getattr(trainer, method)(system, **method_args)
elif hasattr(Tuner, method):
tuner = Tuner(trainer)
getattr(tuner, method)(system, **method_args)
else:
raise ValueError(f"Method '{method}' is not a valid Trainer or Tuner method [{list(CONFIG_STRUCTURE['args'])}].")

0 comments on commit 66fd0c4

Please sign in to comment.