diff --git a/rllib/common.py b/rllib/common.py index 2d90aa8820f2..0c3fc96020e4 100644 --- a/rllib/common.py +++ b/rllib/common.py @@ -24,10 +24,23 @@ class SupportedFileType(str, Enum): """Supported file types for RLlib, used for CLI argument validation.""" yaml = "yaml" - json = "json" python = "python" +def get_file_type(config_file: str) -> SupportedFileType: + if ".py" in config_file: + file_type = SupportedFileType.python + elif ".yaml" in config_file or ".yml" in config_file: + file_type = SupportedFileType.yaml + else: + raise ValueError( + "Unknown file type for config " + "file: {}. Supported extensions: .py, " + "yml, yaml.".format(config_file) + ) + return file_type + + def _create_tune_parser_help(): """Create a Tune dummy parser to access its 'help' docstrings.""" parser = _make_parser( @@ -105,7 +118,7 @@ def get_help(key: str) -> str: "`ray.rllib.examples.env.simple_corridor.SimpleCorridor`).", config_file="Use the algorithm configuration from this file.", filetype="The file type of the config file. Defaults to 'yaml' and can also be " - "'json', or 'python'.", + "'python'.", experiment_name="Name of the subdirectory under `local_dir` to put results in.", framework="The identifier of the deep learning framework you want to use." "Choose between TensorFlow 1.x ('tf'), TensorFlow 2.x ('tf2'), " @@ -252,7 +265,7 @@ class CLIArguments: }, "cartpole-a2c": { "file": "tuned_examples/a2c/cartpole_a2c.py", - "file_type": SupportedFileType.python, + "stop": "{'timesteps_total': 50000, 'episode_reward_mean': 200}", "description": "Runs A2C on the CartPole-v1 environment.", }, "cartpole-a2c-micro": { @@ -261,7 +274,8 @@ class CLIArguments: }, # A3C "cartpole-a3c": { - "file": "tuned_examples/a3c/cartpole-a3c.yaml", + "file": "tuned_examples/a3c/cartpole_a3c.py", + "stop": "{'timesteps_total': 20000, 'episode_reward_mean': 150}", "description": "Runs A3C on the CartPole-v1 environment.", }, "pong-a3c": { diff --git a/rllib/scripts.py b/rllib/scripts.py index c7934906f01f..b20b97c9d146 100644 --- a/rllib/scripts.py +++ b/rllib/scripts.py @@ -10,7 +10,6 @@ from ray.rllib.common import ( EXAMPLES, FrameworkEnum, - SupportedFileType, example_help, download_example_file, ) @@ -101,11 +100,11 @@ def run(example_id: str = typer.Argument(..., help="Example ID to run.")): example = EXAMPLES[example_id] example_file = get_example_file(example_id) example_file, temp_file = download_example_file(example_file) - file_type = example.get("file_type", SupportedFileType.yaml) + stop = example.get("stop", "{}") train_module.file( config_file=example_file, - file_type=file_type, + stop=stop, checkpoint_freq=1, checkpoint_at_end=True, keep_checkpoints_num=None, diff --git a/rllib/tests/test_rllib_train_and_evaluate.py b/rllib/tests/test_rllib_train_and_evaluate.py index 6d69a256a3d3..a7bf73b74acf 100644 --- a/rllib/tests/test_rllib_train_and_evaluate.py +++ b/rllib/tests/test_rllib_train_and_evaluate.py @@ -289,20 +289,15 @@ def test_yaml_run(self): f"cartpole-simpleq-test.yaml" ).read() - def test_json_run(self): - assert os.popen( - f"python {rllib_dir}/scripts.py train file tuned_examples/simple_q/" - f"cartpole-simpleq-test.json --type=json" - ).read() - def test_python_run(self): assert os.popen( f"python {rllib_dir}/scripts.py train file tuned_examples/simple_q/" - f"cartpole_simpleq_test.py --type=python" + f"cartpole_simpleq_test.py " + f"--stop={'timesteps_total': 50000, 'episode_reward_mean': 200}" ).read() def test_all_example_files_exist(self): - """ "The 'example' command now knows about example files, + """The 'example' command now knows about example files, so we check that they exist.""" from ray.rllib.common import EXAMPLES diff --git a/rllib/train.py b/rllib/train.py old mode 100755 new mode 100644 index e32106687dce..bb36a6ee1cfd --- a/rllib/train.py +++ b/rllib/train.py @@ -13,6 +13,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.common import CLIArguments as cli from ray.rllib.common import FrameworkEnum, SupportedFileType +from ray.rllib.common import download_example_file, get_file_type def import_backends(): @@ -53,18 +54,16 @@ def _patch_path(path: str): def load_experiments_from_file( config_file: str, file_type: SupportedFileType, + stop: Optional[str] = None, checkpoint_config: Optional[dict] = None, ) -> dict: - """Load experiments from a file. Supports YAML, JSON and Python files. + """Load experiments from a file. Supports YAML and Python files. If you want to use a Python file, it has to have a 'config' variable that is an AlgorithmConfig object.""" if file_type == SupportedFileType.yaml: with open(config_file) as f: experiments = yaml.safe_load(f) - elif file_type == SupportedFileType.json: - with open(config_file) as f: - experiments = json.load(f) else: # Python file case (ensured by file type enum) import importlib @@ -90,10 +89,9 @@ def load_experiments_from_file( } } - # If there's a "stop" dict, add it to the experiment. - if hasattr(module, "stop"): - stop = getattr(module, "stop") - experiments["default"]["stop"] = stop + # Add a stopping condition if provided + if stop: + experiments["default"]["stop"] = json.loads(stop) for key, val in experiments.items(): experiments[key]["checkpoint_config"] = checkpoint_config or {} @@ -105,7 +103,8 @@ def load_experiments_from_file( def file( # File-based arguments. config_file: str = cli.ConfigFile, - file_type: SupportedFileType = cli.FileType, + # stopping conditions + stop: str = cli.Stop, # Checkpointing checkpoint_freq: int = cli.CheckpointFreq, checkpoint_at_end: bool = cli.CheckpointAtEnd, @@ -139,8 +138,6 @@ def file( rllib train file https://raw.githubusercontent.com/ray-project/ray/\ master/rllib/tuned_examples/ppo/cartpole-ppo.yaml """ - from ray.rllib.common import download_example_file - # Attempt to download the file if it's not found locally. config_file, temp_file = download_example_file( example_file=config_file, base_url=None @@ -156,7 +153,11 @@ def file( "checkpoint_score_attribute": checkpoint_score_attr, } - experiments = load_experiments_from_file(config_file, file_type, checkpoint_config) + file_type = get_file_type(config_file) + + experiments = load_experiments_from_file( + config_file, file_type, stop, checkpoint_config + ) exp_name = list(experiments.keys())[0] algo = experiments[exp_name]["run"] diff --git a/rllib/tuned_examples/a2c/cartpole-a2c.json b/rllib/tuned_examples/a2c/cartpole-a2c.json deleted file mode 100644 index bfe5781f92ee..000000000000 --- a/rllib/tuned_examples/a2c/cartpole-a2c.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "cartpole-a2c":{ - "env":"CartPole-v0", - "run":"A2C", - "stop":{ - "episode_reward_mean":150, - "timesteps_total":500000 - }, - "config":{ - "framework":"tf", - "num_workers":0, - "lr":0.001 - } - } -} \ No newline at end of file diff --git a/rllib/tuned_examples/a2c/cartpole_a2c.py b/rllib/tuned_examples/a2c/cartpole_a2c.py index f456297abd74..e48d8c466e09 100644 --- a/rllib/tuned_examples/a2c/cartpole_a2c.py +++ b/rllib/tuned_examples/a2c/cartpole_a2c.py @@ -1,3 +1,6 @@ +# Run with: +# rllib train file cartpole_a2c.py \ +# --stop={'timesteps_total': 50000, 'episode_reward_mean': 200}" from ray.rllib.algorithms.a2c import A2CConfig @@ -8,4 +11,3 @@ .framework("tf") .rollouts(num_rollout_workers=0) ) -stop = {"episode_reward_mean": 150, "timesteps_total": 500000} diff --git a/rllib/tuned_examples/a3c/__init__.py b/rllib/tuned_examples/a3c/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/tuned_examples/a3c/cartpole_a3c.py b/rllib/tuned_examples/a3c/cartpole_a3c.py new file mode 100644 index 000000000000..464f1ebcb9f8 --- /dev/null +++ b/rllib/tuned_examples/a3c/cartpole_a3c.py @@ -0,0 +1,13 @@ +# Run with: +# rllib train file cartpole_a3c.py \ +# --stop={'timesteps_total': 20000, 'episode_reward_mean': 150}" +from ray.rllib.algorithms.a3c import A3CConfig + + +config = ( + A3CConfig() + .training(gamma=0.95) + .environment("CartPole-v1") + .framework("tf") + .rollouts(num_rollout_workers=0) +) diff --git a/rllib/tuned_examples/simple_q/cartpole_simpleq_test.py b/rllib/tuned_examples/simple_q/cartpole_simpleq_test.py index 885892d5aef2..ba2816968870 100644 --- a/rllib/tuned_examples/simple_q/cartpole_simpleq_test.py +++ b/rllib/tuned_examples/simple_q/cartpole_simpleq_test.py @@ -1,3 +1,6 @@ +# Run with: +# rllib train -f cartpole_simpleq_test.py\ +# --stop={'timesteps_total': 50000, 'episode_reward_mean': 200}" from ray.rllib.algorithms.simple_q import SimpleQConfig @@ -7,4 +10,3 @@ .framework("tf") .rollouts(num_rollout_workers=0) ) -stop = {"episode_reward_mean": 150, "timesteps_total": 50000}