Skip to content

Commit

Permalink
write after each group, not at the end
Browse files Browse the repository at this point in the history
  • Loading branch information
parasj committed Jan 17, 2022
1 parent 75624c7 commit a6e705b
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions skylark/cli/experiments/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def throughput_grid(
aws_region_list: List[str] = typer.Option(aws_regions, "-aws"),
azure_region_list: List[str] = typer.Option(azure_regions, "-azure"),
gcp_region_list: List[str] = typer.Option(gcp_regions, "-gcp"),
enable_aws: bool = typer.Option(True),
enable_azure: bool = typer.Option(True),
enable_gcp: bool = typer.Option(True),
# instances to provision
aws_instance_class: str = typer.Option("m5.8xlarge", help="AWS instance class to use"),
azure_instance_class: str = typer.Option("Standard_D32_v4", help="Azure instance class to use"),
Expand Down Expand Up @@ -117,14 +120,20 @@ def throughput_grid(
resume_keys = []

# validate arguments
aws_region_list = aws_region_list if enable_aws else []
azure_region_list = azure_region_list if enable_azure else []
gcp_region_list = gcp_region_list if enable_gcp else []
if not enable_aws and not enable_azure and not enable_gcp:
log_error("At least one of -aws, -azure, -gcp must be enabled.")
typer.Abort()
if not all(r in aws_regions for r in aws_region_list):
typer.secho(f"Invalid AWS region list: {aws_region_list}", fg="red")
log_error(f"Invalid AWS region list: {aws_region_list}")
typer.Abort()
if not all(r in azure_regions for r in azure_region_list):
typer.secho(f"Invalid Azure region list: {azure_region_list}", fg="red")
log_error(f"Invalid Azure region list: {azure_region_list}")
typer.Abort()
if not all(r in gcp_regions for r in gcp_region_list):
typer.secho(f"Invalid GCP region list: {gcp_region_list}", fg="red")
log_error(f"Invalid GCP region list: {gcp_region_list}")
typer.Abort()
assert not gcp_test_standard_network, "GCP standard network is not supported yet"
assert not azure_test_standard_network, "Azure standard network is not supported yet"
Expand Down Expand Up @@ -228,17 +237,20 @@ def client_fn(instance_pair):

# run experiments
new_througput_results = []
output_file = log_dir / "throughput.csv"
with tqdm(total=len(instance_pairs), desc="Total throughput evaluation") as pbar:
for group_idx, group in enumerate(groups):
tag_fmt = lambda x: f"{x[0].region_tag}:{x[0].network_tier()} to {x[1].region_tag}:{x[1].network_tier()}"
results = do_parallel(client_fn, group, progress_bar=True, desc=f"Parallel eval group {group_idx}", n=-1, arg_fmt=tag_fmt)
new_througput_results.extend([rec for args, rec in results if rec is not None])

# build dataframe from results
output_file = log_dir / "throughput.csv"
log_success(f"Saving intermediate results to {output_file}")
df = pd.DataFrame(new_througput_results)
if resume_from_file and copy_resume_file:
log_info(f"Copying old CSV entries from {resume_from_file}")
df = df.append(pd.read_csv(resume_from_file))
df.to_csv(output_file, index=False)
# build dataframe from results
tqdm.write(f"Saving intermediate results to {output_file}")
df = pd.DataFrame(new_througput_results)
if resume_from_file and copy_resume_file:
log_info(f"Copying old CSV entries from {resume_from_file}")
df = df.append(pd.read_csv(resume_from_file))
df.to_csv(output_file, index=False)

log_success(f"Experiment complete: {experiment_tag}")
log_success(f"Results saved to {output_file}")

0 comments on commit a6e705b

Please sign in to comment.