Skip to content

Commit

Permalink
#8694: add --collect-only for sweep tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed May 22, 2024
1 parent 432117b commit 7f6b022
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
26 changes: 18 additions & 8 deletions tests/ttnn/sweep_tests/run_sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,26 @@ def convert_string_to_set(string):

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--include", type=str, help="Comma separated list of sweep names to include")
parser.add_argument(
"--include", type=str, help="Comma separated list of sweep names to include eg) --include sweep1.py,sweep2.py"
)
parser.add_argument(
"--collect-only", action="store_true", help="Print the sweeps that will be run but do not run them"
)

args = parser.parse_args()
include = convert_string_to_set(args.include)
device = None
if not args.collect_only:
device = ttnn.open_device(device_id=0)
print("Running sweeps...")
else:
print("Collecting sweeps to run...")

include = parser.parse_args().include

include = convert_string_to_set(include)

device = ttnn.open_device(device_id=0)
table_names = run_sweeps(device=device, include=include)
ttnn.close_device(device)
print_report(table_names=table_names)
if not args.collect_only:
ttnn.close_device(device)
print_report(table_names=table_names)


if __name__ == "__main__":
Expand Down
19 changes: 14 additions & 5 deletions tests/ttnn/sweep_tests/sweeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,24 @@ def name_to_string(name):
return table_name


def run_sweeps(*, device, include):
table_names = []
def collect_tests(*, include):
for file_name in sorted(SWEEP_SOURCES_DIR.glob("**/*.py")):
sweep_name = get_sweep_name(file_name)
if include and sweep_name not in include:
continue
logger.info(f"Running {file_name}")
table_name = run_sweep(file_name, device=device)
table_names.append(table_name)
yield file_name, sweep_name


def run_sweeps(*, device, include):
table_names = []
for file_name, sweep_name in collect_tests(include=include):
if not device:
logger.info(f"Collecting {sweep_name}")
continue
else:
logger.info(f"Running {sweep_name}")
table_name = run_sweep(file_name, device=device)
table_names.append(table_name)
return table_names


Expand Down

0 comments on commit 7f6b022

Please sign in to comment.