diff --git a/src/utils/utils.py b/src/utils/utils.py index b3b404e36..a4d4eb1e0 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,6 +1,6 @@ import warnings from importlib.util import find_spec -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Optional, Tuple from omegaconf import DictConfig @@ -95,12 +95,12 @@ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: return wrap -def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: +def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. - :param metric_name: The name of the metric to retrieve. - :return: The value of the metric. + :param metric_name: If provided, the name of the metric to retrieve. + :return: If a metric name was provided, the value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...")