Skip to content

Commit

Permalink
feat(workflows): add task to collect metrics for all data sets and mo…
Browse files Browse the repository at this point in the history
…dels

Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 15, 2024
1 parent 9582ab2 commit b242f80
Showing 1 changed file with 103 additions and 1 deletion.
104 changes: 103 additions & 1 deletion src/pyrovelocity/workflows/main_workflow.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import json
import os
from dataclasses import asdict
from datetime import timedelta
from pathlib import Path

from beartype.typing import List
from flytekit import Resources, current_context, dynamic, task, workflow
from flytekit.extras.accelerators import T4, GPUAccelerator
from flytekit.types.directory import FlyteDirectory
Expand All @@ -19,7 +21,12 @@
create_tarball_from_filtered_dir,
)
from pyrovelocity.io.gcs import upload_file_concurrently
from pyrovelocity.io.json import add_duration_to_run_info, combine_json_files
from pyrovelocity.io.json import (
add_duration_to_run_info,
combine_json_files,
generate_tables,
load_json,
)
from pyrovelocity.logging import configure_logging
from pyrovelocity.tasks.data import download_dataset
from pyrovelocity.tasks.postprocess import postprocess_dataset
Expand All @@ -29,6 +36,7 @@
from pyrovelocity.utils import str_to_bool
from pyrovelocity.workflows.main_configuration import (
PYROVELOCITY_DATA_SUBSET,
CombinedMetricsOutputs,
PostprocessConfiguration,
PostprocessOutputs,
PreprocessOutputs,
Expand Down Expand Up @@ -67,6 +75,7 @@
POSTPROCESS_CACHE_VERSION = f"{CACHE_VERSION}.0"
SUMMARIZE_CACHE_VERSION = f"{CACHE_VERSION}.0"
UPLOAD_CACHE_VERSION = f"{CACHE_VERSION}.0"
COMBINE_METRICS_CACHE_VERSION = f"{CACHE_VERSION}.0"
PYROVELOCITY_CACHE_FLAG = str_to_bool(
os.getenv("PYROVELOCITY_CACHE_FLAG", "True")
)
Expand Down Expand Up @@ -461,6 +470,97 @@ def map_model_configurations_over_data_set(
return dataset_summaries


@task(
cache=PYROVELOCITY_CACHE_FLAG,
cache_version=COMBINE_METRICS_CACHE_VERSION,
retries=3,
interruptible=True,
timeout=timedelta(minutes=20),
requests=Resources(cpu="2", mem="4Gi", ephemeral_storage="8Gi"),
limits=Resources(cpu="4", mem="8Gi", ephemeral_storage="16Gi"),
)
def combine_all_metrics(
results: List[List[SummarizeOutputs]]
) -> CombinedMetricsOutputs:
combined_metrics = {}

for dataset_results in results:
for model_result in dataset_results:
metrics_path = model_result.combined_metrics_path.download()
metrics_result = load_json(Path(metrics_path))

if isinstance(metrics_result, Success):
metrics = metrics_result.unwrap()
run_name = metrics.get("run_name", "unknown")
combined_metrics[run_name] = metrics
else:
print(
f"Failed to load metrics from {metrics_path}: {metrics_result.failure()}"
)

json_metrics_file = Path("combined_metrics.json")
with json_metrics_file.open("w") as f:
json.dump(combined_metrics, f, indent=2)

latex_table, html_table, markdown_table, _ = generate_tables(
combined_metrics
)

latex_metrics_file = Path("combined_metrics_table.tex")
with latex_metrics_file.open("w") as f:
f.write(latex_table)

html_metrics_file = Path("combined_metrics_table.html")
with html_metrics_file.open("w") as f:
f.write(html_table)

md_metrics_file = Path("combined_metrics_table.md")
with md_metrics_file.open("w") as f:
f.write(markdown_table)

ctx = current_context()
execution_id = ctx.execution_id.name

files_to_upload = [
json_metrics_file,
latex_metrics_file,
html_metrics_file,
md_metrics_file,
]
upload_results = []

for file in files_to_upload:
upload_result = upload_file_concurrently(
bucket_name=f"pyrovelocity/reports/{execution_id}",
source_filename=file,
destination_blob_name=str(file),
)
upload_results.append(upload_result)

if all(isinstance(result, Success) for result in upload_results):
print("\nAll uploads successful.")
return CombinedMetricsOutputs(
json_metrics=FlyteFile(path=str(json_metrics_file)),
latex_metrics=FlyteFile(path=str(latex_metrics_file)),
html_metrics=FlyteFile(path=str(html_metrics_file)),
md_metrics=FlyteFile(path=str(md_metrics_file)),
)
else:
print("\nOne or more uploads failed.")
failed_uploads = [
str(file)
for file, result in zip(files_to_upload, upload_results)
if isinstance(result, Failure)
]
print(f"Failed uploads: {', '.join(failed_uploads)}")
return CombinedMetricsOutputs(
json_metrics=FlyteFile(path=""),
latex_metrics=FlyteFile(path=""),
html_metrics=FlyteFile(path=""),
md_metrics=FlyteFile(path=""),
)


@dynamic
def training_workflow(
simulated_configuration: WorkflowConfiguration = simulated_configuration,
Expand Down Expand Up @@ -503,6 +603,8 @@ def training_workflow(
)
results.append(result)

combine_all_metrics(results=results)

return results


Expand Down

0 comments on commit b242f80

Please sign in to comment.