Skip to content

Commit

Permalink
#8694: add --collect-only for sweep tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjameelTT committed May 21, 2024
1 parent d5c7fef commit 0f8da46
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
28 changes: 17 additions & 11 deletions tests/ttnn/sweep_tests/run_sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ttnn


from sweeps import run_sweeps, print_report
from sweeps import run_sweeps, print_report, collect_only


def convert_string_to_list(string):
Expand All @@ -21,16 +21,22 @@ def convert_string_to_list(string):

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--include", type=str)

include = parser.parse_args().include

include = convert_string_to_list(include)

device = ttnn.open_device(device_id=0)
table_names = run_sweeps(device=device, include=include)
ttnn.close_device(device)
print_report(table_names=table_names)
parser.add_argument(
"--include", type=str, help="Comma separated list of sweep names to include eg) --include sweep1.py,sweep2.py"
)
parser.add_argument(
"--collect-only", action="store_true", help="Print the sweeps that will be run but do not run them"
)

args = parser.parse_args()
include = convert_string_to_list(args.include)
if args.collect_only:
collect_only(include=include)
else:
device = ttnn.open_device(device_id=0)
table_names = run_sweeps(device=device, include=include)
ttnn.close_device(device)
print_report(table_names=table_names)


if __name__ == "__main__":
Expand Down
8 changes: 8 additions & 0 deletions tests/ttnn/sweep_tests/sweeps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,14 @@ def name_to_string(name):
return table_name


def collect_only(*, include):
for file_name in sorted(SWEEP_SOURCES_DIR.glob("**/*.py")):
sweep_name = get_sweep_name(file_name)
if include and sweep_name not in include:
continue
print(sweep_name)


def run_sweeps(*, device, include):
table_names = []
for file_name in sorted(SWEEP_SOURCES_DIR.glob("**/*.py")):
Expand Down

0 comments on commit 0f8da46

Please sign in to comment.