Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rllib] stopper support for python based training #29972

Merged
merged 6 commits into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rllib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class CLIArguments:
"cartpole-a2c": {
"file": "tuned_examples/a2c/cartpole_a2c.py",
"file_type": SupportedFileType.python,
"stop": "{'timesteps_total': 50000, 'episode_reward_mean': 200}",
"description": "Runs A2C on the CartPole-v1 environment.",
},
"cartpole-a2c-micro": {
Expand All @@ -261,7 +262,9 @@ class CLIArguments:
},
# A3C
"cartpole-a3c": {
"file": "tuned_examples/a3c/cartpole-a3c.yaml",
"file": "tuned_examples/a3c/cartpole_a3c.py",
"file_type": SupportedFileType.python,
"stop": "{'timesteps_total': 20000, 'episode_reward_mean': 150}",
"description": "Runs A3C on the CartPole-v1 environment.",
},
"pong-a3c": {
Expand Down
2 changes: 2 additions & 0 deletions rllib/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def run(example_id: str = typer.Argument(..., help="Example ID to run.")):
example_file = get_example_file(example_id)
example_file, temp_file = download_example_file(example_file)
file_type = example.get("file_type", SupportedFileType.yaml)
stop = example.get("stop", "{}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, sorry, unrelated to this PR, BUT can we automate the file_type recognition (instead of making it yaml by default, detect type from file extension).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


train_module.file(
config_file=example_file,
file_type=file_type,
stop=stop,
checkpoint_freq=1,
checkpoint_at_end=True,
keep_checkpoints_num=None,
Expand Down
5 changes: 3 additions & 2 deletions rllib/tests/test_rllib_train_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,12 @@ def test_json_run(self):
def test_python_run(self):
assert os.popen(
f"python {rllib_dir}/scripts.py train file tuned_examples/simple_q/"
f"cartpole_simpleq_test.py --type=python"
f"cartpole_simpleq_test.py --type=python "
maxpumperla marked this conversation as resolved.
Show resolved Hide resolved
f"--stop={'timesteps_total': 50000, 'episode_reward_mean': 200}"
).read()

def test_all_example_files_exist(self):
""" "The 'example' command now knows about example files,
"""The 'example' command now knows about example files,
so we check that they exist."""
from ray.rllib.common import EXAMPLES

Expand Down
18 changes: 9 additions & 9 deletions rllib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ray.tune.schedulers import create_scheduler
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.common import CLIArguments as cli
from ray.rllib.common import FrameworkEnum, SupportedFileType
from ray.rllib.common import FrameworkEnum, SupportedFileType, download_example_file


def import_backends():
Expand Down Expand Up @@ -51,7 +51,7 @@ def _patch_path(path: str):


def load_experiments_from_file(
config_file: str, file_type: SupportedFileType, checkpoint_config: dict
config_file: str, file_type: SupportedFileType, checkpoint_config: dict, stop: str
) -> dict:
"""Load experiments from a file. Supports YAML, JSON and Python files.
If you want to use a Python file, it has to have a 'config' variable
Expand Down Expand Up @@ -88,10 +88,8 @@ def load_experiments_from_file(
}
}

# If there's a "stop" dict, add it to the experiment.
if hasattr(module, "stop"):
stop = getattr(module, "stop")
experiments["default"]["stop"] = stop
# Add a stopping condition if provided
experiments["default"]["stop"] = json.loads(stop)

for key, val in experiments.items():
experiments[key]["checkpoint_config"] = checkpoint_config
Expand All @@ -104,6 +102,8 @@ def file(
# File-based arguments.
config_file: str = cli.ConfigFile,
file_type: SupportedFileType = cli.FileType,
# stopping conditions
stop: str = cli.Stop,
# Checkpointing
checkpoint_freq: int = cli.CheckpointFreq,
checkpoint_at_end: bool = cli.CheckpointAtEnd,
Expand Down Expand Up @@ -137,8 +137,6 @@ def file(
rllib train file https://raw.githubusercontent.com/ray-project/ray/\
master/rllib/tuned_examples/ppo/cartpole-ppo.yaml
"""
from ray.rllib.common import download_example_file

# Attempt to download the file if it's not found locally.
config_file, temp_file = download_example_file(
example_file=config_file, base_url=None
Expand All @@ -154,7 +152,9 @@ def file(
"checkpoint_score_attribute": checkpoint_score_attr,
}

experiments = load_experiments_from_file(config_file, file_type, checkpoint_config)
experiments = load_experiments_from_file(
config_file, file_type, checkpoint_config, stop
)
exp_name = list(experiments.keys())[0]
algo = experiments[exp_name]["run"]

Expand Down
4 changes: 3 additions & 1 deletion rllib/tuned_examples/a2c/cartpole_a2c.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Run with:
# rllib train -f cartpole_a2c.py --type python \
maxpumperla marked this conversation as resolved.
Show resolved Hide resolved
# --stop={'timesteps_total': 50000, 'episode_reward_mean': 200}"
from ray.rllib.algorithms.a2c import A2CConfig


Expand All @@ -8,4 +11,3 @@
.framework("tf")
.rollouts(num_rollout_workers=0)
)
stop = {"episode_reward_mean": 150, "timesteps_total": 500000}
Empty file.
13 changes: 13 additions & 0 deletions rllib/tuned_examples/a3c/cartpole_a3c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Run with:
# rllib train -f cartpole_a3c.py --type python \
# --stop={'timesteps_total': 20000, 'episode_reward_mean': 150}"
from ray.rllib.algorithms.a3c import A3CConfig


config = (
A3CConfig()
.training(gamma=0.95)
.environment("CartPole-v1")
.framework("tf")
.rollouts(num_rollout_workers=0)
)
4 changes: 3 additions & 1 deletion rllib/tuned_examples/simple_q/cartpole_simpleq_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Run with:
# rllib train -f cartpole_simpleq_test.py --type python \
maxpumperla marked this conversation as resolved.
Show resolved Hide resolved
# --stop={'timesteps_total': 50000, 'episode_reward_mean': 200}"
from ray.rllib.algorithms.simple_q import SimpleQConfig


Expand All @@ -7,4 +10,3 @@
.framework("tf")
.rollouts(num_rollout_workers=0)
)
stop = {"episode_reward_mean": 150, "timesteps_total": 50000}