diff --git a/doc/source/ray-air/examples/huggingface_text_classification.ipynb b/doc/source/ray-air/examples/huggingface_text_classification.ipynb index 1308eb13ccd9..da844f73c124 100644 --- a/doc/source/ray-air/examples/huggingface_text_classification.ipynb +++ b/doc/source/ray-air/examples/huggingface_text_classification.ipynb @@ -600,7 +600,7 @@ "source": [ "With our `trainer_init_per_worker` complete, we can now instantiate the `HuggingFaceTrainer`. Aside from the function, we set the `scaling_config`, controlling the amount of workers and resources used, and the `datasets` we will use for training and evaluation.\n", "\n", - "We specify the `MlflowLoggerCallback` inside the `run_config`, and pass the preprocessor we have defined earlier as an argument. The preprocessor will be included with the returned `Checkpoint`, meaning it will also be applied during inference." + "We specify the `MLflowLoggerCallback` inside the `run_config`, and pass the preprocessor we have defined earlier as an argument. The preprocessor will be included with the returned `Checkpoint`, meaning it will also be applied during inference." ] }, { diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index 4db43ec2cffc..160ea1f1f6c3 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -36,7 +36,7 @@ MLflow (tune.integration.mlflow) .. autoclass:: ray.air.integrations.mlflow.MLflowLoggerCallback :noindex: -.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin +.. autofunction:: ray.air.integrations.mlflow.setup_mlflow .. _tune-integration-mxnet: diff --git a/doc/source/tune/examples/tune-mlflow.ipynb b/doc/source/tune/examples/tune-mlflow.ipynb index 2cd640d6f17f..23f66d238c30 100644 --- a/doc/source/tune/examples/tune-mlflow.ipynb +++ b/doc/source/tune/examples/tune-mlflow.ipynb @@ -11,7 +11,7 @@ "\n", ":::{warning}\n", "If you are using these MLflow integrations with {ref}`ray-client-ref`, it is recommended that you setup a\n", - "remote Mlflow tracking server instead of one that is backed by the local filesystem.\n", + "remote MLflow tracking server instead of one that is backed by the local filesystem.\n", ":::\n", "\n", "[MLflow](https://mlflow.org/) is an open source platform to manage the ML lifecycle, including experimentation,\n", @@ -29,7 +29,7 @@ "One is the {ref}`MLflowLoggerCallback `, which automatically logs\n", "metrics reported to Tune to the MLflow Tracking API.\n", "\n", - "The other one is the {ref}`@mlflow_mixin ` decorator, which can be\n", + "The other one is the {ref}`setup_mlflow ` function, which can be\n", "used with the function API. It automatically\n", "initializes the MLflow API with Tune's training information and creates a run for each Tune trial.\n", "Then within your training function, you can just use the\n", @@ -44,7 +44,7 @@ "## Running an MLflow Example\n", "\n", "In the following example we're going to use both of the above methods, namely the `MLflowLoggerCallback` and\n", - "the `mlflow_mixin` decorator to log metrics.\n", + "the `setup_mlflow` function to log metrics.\n", "Let's start with a few crucial imports:" ] }, @@ -67,8 +67,7 @@ "\n", "from ray import air, tune\n", "from ray.air import session\n", - "from ray.air.integrations.mlflow import MLflowLoggerCallback\n", - "from ray.tune.integration.mlflow import mlflow_mixin" + "from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow" ] }, { @@ -80,7 +79,7 @@ } }, "source": [ - "Next, let's define an easy objective function (a Tune `Trainable`) that iteratively computes steps and evaluates\n", + "Next, let's define an easy training function (a Tune `Trainable`) that iteratively computes steps and evaluates\n", "intermediate scores that we report to Tune." ] }, @@ -102,7 +101,7 @@ " return (0.1 + width * step / 100) ** (-1) + height * 0.1\n", "\n", "\n", - "def easy_objective(config):\n", + "def train_function(config):\n", " width, height = config[\"width\"], config[\"height\"]\n", "\n", " for step in range(config.get(\"steps\", 100)):\n", @@ -140,9 +139,9 @@ }, "outputs": [], "source": [ - "def tune_function(mlflow_tracking_uri, finish_fast=False):\n", + "def tune_with_callback(mlflow_tracking_uri, finish_fast=False):\n", " tuner = tune.Tuner(\n", - " easy_objective,\n", + " train_function,\n", " tune_config=tune.TuneConfig(\n", " num_samples=5\n", " ),\n", @@ -170,9 +169,9 @@ "id": "e086f110", "metadata": {}, "source": [ - "To use the `mlflow_mixin` decorator, you can simply decorate the objective function from earlier.\n", + "To use the `setup_mlflow` utility, you simply call this function in your training function.\n", "Note that we also use `mlflow.log_metrics(...)` to log metrics to MLflow.\n", - "Otherwise, the decorated version of our objective is identical to its original." + "Otherwise, this version of our training function is identical to its original." ] }, { @@ -189,8 +188,9 @@ }, "outputs": [], "source": [ - "@mlflow_mixin\n", - "def decorated_easy_objective(config):\n", + "def train_function_mlflow(config):\n", + " setup_mlflow(config)\n", + " \n", " # Hyperparameters\n", " width, height = config[\"width\"], config[\"height\"]\n", "\n", @@ -226,13 +226,13 @@ }, "outputs": [], "source": [ - "def tune_decorated(mlflow_tracking_uri, finish_fast=False):\n", + "def tune_with_setup(mlflow_tracking_uri, finish_fast=False):\n", " # Set the experiment, or create a new one if does not exist yet.\n", " mlflow.set_tracking_uri(mlflow_tracking_uri)\n", " mlflow.set_experiment(experiment_name=\"mixin_example\")\n", " \n", " tuner = tune.Tuner(\n", - " decorated_easy_objective,\n", + " train_function_mlflow,\n", " tune_config=tune.TuneConfig(\n", " num_samples=5\n", " ),\n", @@ -279,25 +279,78 @@ "name": "stderr", "output_type": "stream", "text": [ - "2022-07-22 16:27:41,371\tINFO services.py:1483 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8271\u001b[39m\u001b[22m\n", - "2022-07-22 16:27:43,768\tWARNING function_trainable.py:619 -- Function checkpointing is disabled. This may result in unexpected behavior when using checkpointing features or certain schedulers. To enable, set the train function arguments to be `func(config, checkpoint_dir=None)`.\n" + "2022-12-22 10:37:53,580\tINFO worker.py:1542 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { "data": { "text/html": [ - "== Status ==
Current time: 2022-07-22 16:27:50 (running for 00:00:06.29)
Memory usage on this node: 10.1/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.63 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/mlflow
Number of trials: 5/5 (5 TERMINATED)
\n", + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2022-12-22 10:38:04
Running for: 00:00:06.73
Memory: 10.4/16.0 GiB
\n", + " \n", + "
\n", + "
\n", + "

System Info

\n", + " Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.03 GiB heap, 0.0/2.0 GiB objects\n", + "
\n", + " \n", + " \n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", "\n", - "\n", + "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "
Trial name status loc height width loss iter total time (s) iterations neg_mean_loss
Trial name status loc height width loss iter total time (s) iterations neg_mean_loss
easy_objective_d4e29_00000TERMINATED127.0.0.1:52551 38 234.78039 5 0.549093 4 -4.78039
easy_objective_d4e29_00001TERMINATED127.0.0.1:52561 86 888.87624 5 0.548692 4 -8.87624
easy_objective_d4e29_00002TERMINATED127.0.0.1:52562 22 952.45641 5 0.587558 4 -2.45641
easy_objective_d4e29_00003TERMINATED127.0.0.1:52563 11 811.3994 5 0.560393 4 -1.3994
easy_objective_d4e29_00004TERMINATED127.0.0.1:52564 21 272.94746 5 0.534 4 -2.94746
train_function_b275b_00000TERMINATED127.0.0.1:801 66 367.24935 5 0.587302 4 -7.24935
train_function_b275b_00001TERMINATED127.0.0.1:813 33 353.96667 5 0.507423 4 -3.96667
train_function_b275b_00002TERMINATED127.0.0.1:814 75 298.29365 5 0.518995 4 -8.29365
train_function_b275b_00003TERMINATED127.0.0.1:815 28 633.18168 5 0.567739 4 -3.18168
train_function_b275b_00004TERMINATED127.0.0.1:816 20 183.21951 5 0.526536 4 -3.21951


" + "\n", + "
\n", + "\n", + "\n" ], "text/plain": [ "" @@ -307,246 +360,120 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2022-07-22 16:27:44,945\tINFO plugin_schema_manager.py:52 -- Loading the default runtime env schemas: ['/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/working_dir_schema.json', '/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/pip_schema.json'].\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result for easy_objective_d4e29_00000:\n", - " date: 2022-07-22_16-27-47\n", - " done: false\n", - " experiment_id: 421feb6ca1cb40969430bd0ab995fe37\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 13.8\n", - " neg_mean_loss: -13.8\n", - " node_ip: 127.0.0.1\n", - " pid: 52551\n", - " time_since_restore: 0.00015282630920410156\n", - " time_this_iter_s: 0.00015282630920410156\n", - " time_total_s: 0.00015282630920410156\n", - " timestamp: 1658503667\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d4e29_00000\n", - " warmup_time: 0.0036253929138183594\n", - " \n", - "Result for easy_objective_d4e29_00000:\n", - " date: 2022-07-22_16-27-48\n", - " done: true\n", - " experiment_id: 421feb6ca1cb40969430bd0ab995fe37\n", - " experiment_tag: 0_height=38,width=23\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 4.780392156862745\n", - " neg_mean_loss: -4.780392156862745\n", - " node_ip: 127.0.0.1\n", - " pid: 52551\n", - " time_since_restore: 0.5490927696228027\n", - " time_this_iter_s: 0.12111282348632812\n", - " time_total_s: 0.5490927696228027\n", - " timestamp: 1658503668\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d4e29_00000\n", - " warmup_time: 0.0036253929138183594\n", - " \n", - "Result for easy_objective_d4e29_00001:\n", - " date: 2022-07-22_16-27-50\n", - " done: false\n", - " experiment_id: 40ac54d80e854437b4126dca98a7f995\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 18.6\n", - " neg_mean_loss: -18.6\n", - " node_ip: 127.0.0.1\n", - " pid: 52561\n", - " time_since_restore: 0.00013113021850585938\n", - " time_this_iter_s: 0.00013113021850585938\n", - " time_total_s: 0.00013113021850585938\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d4e29_00001\n", - " warmup_time: 0.002991914749145508\n", - " \n", - "Result for easy_objective_d4e29_00002:\n", - " date: 2022-07-22_16-27-50\n", - " done: false\n", - " experiment_id: 23f2d0c4631e4a2abb5449ba68f80e8b\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 12.2\n", - " neg_mean_loss: -12.2\n", - " node_ip: 127.0.0.1\n", - " pid: 52562\n", - " time_since_restore: 0.0001289844512939453\n", - " time_this_iter_s: 0.0001289844512939453\n", - " time_total_s: 0.0001289844512939453\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d4e29_00002\n", - " warmup_time: 0.002949953079223633\n", - " \n", - "Result for easy_objective_d4e29_00003:\n", - " date: 2022-07-22_16-27-50\n", - " done: false\n", - " experiment_id: 7cb23325d6044f0f995b338d2e15f31e\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 11.1\n", - " neg_mean_loss: -11.1\n", - " node_ip: 127.0.0.1\n", - " pid: 52563\n", - " time_since_restore: 0.00010609626770019531\n", - " time_this_iter_s: 0.00010609626770019531\n", - " time_total_s: 0.00010609626770019531\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d4e29_00003\n", - " warmup_time: 0.0026869773864746094\n", - " \n", - "Result for easy_objective_d4e29_00004:\n", - " date: 2022-07-22_16-27-50\n", - " done: false\n", - " experiment_id: fc3b1add717842f4ae0b4882a1292f93\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 12.1\n", - " neg_mean_loss: -12.1\n", - " node_ip: 127.0.0.1\n", - " pid: 52564\n", - " time_since_restore: 0.00011801719665527344\n", - " time_this_iter_s: 0.00011801719665527344\n", - " time_total_s: 0.00011801719665527344\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d4e29_00004\n", - " warmup_time: 0.0028209686279296875\n", - " \n", - "Result for easy_objective_d4e29_00001:\n", - " date: 2022-07-22_16-27-50\n", - " done: true\n", - " experiment_id: 40ac54d80e854437b4126dca98a7f995\n", - " experiment_tag: 1_height=86,width=88\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 8.876243093922652\n", - " neg_mean_loss: -8.876243093922652\n", - " node_ip: 127.0.0.1\n", - " pid: 52561\n", - " time_since_restore: 0.548691987991333\n", - " time_this_iter_s: 0.12308692932128906\n", - " time_total_s: 0.548691987991333\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d4e29_00001\n", - " warmup_time: 0.002991914749145508\n", - " \n", - "Result for easy_objective_d4e29_00004:\n", - " date: 2022-07-22_16-27-50\n", - " done: true\n", - " experiment_id: fc3b1add717842f4ae0b4882a1292f93\n", - " experiment_tag: 4_height=21,width=27\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 2.947457627118644\n", - " neg_mean_loss: -2.947457627118644\n", - " node_ip: 127.0.0.1\n", - " pid: 52564\n", - " time_since_restore: 0.5339996814727783\n", - " time_this_iter_s: 0.12359499931335449\n", - " time_total_s: 0.5339996814727783\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d4e29_00004\n", - " warmup_time: 0.0028209686279296875\n", - " \n", - "Result for easy_objective_d4e29_00003:\n", - " date: 2022-07-22_16-27-50\n", - " done: true\n", - " experiment_id: 7cb23325d6044f0f995b338d2e15f31e\n", - " experiment_tag: 3_height=11,width=81\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 1.3994011976047904\n", - " neg_mean_loss: -1.3994011976047904\n", - " node_ip: 127.0.0.1\n", - " pid: 52563\n", - " time_since_restore: 0.5603930950164795\n", - " time_this_iter_s: 0.12318706512451172\n", - " time_total_s: 0.5603930950164795\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d4e29_00003\n", - " warmup_time: 0.0026869773864746094\n", - " \n", - "Result for easy_objective_d4e29_00002:\n", - " date: 2022-07-22_16-27-50\n", - " done: true\n", - " experiment_id: 23f2d0c4631e4a2abb5449ba68f80e8b\n", - " experiment_tag: 2_height=22,width=95\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 2.4564102564102566\n", - " neg_mean_loss: -2.4564102564102566\n", - " node_ip: 127.0.0.1\n", - " pid: 52562\n", - " time_since_restore: 0.5875582695007324\n", - " time_this_iter_s: 0.12340712547302246\n", - " time_total_s: 0.5875582695007324\n", - " timestamp: 1658503670\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d4e29_00002\n", - " warmup_time: 0.002949953079223633\n", - " \n" - ] + "data": { + "text/html": [ + "
\n", + "

Trial Progress

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name date done episodes_total experiment_id experiment_tag hostname iterations iterations_since_restore mean_loss neg_mean_lossnode_ip pid time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iterationtrial_id warmup_time
train_function_b275b_000002022-12-22_10-38-01True 28feaa4dd8ab4edab810e8109e77502e0_height=66,width=36kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 7.24935 -7.24935127.0.0.1 801 0.587302 0.126818 0.587302 1671705481 0 5b275b_00000 0.00293493
train_function_b275b_000012022-12-22_10-38-04True 245010d0c3d0439ebfb664764ae9db3c1_height=33,width=35kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 3.96667 -3.96667127.0.0.1 813 0.507423 0.122086 0.507423 1671705484 0 5b275b_00001 0.00553799
train_function_b275b_000022022-12-22_10-38-04True 898afbf9b906448c980f399c72a2324c2_height=75,width=29kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 8.29365 -8.29365127.0.0.1 814 0.518995 0.123554 0.518995 1671705484 0 5b275b_00002 0.0040431
train_function_b275b_000032022-12-22_10-38-04True 03a4476f82734642b6ab0a5040ca58f83_height=28,width=63kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 3.18168 -3.18168127.0.0.1 815 0.567739 0.125471 0.567739 1671705484 0 5b275b_00003 0.00406194
train_function_b275b_000042022-12-22_10-38-04True ff8c7c55ce6e404f9b0552c17f7a0c404_height=20,width=18kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 3.21951 -3.21951127.0.0.1 816 0.526536 0.123327 0.526536 1671705484 0 5b275b_00004 0.00332022
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ - "2022-07-22 16:27:51,033\tINFO tune.py:738 -- Total run time: 7.27 seconds (6.28 seconds for the tuning loop).\n", - "2022/07/22 16:27:51 INFO mlflow.tracking.fluent: Experiment with name 'mixin_example' does not exist. Creating a new experiment.\n" + "2022-12-22 10:38:04,477\tINFO tune.py:772 -- Total run time: 7.99 seconds (6.71 seconds for the tuning loop).\n" ] }, { "data": { "text/html": [ - "== Status ==
Current time: 2022-07-22 16:27:58 (running for 00:00:07.03)
Memory usage on this node: 10.4/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.63 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/mlflow
Number of trials: 5/5 (5 TERMINATED)
\n", + "
\n", + "
\n", + "
\n", + "

Tune Status

\n", + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Current time:2022-12-22 10:38:11
Running for: 00:00:07.00
Memory: 10.7/16.0 GiB
\n", + " \n", + "
\n", + "
\n", + "

System Info

\n", + " Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.03 GiB heap, 0.0/2.0 GiB objects\n", + "
\n", + " \n", + " \n", + "
\n", + "
\n", + "

Trial Status

\n", + " \n", "\n", - "\n", + "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", - "
Trial name status loc height width loss iter total time (s) iterations neg_mean_loss
Trial name status loc height width loss iter total time (s) iterations neg_mean_loss
decorated_easy_objective_d93b6_00000TERMINATED127.0.0.1:52581 45 51 4.96729 5 0.460993 4 -4.96729
decorated_easy_objective_d93b6_00001TERMINATED127.0.0.1:52598 44 94 4.65907 5 0.434945 4 -4.65907
decorated_easy_objective_d93b6_00002TERMINATED127.0.0.1:52599 93 2510.2091 5 0.471808 4 -10.2091
decorated_easy_objective_d93b6_00003TERMINATED127.0.0.1:52600 40 26 4.87719 5 0.437302 4 -4.87719
decorated_easy_objective_d93b6_00004TERMINATED127.0.0.1:52601 16 65 1.97037 5 0.468027 4 -1.97037
train_function_mlflow_b73bd_00000TERMINATED127.0.0.1:842 37 684.05461 5 0.750435 4 -4.05461
train_function_mlflow_b73bd_00001TERMINATED127.0.0.1:853 50 206.11111 5 0.652748 4 -6.11111
train_function_mlflow_b73bd_00002TERMINATED127.0.0.1:854 38 834.0924 5 0.6513 4 -4.0924
train_function_mlflow_b73bd_00003TERMINATED127.0.0.1:855 15 931.76178 5 0.650586 4 -1.76178
train_function_mlflow_b73bd_00004TERMINATED127.0.0.1:856 75 438.04945 5 0.656046 4 -8.04945


" + "\n", + "
\n", + "\n", + "\n" ], "text/plain": [ "" @@ -556,221 +483,49 @@ "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Result for decorated_easy_objective_d93b6_00000:\n", - " date: 2022-07-22_16-27-54\n", - " done: false\n", - " experiment_id: 2d0d9fbc13c64acfa27153a5fb9aeb68\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 14.5\n", - " neg_mean_loss: -14.5\n", - " node_ip: 127.0.0.1\n", - " pid: 52581\n", - " time_since_restore: 0.001725912094116211\n", - " time_this_iter_s: 0.001725912094116211\n", - " time_total_s: 0.001725912094116211\n", - " timestamp: 1658503674\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d93b6_00000\n", - " warmup_time: 0.20471811294555664\n", - " \n", - "Result for decorated_easy_objective_d93b6_00000:\n", - " date: 2022-07-22_16-27-54\n", - " done: true\n", - " experiment_id: 2d0d9fbc13c64acfa27153a5fb9aeb68\n", - " experiment_tag: 0_height=45,width=51\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 4.9672897196261685\n", - " neg_mean_loss: -4.9672897196261685\n", - " node_ip: 127.0.0.1\n", - " pid: 52581\n", - " time_since_restore: 0.46099305152893066\n", - " time_this_iter_s: 0.10984206199645996\n", - " time_total_s: 0.46099305152893066\n", - " timestamp: 1658503674\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d93b6_00000\n", - " warmup_time: 0.20471811294555664\n", - " \n", - "Result for decorated_easy_objective_d93b6_00001:\n", - " date: 2022-07-22_16-27-57\n", - " done: false\n", - " experiment_id: 4bec5377a38a47d7bae57f7502ff0312\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 14.4\n", - " neg_mean_loss: -14.4\n", - " node_ip: 127.0.0.1\n", - " pid: 52598\n", - " time_since_restore: 0.0016498565673828125\n", - " time_this_iter_s: 0.0016498565673828125\n", - " time_total_s: 0.0016498565673828125\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d93b6_00001\n", - " warmup_time: 0.18288898468017578\n", - " \n", - "Result for decorated_easy_objective_d93b6_00003:\n", - " date: 2022-07-22_16-27-57\n", - " done: false\n", - " experiment_id: 6868d31636df4c4a8e9ed91927120269\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 14.0\n", - " neg_mean_loss: -14.0\n", - " node_ip: 127.0.0.1\n", - " pid: 52600\n", - " time_since_restore: 0.0016481876373291016\n", - " time_this_iter_s: 0.0016481876373291016\n", - " time_total_s: 0.0016481876373291016\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d93b6_00003\n", - " warmup_time: 0.17208290100097656\n", - " \n", - "Result for decorated_easy_objective_d93b6_00004:\n", - " date: 2022-07-22_16-27-57\n", - " done: false\n", - " experiment_id: f021ddc2dc164413931c17cb593dfa12\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 11.6\n", - " neg_mean_loss: -11.6\n", - " node_ip: 127.0.0.1\n", - " pid: 52601\n", - " time_since_restore: 0.0015459060668945312\n", - " time_this_iter_s: 0.0015459060668945312\n", - " time_total_s: 0.0015459060668945312\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d93b6_00004\n", - " warmup_time: 0.1808018684387207\n", - " \n", - "Result for decorated_easy_objective_d93b6_00002:\n", - " date: 2022-07-22_16-27-57\n", - " done: false\n", - " experiment_id: a341941781824ea9b1a072b587e42a84\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 0\n", - " iterations_since_restore: 1\n", - " mean_loss: 19.3\n", - " neg_mean_loss: -19.3\n", - " node_ip: 127.0.0.1\n", - " pid: 52599\n", - " time_since_restore: 0.0015799999237060547\n", - " time_this_iter_s: 0.0015799999237060547\n", - " time_total_s: 0.0015799999237060547\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 1\n", - " trial_id: d93b6_00002\n", - " warmup_time: 0.1837329864501953\n", - " \n", - "Result for decorated_easy_objective_d93b6_00001:\n", - " date: 2022-07-22_16-27-57\n", - " done: true\n", - " experiment_id: 4bec5377a38a47d7bae57f7502ff0312\n", - " experiment_tag: 1_height=44,width=94\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 4.659067357512954\n", - " neg_mean_loss: -4.659067357512954\n", - " node_ip: 127.0.0.1\n", - " pid: 52598\n", - " time_since_restore: 0.43494510650634766\n", - " time_this_iter_s: 0.10719513893127441\n", - " time_total_s: 0.43494510650634766\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d93b6_00001\n", - " warmup_time: 0.18288898468017578\n", - " \n", - "Result for decorated_easy_objective_d93b6_00003:\n", - " date: 2022-07-22_16-27-57\n", - " done: true\n", - " experiment_id: 6868d31636df4c4a8e9ed91927120269\n", - " experiment_tag: 3_height=40,width=26\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 4.87719298245614\n", - " neg_mean_loss: -4.87719298245614\n", - " node_ip: 127.0.0.1\n", - " pid: 52600\n", - " time_since_restore: 0.4373021125793457\n", - " time_this_iter_s: 0.10880899429321289\n", - " time_total_s: 0.4373021125793457\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d93b6_00003\n", - " warmup_time: 0.17208290100097656\n", - " \n", - "Result for decorated_easy_objective_d93b6_00004:\n", - " date: 2022-07-22_16-27-57\n", - " done: true\n", - " experiment_id: f021ddc2dc164413931c17cb593dfa12\n", - " experiment_tag: 4_height=16,width=65\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 1.9703703703703703\n", - " neg_mean_loss: -1.9703703703703703\n", - " node_ip: 127.0.0.1\n", - " pid: 52601\n", - " time_since_restore: 0.46802687644958496\n", - " time_this_iter_s: 0.1077277660369873\n", - " time_total_s: 0.46802687644958496\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d93b6_00004\n", - " warmup_time: 0.1808018684387207\n", - " \n", - "Result for decorated_easy_objective_d93b6_00002:\n", - " date: 2022-07-22_16-27-57\n", - " done: true\n", - " experiment_id: a341941781824ea9b1a072b587e42a84\n", - " experiment_tag: 2_height=93,width=25\n", - " hostname: Kais-MacBook-Pro.local\n", - " iterations: 4\n", - " iterations_since_restore: 5\n", - " mean_loss: 10.209090909090909\n", - " neg_mean_loss: -10.209090909090909\n", - " node_ip: 127.0.0.1\n", - " pid: 52599\n", - " time_since_restore: 0.47180795669555664\n", - " time_this_iter_s: 0.10791492462158203\n", - " time_total_s: 0.47180795669555664\n", - " timestamp: 1658503677\n", - " timesteps_since_restore: 0\n", - " training_iteration: 5\n", - " trial_id: d93b6_00002\n", - " warmup_time: 0.1837329864501953\n", - " \n" - ] + "data": { + "text/html": [ + "
\n", + "

Trial Progress

\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
Trial name date done episodes_total experiment_id experiment_tag hostname iterations iterations_since_restore mean_loss neg_mean_lossnode_ip pid time_since_restore time_this_iter_s time_total_s timestamp timesteps_since_restoretimesteps_total training_iterationtrial_id warmup_time
train_function_mlflow_b73bd_000002022-12-22_10-38-08True 62703cfe82e54d74972377fbb525b0000_height=37,width=68kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 4.05461 -4.05461127.0.0.1 842 0.750435 0.108625 0.750435 1671705488 0 5b73bd_00000 0.0030272
train_function_mlflow_b73bd_000012022-12-22_10-38-11True 03ea89852115465392ed318db80216141_height=50,width=20kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 6.11111 -6.11111127.0.0.1 853 0.652748 0.110796 0.652748 1671705491 0 5b73bd_00001 0.00303078
train_function_mlflow_b73bd_000022022-12-22_10-38-11True 3731fc2966f9453ba58c650d89035ab42_height=38,width=83kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 4.0924 -4.0924 127.0.0.1 854 0.6513 0.108578 0.6513 1671705491 0 5b73bd_00002 0.00310016
train_function_mlflow_b73bd_000032022-12-22_10-38-11True fb35841742b348b9912d10203c730f1e3_height=15,width=93kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 1.76178 -1.76178127.0.0.1 855 0.650586 0.109097 0.650586 1671705491 0 5b73bd_00003 0.0576491
train_function_mlflow_b73bd_000042022-12-22_10-38-11True 6d3cbf9ecc3446369e607ff78c67bc294_height=75,width=43kais-macbook-pro.anyscale.com.beta.tailscale.net 4 5 8.04945 -8.04945127.0.0.1 856 0.656046 0.109869 0.656046 1671705491 0 5b73bd_00004 0.00265694
\n", + "
\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ - "2022-07-22 16:27:58,211\tINFO tune.py:738 -- Total run time: 7.15 seconds (7.01 seconds for the tuning loop).\n" + "2022-12-22 10:38:11,514\tINFO tune.py:772 -- Total run time: 7.01 seconds (6.98 seconds for the tuning loop).\n" ] } ], @@ -782,14 +537,14 @@ "else:\n", " mlflow_tracking_uri = \"\"\n", "\n", - "tune_function(mlflow_tracking_uri, finish_fast=smoke_test)\n", + "tune_with_callback(mlflow_tracking_uri, finish_fast=smoke_test)\n", "if not smoke_test:\n", " df = mlflow.search_runs(\n", " [mlflow.get_experiment_by_name(\"example\").experiment_id]\n", " )\n", " print(df)\n", "\n", - "tune_decorated(mlflow_tracking_uri, finish_fast=smoke_test)\n", + "tune_with_setup(mlflow_tracking_uri, finish_fast=smoke_test)\n", "if not smoke_test:\n", " df = mlflow.search_runs(\n", " [mlflow.get_experiment_by_name(\"mixin_example\").experiment_id]\n", @@ -819,12 +574,12 @@ " :noindex:\n", "```\n", "\n", - "## MLflow Mixin API\n", + "## MLflow setup API\n", "\n", - "(tune-mlflow-mixin)=\n", + "(tune-mlflow-setup)=\n", "\n", "```{eval-rst}\n", - ".. autofunction:: ray.tune.integration.mlflow.mlflow_mixin\n", + ".. autofunction:: ray.air.integrations.mlflow.setup_mlflow\n", " :noindex:\n", "```\n", "\n", diff --git a/python/ray/air/BUILD b/python/ray/air/BUILD index e8d01769c256..062b772d68ce 100644 --- a/python/ray/air/BUILD +++ b/python/ray/air/BUILD @@ -74,18 +74,35 @@ py_test( deps = [":ml_lib"] ) + py_test( - name = "test_keras_callback", + name = "test_integration_comet", size = "small", - srcs = ["tests/test_keras_callback.py"], + srcs = ["tests/test_integration_comet.py"], + deps = [":ml_lib"], tags = ["team:ml", "exclusive"], - deps = [":ml_lib"] ) py_test( - name = "test_mlflow", + name = "test_integration_wandb", size = "small", - srcs = ["tests/test_mlflow.py"], + srcs = ["tests/test_integration_wandb.py"], + deps = [":ml_lib"], + tags = ["team:ml", "exclusive"], +) + +py_test( + name = "test_integration_mlflow", + size = "small", + srcs = ["tests/test_integration_mlflow.py"], + deps = [":ml_lib"], + tags = ["team:ml", "exclusive"] +) + +py_test( + name = "test_keras_callback", + size = "small", + srcs = ["tests/test_keras_callback.py"], tags = ["team:ml", "exclusive"], deps = [":ml_lib"] ) diff --git a/python/ray/air/_internal/mlflow.py b/python/ray/air/_internal/mlflow.py index eb00ab116d7b..8e5bf5a99ea4 100644 --- a/python/ray/air/_internal/mlflow.py +++ b/python/ray/air/_internal/mlflow.py @@ -41,7 +41,7 @@ def setup_mlflow( registry_uri: Optional[str] = None, experiment_id: Optional[str] = None, experiment_name: Optional[str] = None, - tracking_token=None, + tracking_token: Optional[str] = None, create_experiment_if_not_exists: bool = True, ): """ @@ -62,13 +62,13 @@ def setup_mlflow( ``experiment_name`` will be used instead. This argument takes precedence over ``experiment_name`` if both are passed in. experiment_name: The experiment name to use for logging. - If None is passed in here, the - the MLFLOW_EXPERIMENT_NAME environment variables is used to - determine the experiment name. + If None is passed in here, the MLFLOW_EXPERIMENT_NAME environment + variable is used to determine the experiment name. If the experiment with the name already exists with MLflow, it will be reused. If not, a new experiment will be created with the provided name if ``create_experiment_if_not_exists`` is set to True. + tracking_token: Tracking token used to authenticate with MLflow. create_experiment_if_not_exists: Whether to create an experiment with the provided name if it does not already exist. Defaults to True. @@ -182,7 +182,7 @@ def start_run( """Starts a new run and possibly sets it as the active run. Args: - tags (Optional[Dict]): Tags to set for the new run. + tags: Tags to set for the new run. set_active: Whether to set the new run as the active run. If an active run already exists, then that run is returned. diff --git a/python/ray/air/integrations/mlflow.py b/python/ray/air/integrations/mlflow.py index 459b878562a9..e4b3129797d1 100644 --- a/python/ray/air/integrations/mlflow.py +++ b/python/ray/air/integrations/mlflow.py @@ -1,15 +1,193 @@ import logging -from typing import Dict, Optional +import warnings +from types import ModuleType +from typing import Dict, Optional, Union import ray +from ray.air import session + from ray.air._internal.mlflow import _MLflowLoggerUtil from ray.tune.logger import LoggerCallback from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION from ray.tune.experiment import Trial +from ray.util.annotations import PublicAPI + +try: + import mlflow +except ImportError: + mlflow = None + logger = logging.getLogger(__name__) +class _NoopModule: + def __getattr__(self, item): + return _NoopModule() + + def __call__(self, *args, **kwargs): + return None + + +@PublicAPI(stability="alpha") +def setup_mlflow( + config: Optional[Dict] = None, + tracking_uri: Optional[str] = None, + registry_uri: Optional[str] = None, + experiment_id: Optional[str] = None, + experiment_name: Optional[str] = None, + tracking_token: Optional[str] = None, + create_experiment_if_not_exists: bool = False, + tags: Optional[Dict] = None, + rank_zero_only: bool = True, +) -> Union[ModuleType, _NoopModule]: + """Set up a MLflow session. + + This function can be used to initialize an MLflow session in a + (distributed) training or tuning run. + + By default, the MLflow experiment ID is the Ray trial ID and the + MLlflow experiment name is the Ray trial name. These settings can be overwritten by + passing the respective keyword arguments. + + The ``config`` dict is automatically logged as the run parameters (excluding the + mlflow settings). + + In distributed training with Ray Train, only the zero-rank worker will initialize + mlflow. All other workers will return a noop client, so that logging is not + duplicated in a distributed run. This can be disabled by passing + ``rank_zero_only=False``, which will then initialize mlflow in every training + worker. + + This function will return the ``mlflow`` module or a noop module for + non-rank zero workers ``if rank_zero_only=True``. By using + ``mlflow = setup_mlflow(config)`` you can ensure that only the rank zero worker + calls the mlflow API. + + Args: + config: Configuration dict to be logged to mlflow as parameters. + tracking_uri: The tracking URI for MLflow tracking. If using + Tune in a multi-node setting, make sure to use a remote server for + tracking. + registry_uri: The registry URI for the MLflow model registry. + experiment_id: The id of an already created MLflow experiment. + All logs from all trials in ``tune.Tuner()`` will be reported to this + experiment. If this is not provided or the experiment with this + id does not exist, you must provide an``experiment_name``. This + parameter takes precedence over ``experiment_name``. + experiment_name: The name of an already existing MLflow + experiment. All logs from all trials in ``tune.Tuner()`` will be + reported to this experiment. If this is not provided, you must + provide a valid ``experiment_id``. + tracking_token: A token to use for HTTP authentication when + logging to a remote tracking server. This is useful when you + want to log to a Databricks server, for example. This value will + be used to set the MLFLOW_TRACKING_TOKEN environment variable on + all the remote training processes. + create_experiment_if_not_exists: Whether to create an + experiment with the provided name if it does not already + exist. Defaults to False. + tags: Tags to set for the new run. + rank_zero_only: If True, will return an initialized session only for the + rank 0 worker in distributed training. If False, will initialize a + session for all workers. Defaults to True. + + Example: + + Per default, you can just call ``setup_mlflow`` and continue to use + MLflow like you would normally do: + + .. code-block:: python + + from ray.air.integrations.mlflow import setup_mlflow + + def training_loop(config): + setup_mlflow(config) + # ... + mlflow.log_metric(key="loss", val=0.123, step=0) + + In distributed data parallel training, you can utilize the return value of + ``setup_mlflow``. This will make sure it is only invoked on the first worker + in distributed training runs. + + .. code-block:: python + + from ray.air.integrations.mlflow import setup_mlflow + + def training_loop(config): + mlflow = setup_mlflow(config) + # ... + mlflow.log_metric(key="loss", val=0.123, step=0) + + + You can also use MlFlow's autologging feature if using a training + framework like Pytorch Lightning, XGBoost, etc. More information can be + found here + (https://mlflow.org/docs/latest/tracking.html#automatic-logging). + + .. code-block:: python + + from ray.tune.integration.mlflow import setup_mlflow + + def train_fn(config): + mlflow = setup_mlflow(config) + mlflow.autolog() + xgboost_results = xgb.train(config, ...) + + """ + if not mlflow: + raise RuntimeError( + "mlflow was not found - please install with `pip install mlflow`" + ) + + try: + # Do a try-catch here if we are not in a train session + _session = session._get_session(warn=False) + if _session and rank_zero_only and session.get_world_rank() != 0: + return _NoopModule() + + default_trial_id = session.get_trial_id() + default_trial_name = session.get_trial_name() + + except RuntimeError: + default_trial_id = None + default_trial_name = None + + _config = config.copy() if config else {} + mlflow_config = _config.pop("mlflow", {}).copy() + + # Deprecate: 2.4 + if mlflow_config: + warnings.warn( + "Passing a `mlflow` key in the config dict is deprecated and will raise an " + "error in the future. Please pass the actual arguments to `setup_mlflow()` " + "instead.", + DeprecationWarning, + ) + + experiment_id = experiment_id or default_trial_id + experiment_name = experiment_name or default_trial_name + + # Setup mlflow + mlflow_util = _MLflowLoggerUtil() + mlflow_util.setup_mlflow( + tracking_uri=tracking_uri or mlflow_config.get("tracking_uri", None), + registry_uri=registry_uri or mlflow_config.get("registry_uri", None), + experiment_id=experiment_id or mlflow_config.get("experiment_id", None), + experiment_name=experiment_name or mlflow_config.get("experiment_name", None), + tracking_token=tracking_token or mlflow_config.get("tracking_token", None), + create_experiment_if_not_exists=create_experiment_if_not_exists, + ) + + mlflow_util.start_run( + run_name=experiment_name, + tags=tags or mlflow_config.get("tags", None), + set_active=True, + ) + mlflow_util.log_params(_config) + return mlflow_util._mlflow + + class MLflowLoggerCallback(LoggerCallback): """MLflow Logger to automatically log Tune results and config to MLflow. @@ -32,6 +210,7 @@ class MLflowLoggerCallback(LoggerCallback): that name. tags: An optional dictionary of string keys and values to set as tags on the run + tracking_token: Tracking token used to authenticate with MLflow. save_artifact: If set to True, automatically save the entire contents of the Tune local_dir as an artifact to the corresponding run in MlFlow. @@ -62,9 +241,11 @@ class MLflowLoggerCallback(LoggerCallback): def __init__( self, tracking_uri: Optional[str] = None, + *, registry_uri: Optional[str] = None, experiment_name: Optional[str] = None, tags: Optional[Dict] = None, + tracking_token: Optional[str] = None, save_artifact: bool = False, ): @@ -72,6 +253,7 @@ def __init__( self.registry_uri = registry_uri self.experiment_name = experiment_name self.tags = tags + self.tracking_token = tracking_token self.should_save_artifact = save_artifact self.mlflow_util = _MLflowLoggerUtil() @@ -92,6 +274,7 @@ def setup(self, *args, **kwargs): tracking_uri=self.tracking_uri, registry_uri=self.registry_uri, experiment_name=self.experiment_name, + tracking_token=self.tracking_token, ) if self.tags is None: diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py index d07657353381..c1240a94acc7 100644 --- a/python/ray/air/integrations/wandb.py +++ b/python/ray/air/integrations/wandb.py @@ -2,6 +2,7 @@ import os import pickle import urllib +import warnings import numpy as np from numbers import Number @@ -57,7 +58,11 @@ @PublicAPI(stability="alpha") def setup_wandb( - config: Optional[Dict] = None, rank_zero_only: bool = True, **kwargs + config: Optional[Dict] = None, + api_key: Optional[str] = None, + api_key_file: Optional[str] = None, + rank_zero_only: bool = True, + **kwargs, ) -> Union[Run, RunDisabled]: """Set up a Weights & Biases session. @@ -76,23 +81,19 @@ def setup_wandb( worker. The ``config`` argument will be passed to Weights and Biases and will be logged - as the run configuration. If wandb-specific settings are found, they will - be used to initialize the session. These settings can be - - - api_key_file: Path to locally available file containing a W&B API key - - api_key: API key to authenticate with W&B + as the run configuration. - If no API information is found in the config, wandb will try to authenticate + If no API key or key file are passed, wandb will try to authenticate using locally stored credentials, created for instance by running ``wandb login``. - All other keys found in the ``wandb`` config parameter will be passed to - ``wandb.init()``. If the same keys are present in multiple locations, the - ``kwargs`` passed to ``setup_wandb()`` will take precedence over those passed - as config keys. + Keyword arguments passed to ``setup_wandb()`` will be passed to + ``wandb.init()`` and take precedence over any potential default settings. Args: - config: Configuration dict to be logged to weights and biases. Can contain + config: Configuration dict to be logged to Weights and Biases. Can contain arguments for ``wandb.init()`` as well as authentication information. + api_key: API key to use for authentication with Weights and Biases. + api_key_file: File pointing to API key for with Weights and Biases. rank_zero_only: If True, will return an initialized session only for the rank 0 worker in distributed training. If False, will initialize a session for all workers. @@ -120,23 +121,36 @@ def training_loop(config): _session = session._get_session(warn=False) if _session and rank_zero_only and session.get_world_rank() != 0: return RunDisabled() - except RuntimeError: - pass - default_kwargs = { - "trial_id": kwargs.get("trial_id") or session.get_trial_id(), - "trial_name": kwargs.get("trial_name") or session.get_trial_name(), - "group": kwargs.get("group") or session.get_experiment_name(), + default_trial_id = session.get_trial_id() + default_trial_name = session.get_trial_name() + default_experiment_name = session.get_experiment_name() + + except RuntimeError: + default_trial_id = None + default_trial_name = None + default_experiment_name = None + + # Default init kwargs + wandb_init_kwargs = { + "trial_id": kwargs.get("trial_id") or default_trial_id, + "trial_name": kwargs.get("trial_name") or default_trial_name, + "group": kwargs.get("group") or default_experiment_name, } - default_kwargs.update(kwargs) + # Passed kwargs take precedence over default kwargs + wandb_init_kwargs.update(kwargs) - return _setup_wandb(config=config, **default_kwargs) + return _setup_wandb( + config=config, api_key=api_key, api_key_file=api_key_file, **wandb_init_kwargs + ) def _setup_wandb( trial_id: str, trial_name: str, config: Optional[Dict] = None, + api_key: Optional[str] = None, + api_key_file: Optional[str] = None, _wandb: Optional[ModuleType] = None, **kwargs, ) -> Union[Run, RunDisabled]: @@ -144,12 +158,21 @@ def _setup_wandb( wandb_config = _config.pop("wandb", {}).copy() + # Deprecate: 2.4 + if wandb_config: + warnings.warn( + "Passing a `wandb` key in the config dict is deprecated and will raise an " + "error in the future. Please pass the actual arguments to `setup_wandb()` " + "instead.", + DeprecationWarning, + ) + # If key file is specified, set - api_key_file = wandb_config.pop("api_key_file", None) + api_key_file = api_key_file or wandb_config.pop("api_key_file", None) if api_key_file: api_key_file = os.path.expanduser(api_key_file) - _set_api_key(api_key_file, wandb_config.pop("api_key", None)) + _set_api_key(api_key_file, api_key or wandb_config.pop("api_key", None)) wandb_config["project"] = _get_wandb_project(wandb_config.get("project")) wandb_config["group"] = ( os.environ.get(WANDB_GROUP_ENV_VAR) diff --git a/python/ray/tune/tests/test_integration_comet.py b/python/ray/air/tests/test_integration_comet.py similarity index 100% rename from python/ray/tune/tests/test_integration_comet.py rename to python/ray/air/tests/test_integration_comet.py diff --git a/python/ray/tune/tests/test_integration_mlflow.py b/python/ray/air/tests/test_integration_mlflow.py similarity index 52% rename from python/ray/tune/tests/test_integration_mlflow.py rename to python/ray/air/tests/test_integration_mlflow.py index c229ca022393..dc08b8255800 100644 --- a/python/ray/tune/tests/test_integration_mlflow.py +++ b/python/ray/air/tests/test_integration_mlflow.py @@ -1,4 +1,5 @@ import os +import shutil import tempfile import unittest from collections import namedtuple @@ -6,14 +7,14 @@ from mlflow.tracking import MlflowClient +from ray.train._internal.session import init_session from ray.tune.trainable import wrap_function +from ray.tune.trainable.session import _shutdown as tune_session_shutdown from ray.tune.integration.mlflow import ( MLflowTrainableMixin, mlflow_mixin, ) -from ray.air.integrations.mlflow import ( - MLflowLoggerCallback, -) +from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow, _NoopModule from ray.air._internal.mlflow import _MLflowLoggerUtil @@ -34,10 +35,8 @@ def save_artifacts(self, dir, run_id): def clear_env_vars(): - if "MLFLOW_EXPERIMENT_NAME" in os.environ: - del os.environ["MLFLOW_EXPERIMENT_NAME"] - if "MLFLOW_EXPERIMENT_ID" in os.environ: - del os.environ["MLFLOW_EXPERIMENT_ID"] + os.environ.pop("MLFLOW_EXPERIMENT_NAME", None) + os.environ.pop("MLFLOW_EXPERIMENT_ID", None) class MLflowTest(unittest.TestCase): @@ -51,6 +50,10 @@ def setUp(self): client.create_experiment(name="existing_experiment") assert client.get_experiment_by_name("existing_experiment").experiment_id == "0" + def tearDown(self) -> None: + # Remove tune session if initialized to clean up for next test + tune_session_shutdown() + def testMlFlowLoggerCallbackConfig(self): # Explicitly pass in all args. logger = MLflowLoggerCallback( @@ -221,11 +224,210 @@ def train_fn(config): wrap_function(train_fn)(trial_config) # Set to experiment that does not already exist. - # New experiment should be created. + # This will fail because the experiment has to be created explicitly first. + trial_config["mlflow"]["tracking_uri"] = self.tracking_uri trial_config["mlflow"]["experiment_name"] = "new_experiment" with self.assertRaises(ValueError): wrap_function(train_fn)(trial_config) + # This should now pass + trial_config["mlflow"]["experiment_name"] = "existing_experiment" + wrap_function(train_fn)(trial_config).stop() + + def testMlFlowSetupConfig(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.0} + + # No MLflow config passed in. + with self.assertRaises(ValueError): + setup_mlflow(trial_config) + + trial_config.update({"mlflow": {}}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + setup_mlflow(trial_config) + + # Invalid experiment-id + trial_config["mlflow"].update({"experiment_id": "500"}) + # No tracking uri or experiment_id/name passed in. + with self.assertRaises(ValueError): + setup_mlflow(trial_config) + + # Set to experiment that does not already exist. + # New experiment should be created. + trial_config["mlflow"]["tracking_uri"] = self.tracking_uri + trial_config["mlflow"]["experiment_name"] = "new_experiment" + with self.assertRaises(ValueError): + setup_mlflow(trial_config) + + trial_config["mlflow"]["experiment_name"] = "existing_experiment" + mlflow = setup_mlflow(trial_config) + mlflow.end_run() + + def testMlFlowSetupExplicit(self): + clear_env_vars() + trial_config = {"par1": 4, "par2": 9.0} + + # No MLflow config passed in. + with self.assertRaises(ValueError): + setup_mlflow(trial_config) + + # Invalid experiment-id + with self.assertRaises(ValueError): + setup_mlflow(trial_config, experiment_id="500") + + # Set to experiment that does not already exist. + with self.assertRaises(ValueError): + setup_mlflow( + trial_config, + experiment_id="500", + experiment_name="new_experiment", + tracking_uri=self.tracking_uri, + ) + + mlflow = setup_mlflow( + trial_config, + experiment_id="500", + experiment_name="existing_experiment", + tracking_uri=self.tracking_uri, + ) + mlflow.end_run() + + def testMlFlowSetupRankNonRankZero(self): + """Assert that non-rank-0 workers get a noop module""" + init_session( + training_func=None, + world_rank=1, + local_rank=1, + node_rank=1, + local_world_size=2, + world_size=2, + ) + mlflow = setup_mlflow({}) + assert isinstance(mlflow, _NoopModule) + + mlflow.log_metrics() + mlflow.sklearn.save_model(None, "model_directory") + + +class MLflowUtilTest(unittest.TestCase): + def setUp(self): + self.dirpath = tempfile.mkdtemp() + import mlflow + + mlflow.set_tracking_uri(self.dirpath) + mlflow.create_experiment(name="existing_experiment") + + self.mlflow_util = _MLflowLoggerUtil() + self.tracking_uri = mlflow.get_tracking_uri() + + def tearDown(self): + shutil.rmtree(self.dirpath) + + def test_experiment_id(self): + self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri, experiment_id="0") + assert self.mlflow_util.experiment_id == "0" + + def test_experiment_id_env_var(self): + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri) + assert self.mlflow_util.experiment_id == "0" + del os.environ["MLFLOW_EXPERIMENT_ID"] + + def test_experiment_name(self): + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name="existing_experiment" + ) + assert self.mlflow_util.experiment_id == "0" + + def test_run_started_with_correct_experiment(self): + experiment_name = "my_experiment_name" + # Make sure run is started under the correct experiment. + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name=experiment_name + ) + run = self.mlflow_util.start_run(set_active=True) + assert ( + run.info.experiment_id + == self.mlflow_util._mlflow.get_experiment_by_name( + experiment_name + ).experiment_id + ) + + self.mlflow_util.end_run() + + def test_experiment_name_env_var(self): + os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment" + self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri) + assert self.mlflow_util.experiment_id == "0" + del os.environ["MLFLOW_EXPERIMENT_NAME"] + + def test_id_precedence(self): + os.environ["MLFLOW_EXPERIMENT_ID"] = "0" + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name="new_experiment" + ) + assert self.mlflow_util.experiment_id == "0" + del os.environ["MLFLOW_EXPERIMENT_ID"] + + def test_new_experiment(self): + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name="new_experiment" + ) + assert self.mlflow_util.experiment_id == "1" + + def test_setup_fail(self): + with self.assertRaises(ValueError): + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, + experiment_name="new_experiment2", + create_experiment_if_not_exists=False, + ) + + def test_log_params(self): + params = {"a": "a"} + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name="new_experiment" + ) + run = self.mlflow_util.start_run() + run_id = run.info.run_id + self.mlflow_util.log_params(params_to_log=params, run_id=run_id) + + run = self.mlflow_util._mlflow.get_run(run_id=run_id) + assert run.data.params == params + + params2 = {"b": "b"} + self.mlflow_util.start_run(set_active=True) + self.mlflow_util.log_params(params_to_log=params2, run_id=run_id) + run = self.mlflow_util._mlflow.get_run(run_id=run_id) + assert run.data.params == { + **params, + **params2, + } + + self.mlflow_util.end_run() + + def test_log_metrics(self): + metrics = {"a": 1.0} + self.mlflow_util.setup_mlflow( + tracking_uri=self.tracking_uri, experiment_name="new_experiment" + ) + run = self.mlflow_util.start_run() + run_id = run.info.run_id + self.mlflow_util.log_metrics(metrics_to_log=metrics, run_id=run_id, step=0) + + run = self.mlflow_util._mlflow.get_run(run_id=run_id) + assert run.data.metrics == metrics + + metrics2 = {"b": 1.0} + self.mlflow_util.start_run(set_active=True) + self.mlflow_util.log_metrics(metrics_to_log=metrics2, run_id=run_id, step=0) + assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.metrics == { + **metrics, + **metrics2, + } + self.mlflow_util.end_run() + if __name__ == "__main__": import sys diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/air/tests/test_integration_wandb.py similarity index 100% rename from python/ray/tune/tests/test_integration_wandb.py rename to python/ray/air/tests/test_integration_wandb.py diff --git a/python/ray/air/tests/test_mlflow.py b/python/ray/air/tests/test_mlflow.py deleted file mode 100644 index 300aa660ebca..000000000000 --- a/python/ray/air/tests/test_mlflow.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -import shutil -import tempfile -import unittest - -from ray.air._internal.mlflow import _MLflowLoggerUtil - - -class MLflowTest(unittest.TestCase): - def setUp(self): - self.dirpath = tempfile.mkdtemp() - import mlflow - - mlflow.set_tracking_uri(self.dirpath) - mlflow.create_experiment(name="existing_experiment") - - self.mlflow_util = _MLflowLoggerUtil() - self.tracking_uri = mlflow.get_tracking_uri() - - def tearDown(self): - shutil.rmtree(self.dirpath) - - def test_experiment_id(self): - self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri, experiment_id="0") - assert self.mlflow_util.experiment_id == "0" - - def test_experiment_id_env_var(self): - os.environ["MLFLOW_EXPERIMENT_ID"] = "0" - self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri) - assert self.mlflow_util.experiment_id == "0" - del os.environ["MLFLOW_EXPERIMENT_ID"] - - def test_experiment_name(self): - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name="existing_experiment" - ) - assert self.mlflow_util.experiment_id == "0" - - def test_run_started_with_correct_experiment(self): - experiment_name = "my_experiment_name" - # Make sure run is started under the correct experiment. - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name=experiment_name - ) - run = self.mlflow_util.start_run(set_active=True) - assert ( - run.info.experiment_id - == self.mlflow_util._mlflow.get_experiment_by_name( - experiment_name - ).experiment_id - ) - - self.mlflow_util.end_run() - - def test_experiment_name_env_var(self): - os.environ["MLFLOW_EXPERIMENT_NAME"] = "existing_experiment" - self.mlflow_util.setup_mlflow(tracking_uri=self.tracking_uri) - assert self.mlflow_util.experiment_id == "0" - del os.environ["MLFLOW_EXPERIMENT_NAME"] - - def test_id_precedence(self): - os.environ["MLFLOW_EXPERIMENT_ID"] = "0" - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name="new_experiment" - ) - assert self.mlflow_util.experiment_id == "0" - del os.environ["MLFLOW_EXPERIMENT_ID"] - - def test_new_experiment(self): - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name="new_experiment" - ) - assert self.mlflow_util.experiment_id == "1" - - def test_setup_fail(self): - with self.assertRaises(ValueError): - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, - experiment_name="new_experiment2", - create_experiment_if_not_exists=False, - ) - - def test_log_params(self): - params = {"a": "a"} - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name="new_experiment" - ) - run = self.mlflow_util.start_run() - run_id = run.info.run_id - self.mlflow_util.log_params(params_to_log=params, run_id=run_id) - - run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.params == params - - params2 = {"b": "b"} - self.mlflow_util.start_run(set_active=True) - self.mlflow_util.log_params(params_to_log=params2, run_id=run_id) - run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.params == { - **params, - **params2, - } - - self.mlflow_util.end_run() - - def test_log_metrics(self): - metrics = {"a": 1.0} - self.mlflow_util.setup_mlflow( - tracking_uri=self.tracking_uri, experiment_name="new_experiment" - ) - run = self.mlflow_util.start_run() - run_id = run.info.run_id - self.mlflow_util.log_metrics(metrics_to_log=metrics, run_id=run_id, step=0) - - run = self.mlflow_util._mlflow.get_run(run_id=run_id) - assert run.data.metrics == metrics - - metrics2 = {"b": 1.0} - self.mlflow_util.start_run(set_active=True) - self.mlflow_util.log_metrics(metrics_to_log=metrics2, run_id=run_id, step=0) - assert self.mlflow_util._mlflow.get_run(run_id=run_id).data.metrics == { - **metrics, - **metrics2, - } - self.mlflow_util.end_run() - - -if __name__ == "__main__": - import sys - - import pytest - - sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index c2f6afcde5ae..172888302452 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -144,14 +144,6 @@ py_test( tags = ["team:ml", "exclusive"], ) -py_test( - name = "test_integration_comet", - size = "small", - srcs = ["tests/test_integration_comet.py"], - deps = [":tune_lib"], - tags = ["team:ml", "exclusive"], -) - py_test( name = "test_integration_pytorch_lightning", size = "small", @@ -160,22 +152,6 @@ py_test( tags = ["team:ml", "exclusive"], ) -py_test( - name = "test_integration_wandb", - size = "small", - srcs = ["tests/test_integration_wandb.py"], - deps = [":tune_lib"], - tags = ["team:ml", "exclusive"], -) - -py_test( - name = "test_integration_mlflow", - size = "small", - srcs = ["tests/test_integration_mlflow.py"], - deps = [":tune_lib"], - tags = ["team:ml", "exclusive"] -) - py_test( name = "test_logger", size = "small", diff --git a/python/ray/tune/examples/mlflow_example.py b/python/ray/tune/examples/mlflow_example.py index c97164ecd905..68e4434cfbba 100644 --- a/python/ray/tune/examples/mlflow_example.py +++ b/python/ray/tune/examples/mlflow_example.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Examples using MLfowLoggerCallback and mlflow_mixin. +"""Examples using MLfowLoggerCallback and setup_mlflow. """ import os import tempfile @@ -9,15 +9,14 @@ from ray import air, tune from ray.air import session -from ray.air.integrations.mlflow import MLflowLoggerCallback -from ray.tune.integration.mlflow import mlflow_mixin +from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow def evaluation_fn(step, width, height): return (0.1 + width * step / 100) ** (-1) + height * 0.1 -def easy_objective(config): +def train_function(config): # Hyperparameters width, height = config["width"], config["height"] @@ -29,10 +28,10 @@ def easy_objective(config): time.sleep(0.1) -def tune_function(mlflow_tracking_uri, finish_fast=False): +def tune_with_callback(mlflow_tracking_uri, finish_fast=False): tuner = tune.Tuner( - easy_objective, + train_function, run_config=air.RunConfig( name="mlflow", callbacks=[ @@ -55,8 +54,9 @@ def tune_function(mlflow_tracking_uri, finish_fast=False): tuner.fit() -@mlflow_mixin -def decorated_easy_objective(config): +def train_function_mlflow(config): + setup_mlflow(config) + # Hyperparameters width, height = config["width"], config["height"] @@ -70,12 +70,12 @@ def decorated_easy_objective(config): time.sleep(0.1) -def tune_decorated(mlflow_tracking_uri, finish_fast=False): +def tune_with_setup(mlflow_tracking_uri, finish_fast=False): # Set the experiment, or create a new one if does not exist yet. mlflow.set_tracking_uri(mlflow_tracking_uri) mlflow.set_experiment(experiment_name="mixin_example") tuner = tune.Tuner( - decorated_easy_objective, + train_function_mlflow, run_config=air.RunConfig( name="mlflow", ), @@ -107,40 +107,21 @@ def tune_decorated(mlflow_tracking_uri, finish_fast=False): type=str, help="The tracking URI for the MLflow tracking server.", ) - parser.add_argument( - "--server-address", - type=str, - default=None, - required=False, - help="The address of server to connect to if using Ray Client.", - ) args, _ = parser.parse_known_args() - if args.server_address: - import ray - - ray.init(f"ray://{args.server_address}") - - if args.server_address and not args.tracking_uri: - raise RuntimeError( - "If running this example with Ray Client, " - "the tracking URI for your tracking server should" - "be explicitly passed in." - ) - if args.smoke_test: mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns") else: mlflow_tracking_uri = args.tracking_uri - tune_function(mlflow_tracking_uri, finish_fast=args.smoke_test) + tune_with_callback(mlflow_tracking_uri, finish_fast=args.smoke_test) if not args.smoke_test: df = mlflow.search_runs( [mlflow.get_experiment_by_name("example").experiment_id] ) print(df) - tune_decorated(mlflow_tracking_uri, finish_fast=args.smoke_test) + tune_with_setup(mlflow_tracking_uri, finish_fast=args.smoke_test) if not args.smoke_test: df = mlflow.search_runs( [mlflow.get_experiment_by_name("mixin_example").experiment_id] diff --git a/python/ray/tune/examples/mlflow_ptl.py b/python/ray/tune/examples/mlflow_ptl.py index 3c28f9738dbb..027587782220 100644 --- a/python/ray/tune/examples/mlflow_ptl.py +++ b/python/ray/tune/examples/mlflow_ptl.py @@ -9,13 +9,13 @@ import mlflow from ray import air, tune -from ray.tune.integration.mlflow import mlflow_mixin +from ray.air.integrations.mlflow import setup_mlflow from ray.tune.integration.pytorch_lightning import TuneReportCallback from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier -@mlflow_mixin def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0): + setup_mlflow(config) model = LightningMNISTClassifier(config, data_dir) dm = MNISTDataModule( data_dir=data_dir, num_workers=1, batch_size=config["batch_size"] diff --git a/python/ray/tune/integration/mlflow.py b/python/ray/tune/integration/mlflow.py index 851ada4bafe5..581f2181234f 100644 --- a/python/ray/tune/integration/mlflow.py +++ b/python/ray/tune/integration/mlflow.py @@ -1,3 +1,5 @@ +import warnings + from ray.air.integrations.mlflow import MLflowLoggerCallback as _MLflowLoggerCallback import logging @@ -34,6 +36,13 @@ def __init__( ) +# Deprecate: Remove in 2.4 +@Deprecated( + message=( + "The MLflowTrainableMixin is deprecated. " + "Use `ray.air.integrations.mlflow.setup_mlflow` instead." + ) +) def mlflow_mixin(func: Callable): """mlflow_mixin @@ -131,6 +140,12 @@ def train_fn(config): tuner.fit() """ + warnings.warn( + "The mlflow_mixin/MLflowTrainableMixin is deprecated. " + "Use `ray.air.integrations.mlflow.setup_mlflow` instead.", + DeprecationWarning, + ) + if ray.util.client.ray.is_connected(): logger.warning( "When using mlflow_mixin with Ray Client, " @@ -147,6 +162,12 @@ def train_fn(config): return func +@Deprecated( + message=( + "The MLflowTrainableMixin is deprecated. " + "Use `ray.air.integrations.mlflow.setup_mlflow` instead." + ) +) class MLflowTrainableMixin: def __init__(self, config: Dict, *args, **kwargs): self.mlflow_util = _MLflowLoggerUtil() diff --git a/python/ray/tune/integration/wandb.py b/python/ray/tune/integration/wandb.py index ed8481d004d3..f5d3f7902cda 100644 --- a/python/ray/tune/integration/wandb.py +++ b/python/ray/tune/integration/wandb.py @@ -1,4 +1,5 @@ # Deprecate: Remove whole file in 2.4 +import warnings from typing import List, Dict, Callable, Optional from ray.air.integrations.wandb import _setup_wandb @@ -125,6 +126,12 @@ def train_fn(config): tuner.fit() """ + warnings.warn( + "The wandb_mixin/WandbTrainableMixin is deprecated. " + "Use `ray.air.integrations.wandb.setup_wandb` instead.", + DeprecationWarning, + ) + if hasattr(func, "__mixins__"): func.__mixins__ = func.__mixins__ + (WandbTrainableMixin,) else: diff --git a/python/ray/tune/tests/test_client.py b/python/ray/tune/tests/test_client.py index ec17da07b379..6252574fead1 100644 --- a/python/ray/tune/tests/test_client.py +++ b/python/ray/tune/tests/test_client.py @@ -89,11 +89,11 @@ def test_xgboost_dynamic_resources_example(start_client_server): def test_mlflow_example(start_client_server): assert ray.util.client.ray.is_connected() - from ray.tune.examples.mlflow_example import tune_function, tune_decorated + from ray.tune.examples.mlflow_example import tune_with_callback, tune_with_setup mlflow_tracking_uri = os.path.join(tempfile.gettempdir(), "mlruns") - tune_function(mlflow_tracking_uri, finish_fast=True) - tune_decorated(mlflow_tracking_uri, finish_fast=True) + tune_with_callback(mlflow_tracking_uri, finish_fast=True) + tune_with_setup(mlflow_tracking_uri, finish_fast=True) def test_pbt_transformers(start_client_server): diff --git a/python/ray/tune/trainable/session.py b/python/ray/tune/trainable/session.py index 8dd75bcf94cc..130259dc3db2 100644 --- a/python/ray/tune/trainable/session.py +++ b/python/ray/tune/trainable/session.py @@ -210,7 +210,9 @@ def _shutdown(): """Cleans up the trial and removes it from the global context.""" global _session + global _session_v2 _session = None + _session_v2 = None @Deprecated(message=_deprecation_msg)