diff --git a/tests/ttnn/sweep_tests/run_sweeps.py b/tests/ttnn/sweep_tests/run_sweeps.py index 125ff1f8b08..206fdf758d2 100644 --- a/tests/ttnn/sweep_tests/run_sweeps.py +++ b/tests/ttnn/sweep_tests/run_sweeps.py @@ -16,16 +16,26 @@ def convert_string_to_set(string): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--include", type=str, help="Comma separated list of sweep names to include") + parser.add_argument( + "--include", type=str, help="Comma separated list of sweep names to include eg) --include sweep1.py,sweep2.py" + ) + parser.add_argument( + "--collect-only", action="store_true", help="Print the sweeps that will be run but do not run them" + ) + + args = parser.parse_args() + include = convert_string_to_set(args.include) + device = None + if not args.collect_only: + device = ttnn.open_device(device_id=0) + print("Running sweeps...") + else: + print("Collecting sweeps to run...") - include = parser.parse_args().include - - include = convert_string_to_set(include) - - device = ttnn.open_device(device_id=0) table_names = run_sweeps(device=device, include=include) - ttnn.close_device(device) - print_report(table_names=table_names) + if not args.collect_only: + ttnn.close_device(device) + print_report(table_names=table_names) if __name__ == "__main__": diff --git a/tests/ttnn/sweep_tests/sweeps/__init__.py b/tests/ttnn/sweep_tests/sweeps/__init__.py index 72ca47c2914..c1d2a6e3e8a 100644 --- a/tests/ttnn/sweep_tests/sweeps/__init__.py +++ b/tests/ttnn/sweep_tests/sweeps/__init__.py @@ -199,15 +199,24 @@ def name_to_string(name): return table_name -def run_sweeps(*, device, include): - table_names = [] +def collect_tests(*, include): for file_name in sorted(SWEEP_SOURCES_DIR.glob("**/*.py")): sweep_name = get_sweep_name(file_name) if include and sweep_name not in include: continue - logger.info(f"Running {file_name}") - table_name = run_sweep(file_name, device=device) - table_names.append(table_name) + yield file_name, sweep_name + + +def run_sweeps(*, device, include): + table_names = [] + for file_name, sweep_name in collect_tests(include=include): + if not device: + logger.info(f"Collecting {sweep_name}") + continue + else: + logger.info(f"Running {sweep_name}") + table_name = run_sweep(file_name, device=device) + table_names.append(table_name) return table_names