Skip to content

Commit

Permalink
support sequence of tests and add checkpoint test
Browse files Browse the repository at this point in the history
address comments

ghstack-source-id: 7d6c51a5ef68dea06ba7d64741a554165c79f1d3
Pull Request resolved: #198
  • Loading branch information
wz337 committed Apr 5, 2024
1 parent 2c21f36 commit c526067
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions test/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ class OverrideDefinitions:
This class is used to define the override definitions for the integration tests.
"""

override_args: Sequence[str] = tuple()
test_descr: str = "default"
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = ""


CONFIG_DIR = "./train_configs"
test_checkpoint_dir = "./test_runner_checkpoint"

"""
key is the config file name and value is a list of OverrideDefinitions
Expand All @@ -34,13 +35,47 @@ class OverrideDefinitions:
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(["--training.compile"], "1D compile"),
OverrideDefinitions(
["--training.tensor_parallel_degree 2"], "Eager mode 2DParallel"
[
["--training.compile"],
],
"1D compile",
),
OverrideDefinitions(
[
["--training.tensor_parallel_degree 2"],
],
"Eager mode 2DParallel",
),
OverrideDefinitions(
[
[f"--checkpoint.folder {test_checkpoint_dir}"],
[f"--checkpoint.folder {test_checkpoint_dir}", "--training.steps 20"],
],
"Checkpoint Integration Test",
),
]


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)
print(result.stdout)


for config_file in os.listdir(CONFIG_DIR):
if config_file.endswith(".toml"):
full_path = os.path.join(CONFIG_DIR, config_file)
Expand All @@ -51,18 +86,6 @@ class OverrideDefinitions:
test_flavors = [OverrideDefinitions()] + integration_tests_flavors[
config_file
]

for test_flavor in test_flavors:
cmd = f"CONFIG_FILE={full_path} NGPU=4 ./run_llama_train.sh"
if test_flavor.override_args:
cmd += " " + " ".join(test_flavor.override_args)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)
print(result.stdout)
run_test(test_flavor, full_path)

0 comments on commit c526067

Please sign in to comment.