Skip to content

Commit

Permalink
[air/mlflow] Flatten config and metrics before passing to mlflow (ray…
Browse files Browse the repository at this point in the history
…-project#35074)

Metrics and parameters are passed as-is to mlflow e.g. in the MlFlowCallback. However, mlflow can't deal with nested dicts.

Instead, we should flatten these dicts before passing them over. 

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored and architkulkarni committed May 16, 2023
1 parent 05c1550 commit ba8edfe
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
4 changes: 4 additions & 0 deletions python/ray/air/_internal/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Dict, Optional

from ray._private.dict import flatten_dict

if TYPE_CHECKING:
from mlflow.entities import Run
from mlflow.tracking import MlflowClient
Expand Down Expand Up @@ -262,6 +264,7 @@ def log_params(self, params_to_log: Dict, run_id: Optional[str] = None):
params_to_log: Dictionary of parameters to log.
run_id (Optional[str]): The ID of the run to log to.
"""
params_to_log = flatten_dict(params_to_log)

if run_id and self._run_exists(run_id):
client = self._get_client()
Expand All @@ -284,6 +287,7 @@ def log_metrics(self, step, metrics_to_log: Dict, run_id: Optional[str] = None):
metrics_to_log: Dictionary of metrics to log.
run_id (Optional[str]): The ID of the run to log to.
"""
metrics_to_log = flatten_dict(metrics_to_log)
metrics_to_log = self._parse_dict(metrics_to_log)

if run_id and self._run_exists(run_id):
Expand Down
31 changes: 19 additions & 12 deletions python/ray/air/tests/test_integration_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mlflow.tracking import MlflowClient

from ray._private.dict import flatten_dict
from ray.train._internal.session import init_session
from ray.tune.trainable import wrap_function
from ray.tune.trainable.session import _shutdown as tune_session_shutdown
Expand Down Expand Up @@ -367,7 +368,7 @@ def test_setup_fail(self):
)

def test_log_params(self):
params = {"a": "a"}
params = {"a": "a", "x": {"y": "z"}}
self.mlflow_util.setup_mlflow(
tracking_uri=self.tracking_uri, experiment_name="new_experiment"
)
Expand All @@ -376,21 +377,23 @@ def test_log_params(self):
self.mlflow_util.log_params(params_to_log=params, run_id=run_id)

run = self.mlflow_util._mlflow.get_run(run_id=run_id)
assert run.data.params == params
assert run.data.params == flatten_dict(params)

params2 = {"b": "b"}
self.mlflow_util.start_run(set_active=True)
self.mlflow_util.log_params(params_to_log=params2, run_id=run_id)
run = self.mlflow_util._mlflow.get_run(run_id=run_id)
assert run.data.params == {
**params,
**params2,
}
assert run.data.params == flatten_dict(
{
**params,
**params2,
}
)

self.mlflow_util.end_run()

def test_log_metrics(self):
metrics = {"a": 1.0}
metrics = {"a": 1.0, "x": {"y": 2.0}}
self.mlflow_util.setup_mlflow(
tracking_uri=self.tracking_uri, experiment_name="new_experiment"
)
Expand All @@ -399,15 +402,19 @@ def test_log_metrics(self):
self.mlflow_util.log_metrics(metrics_to_log=metrics, run_id=run_id, step=0)

run = self.mlflow_util._mlflow.get_run(run_id=run_id)
assert run.data.metrics == metrics
assert run.data.metrics == flatten_dict(metrics)

metrics2 = {"b": 1.0}
self.mlflow_util.start_run(set_active=True)
self.mlflow_util.log_metrics(metrics_to_log=metrics2, run_id=run_id, step=0)
assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.metrics == {
**metrics,
**metrics2,
}
assert self.mlflow_util._mlflow.get_run(
run_id=run_id
).data.metrics == flatten_dict(
{
**metrics,
**metrics2,
}
)
self.mlflow_util.end_run()


Expand Down

0 comments on commit ba8edfe

Please sign in to comment.