diff --git a/python/ray/air/_internal/mlflow.py b/python/ray/air/_internal/mlflow.py index d0fd1168dd60..a7b553100933 100644 --- a/python/ray/air/_internal/mlflow.py +++ b/python/ray/air/_internal/mlflow.py @@ -4,6 +4,8 @@ from copy import deepcopy from typing import TYPE_CHECKING, Dict, Optional +from ray._private.dict import flatten_dict + if TYPE_CHECKING: from mlflow.entities import Run from mlflow.tracking import MlflowClient @@ -262,6 +264,7 @@ def log_params(self, params_to_log: Dict, run_id: Optional[str] = None): params_to_log: Dictionary of parameters to log. run_id (Optional[str]): The ID of the run to log to. """ + params_to_log = flatten_dict(params_to_log) if run_id and self._run_exists(run_id): client = self._get_client() @@ -284,6 +287,7 @@ def log_metrics(self, step, metrics_to_log: Dict, run_id: Optional[str] = None): metrics_to_log: Dictionary of metrics to log. run_id (Optional[str]): The ID of the run to log to. """ + metrics_to_log = flatten_dict(metrics_to_log) metrics_to_log = self._parse_dict(metrics_to_log) if run_id and self._run_exists(run_id): diff --git a/python/ray/air/tests/test_integration_mlflow.py b/python/ray/air/tests/test_integration_mlflow.py index 7b6ea45c0642..85cab89080e1 100644 --- a/python/ray/air/tests/test_integration_mlflow.py +++ b/python/ray/air/tests/test_integration_mlflow.py @@ -7,6 +7,7 @@ from mlflow.tracking import MlflowClient +from ray._private.dict import flatten_dict from ray.train._internal.session import init_session from ray.tune.trainable import wrap_function from ray.tune.trainable.session import _shutdown as tune_session_shutdown @@ -367,7 +368,7 @@ def test_setup_fail(self): ) def test_log_params(self): - params = {"a": "a"} + params = {"a": "a", "x": {"y": "z"}} self.mlflow_util.setup_mlflow( tracking_uri=self.tracking_uri, experiment_name="new_experiment" ) @@ -376,21 +377,23 @@ def test_log_params(self): self.mlflow_util.log_params(params_to_log=params, run_id=run_id) run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.params == params + assert run.data.params == flatten_dict(params) params2 = {"b": "b"} self.mlflow_util.start_run(set_active=True) self.mlflow_util.log_params(params_to_log=params2, run_id=run_id) run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.params == { - **params, - **params2, - } + assert run.data.params == flatten_dict( + { + **params, + **params2, + } + ) self.mlflow_util.end_run() def test_log_metrics(self): - metrics = {"a": 1.0} + metrics = {"a": 1.0, "x": {"y": 2.0}} self.mlflow_util.setup_mlflow( tracking_uri=self.tracking_uri, experiment_name="new_experiment" ) @@ -399,15 +402,19 @@ def test_log_metrics(self): self.mlflow_util.log_metrics(metrics_to_log=metrics, run_id=run_id, step=0) run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.metrics == metrics + assert run.data.metrics == flatten_dict(metrics) metrics2 = {"b": 1.0} self.mlflow_util.start_run(set_active=True) self.mlflow_util.log_metrics(metrics_to_log=metrics2, run_id=run_id, step=0) - assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.metrics == { - **metrics, - **metrics2, - } + assert self.mlflow_util._mlflow.get_run( + run_id=run_id + ).data.metrics == flatten_dict( + { + **metrics, + **metrics2, + } + ) self.mlflow_util.end_run()