Skip to content

Commit

Permalink
Update function arguments typing (#603)
Browse files Browse the repository at this point in the history
  • Loading branch information
amorehead committed Sep 9, 2023
1 parent 2654bad commit 07ce4b7
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from importlib.util import find_spec
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

from omegaconf import DictConfig

Expand Down Expand Up @@ -95,12 +95,12 @@ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return wrap


def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: The name of the metric to retrieve.
:return: The value of the metric.
:param metric_name: If provided, the name of the metric to retrieve.
:return: If a metric name was provided, the value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
Expand Down

0 comments on commit 07ce4b7

Please sign in to comment.