From c52606731275a07da87834248614cf538e856065 Mon Sep 17 00:00:00 2001 From: wz337 Date: Fri, 5 Apr 2024 13:25:44 -0700 Subject: [PATCH] support sequence of tests and add checkpoint test address comments ghstack-source-id: 7d6c51a5ef68dea06ba7d64741a554165c79f1d3 Pull Request resolved: https://github.com/pytorch/torchtrain/pull/198 --- test/test_runner.py | 59 +++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 18 deletions(-) diff --git a/test/test_runner.py b/test/test_runner.py index f3dd8dc0..243aa2d7 100755 --- a/test/test_runner.py +++ b/test/test_runner.py @@ -21,11 +21,12 @@ class OverrideDefinitions: This class is used to define the override definitions for the integration tests. """ - override_args: Sequence[str] = tuple() - test_descr: str = "default" + override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) + test_descr: str = "" CONFIG_DIR = "./train_configs" +test_checkpoint_dir = "./test_runner_checkpoint" """ key is the config file name and value is a list of OverrideDefinitions @@ -34,13 +35,47 @@ class OverrideDefinitions: """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ - OverrideDefinitions(["--training.compile"], "1D compile"), OverrideDefinitions( - ["--training.tensor_parallel_degree 2"], "Eager mode 2DParallel" + [ + ["--training.compile"], + ], + "1D compile", + ), + OverrideDefinitions( + [ + ["--training.tensor_parallel_degree 2"], + ], + "Eager mode 2DParallel", + ), + OverrideDefinitions( + [ + [f"--checkpoint.folder {test_checkpoint_dir}"], + [f"--checkpoint.folder {test_checkpoint_dir}", "--training.steps 20"], + ], + "Checkpoint Integration Test", ), ] +def run_test(test_flavor: OverrideDefinitions, full_path: str): + # run_test supports sequence of tests. + for override_arg in test_flavor.override_args: + cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh" + if override_arg: + cmd += " " + " ".join(override_arg) + print( + f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" + ) + result = subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + print(result.stdout) + + for config_file in os.listdir(CONFIG_DIR): if config_file.endswith(".toml"): full_path = os.path.join(CONFIG_DIR, config_file) @@ -51,18 +86,6 @@ class OverrideDefinitions: test_flavors = [OverrideDefinitions()] + integration_tests_flavors[ config_file ] + for test_flavor in test_flavors: - cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh" - if test_flavor.override_args: - cmd += " " + " ".join(test_flavor.override_args) - print( - f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" - ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) - print(result.stdout) + run_test(test_flavor, full_path)