diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 357d0f30..cc667a53 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -1020,6 +1020,14 @@ def import_libs(build_out_map: Dict, hexcode: str, linkReferences: Dict) -> Dict return libs +def build_output_iterator(build_out: Dict): + for compiler_version in sorted(build_out): + build_out_map = build_out[compiler_version] + for filename in sorted(build_out_map): + for contract_name in sorted(build_out_map[filename]): + yield (build_out_map, filename, contract_name) + + @dataclass(frozen=True) class MainResult: exitcode: int @@ -1075,6 +1083,7 @@ def _main(_args=None) -> MainResult: print(color_warn(f"Build failed: {build_cmd}")) return MainResult(1) + timer.create_subtimer("load") try: build_out = parse_build_out(args) except Exception as err: @@ -1092,83 +1101,66 @@ def _main(_args=None) -> MainResult: total_passed = 0 total_failed = 0 total_found = 0 - test_results_map = {} - for compiler_version in sorted(build_out): - build_out_map = build_out[compiler_version] - for filename in sorted(build_out_map): - for contract_name in sorted(build_out_map[filename]): - if args.contract and args.contract != contract_name: - continue + for build_out_map, filename, contract_name in build_output_iterator(build_out): + if args.contract and args.contract != contract_name: + continue - (contract_json, contract_type, natspec) = build_out_map[filename][ - contract_name - ] - if contract_type != "contract": - continue + (contract_json, contract_type, natspec) = build_out_map[filename][contract_name] + if contract_type != "contract": + continue - creation_hexcode = contract_json["bytecode"]["object"] - deployed_hexcode = contract_json["deployedBytecode"]["object"] + methodIdentifiers = contract_json["methodIdentifiers"] + funsigs = [f for f in methodIdentifiers if f.startswith(args.function)] + num_found = len(funsigs) - abi = contract_json["abi"] - methodIdentifiers = contract_json["methodIdentifiers"] - linkReferences = contract_json["bytecode"]["linkReferences"] + if num_found == 0: + continue - libs = ( - import_libs(build_out_map, creation_hexcode, linkReferences) - if linkReferences - else {} - ) + contract_timer = NamedTimer("time") - funsigs = [ - funsig - for funsig in methodIdentifiers - if funsig.startswith(args.function) - ] + abi = contract_json["abi"] + creation_hexcode = contract_json["bytecode"]["object"] + deployed_hexcode = contract_json["deployedBytecode"]["object"] + linkReferences = contract_json["bytecode"]["linkReferences"] + libs = import_libs(build_out_map, creation_hexcode, linkReferences) - if funsigs: - total_found += len(funsigs) - contract_path = ( - f"{contract_json['ast']['absolutePath']}:{contract_name}" - ) - print(f"\nRunning {len(funsigs)} tests for {contract_path}") - contract_timer = NamedTimer("time") + contract_path = f"{contract_json['ast']['absolutePath']}:{contract_name}" + print(f"\nRunning {num_found} tests for {contract_path}") + contract_args = extend_args(args, parse_natspec(natspec)) if natspec else args - contract_args = ( - extend_args(args, parse_natspec(natspec)) if natspec else args - ) + run_args = RunArgs( + funsigs, + creation_hexcode, + deployed_hexcode, + abi, + methodIdentifiers, + contract_args, + contract_json, + libs, + ) - run_args = RunArgs( - funsigs, - creation_hexcode, - deployed_hexcode, - abi, - methodIdentifiers, - contract_args, - contract_json, - libs, - ) + enable_parallel = args.test_parallel and num_found > 1 + run_method = run_parallel if enable_parallel else run_sequential + test_results = run_method(run_args) - enable_parallel = args.test_parallel and len(funsigs) > 1 - test_results = ( - run_parallel(run_args) - if enable_parallel - else run_sequential(run_args) - ) + num_passed = sum(r.exitcode == 0 for r in test_results) + num_failed = num_found - num_passed - num_passed = sum(r.exitcode == 0 for r in test_results) - num_failed = len(funsigs) - num_passed + print( + f"Symbolic test result: {num_passed} passed; " + f"{num_failed} failed; {contract_timer.report()}" + ) - print( - f"Symbolic test result: {num_passed} passed; {num_failed} failed; {contract_timer.report()}" - ) - total_passed += num_passed - total_failed += num_failed + total_found += num_found + total_passed += num_passed + total_failed += num_failed + + if contract_path in test_results_map: + raise ValueError("already exists", contract_path) - if contract_path in test_results_map: - raise ValueError("already exists", contract_path) - test_results_map[contract_path] = test_results + test_results_map[contract_path] = test_results if args.statistics: print(f"\n[time] {timer.report()}") diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 78e5e152..7609f4af 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -688,4 +688,7 @@ def __str__(self): return self.report() def __repr__(self): - return f"NamedTimer(name={self.name}, start_time={self.start_time}, end_time={self.end_time}, sub_timers={self.sub_timers})" + return ( + f"NamedTimer(name={self.name}, start_time={self.start_time}, " + f"end_time={self.end_time}, sub_timers={self.sub_timers})" + )