Skip to content

Commit

Permalink
feat(workflows): use PreprocessOutputs dataclass
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 15, 2024
1 parent ed6a81f commit 1e5a223
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions src/pyrovelocity/workflows/main_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PYROVELOCITY_DATA_SUBSET,
PostprocessConfiguration,
PostprocessOutputs,
PreprocessOutputs,
ResourcesJSON,
SummarizeOutputs,
TrainingOutputs,
Expand Down Expand Up @@ -78,7 +79,7 @@ def download_data(download_dataset_args: DownloadDatasetInterface) -> FlyteFile:
"""
dataset_path = download_dataset(**asdict(download_dataset_args))
print(f"\nFlyte download data path: {dataset_path}\n")
return FlyteFile(path=dataset_path)
return FlyteFile(path=str(dataset_path))


@task(
Expand All @@ -93,19 +94,22 @@ def download_data(download_dataset_args: DownloadDatasetInterface) -> FlyteFile:
)
def preprocess_data(
data: FlyteFile, preprocess_data_args: PreprocessDataInterface
) -> FlyteFile:
) -> PreprocessOutputs:
"""
Download external data.
"""
data_path = data.download()
print(f"\nFlyte preprocess input data path: {data_path}\n")
preprocess_data_args.adata = str(data_path)
_, processed_dataset_path = preprocess_dataset(
_, processed_dataset_path, processed_reports_path = preprocess_dataset(
**asdict(preprocess_data_args),
)

print(f"\nFlyte preprocess output data path: {processed_dataset_path}\n")
return FlyteFile(path=processed_dataset_path)
return PreprocessOutputs(
processed_data=FlyteFile(path=str(processed_dataset_path)),
processed_reports=FlyteDirectory(path=str(processed_reports_path)),
)


@task(
Expand All @@ -121,13 +125,13 @@ def preprocess_data(
enable_deck=False,
)
def train_model(
processed_data: FlyteFile,
preprocess_outputs: PreprocessOutputs,
train_model_configuration: PyroVelocityTrainInterface,
) -> TrainingOutputs:
"""
Train model.
"""
processed_data_path = processed_data.download()
processed_data_path = preprocess_outputs.processed_data.download()
print(f"\nFlyte train model input data path: {processed_data_path}\n")
train_model_configuration.adata = str(processed_data_path)
(
Expand Down Expand Up @@ -158,12 +162,12 @@ def train_model(
return TrainingOutputs(
data_model=data_model,
data_model_path=data_model_path,
trained_data_path=FlyteFile(path=trained_data_path),
model_path=FlyteDirectory(path=model_path),
posterior_samples_path=FlyteFile(path=posterior_samples_path),
metrics_path=FlyteFile(path=metrics_path),
run_info_path=FlyteFile(path=run_info_path),
loss_plot_path=FlyteFile(path=loss_plot_path),
trained_data_path=FlyteFile(path=str(trained_data_path)),
model_path=FlyteDirectory(path=str(model_path)),
posterior_samples_path=FlyteFile(path=str(posterior_samples_path)),
metrics_path=FlyteFile(path=str(metrics_path)),
run_info_path=FlyteFile(path=str(run_info_path)),
loss_plot_path=FlyteFile(path=str(loss_plot_path)),
)


Expand Down Expand Up @@ -203,8 +207,8 @@ def postprocess_data(
)

return PostprocessOutputs(
pyrovelocity_data=FlyteFile(path=pyrovelocity_data_path),
postprocessed_data=FlyteFile(path=postprocessed_data_path),
pyrovelocity_data=FlyteFile(path=str(pyrovelocity_data_path)),
postprocessed_data=FlyteFile(path=str(postprocessed_data_path)),
)


Expand Down Expand Up @@ -249,8 +253,8 @@ def summarize_data(
f"\ndataframe_path: {dataframe_path}\n\n",
)
return SummarizeOutputs(
data_model_reports=FlyteDirectory(path=data_model_reports_path),
dataframe=FlyteFile(path=dataframe_path),
data_model_reports=FlyteDirectory(path=str(data_model_reports_path)),
dataframe=FlyteFile(path=str(dataframe_path)),
)


Expand Down Expand Up @@ -347,7 +351,7 @@ def map_model_configurations_over_data_set(
list[FlyteFile]: Workflow outputs as flytekit.types.file.FlyteFile objects.
"""
data = download_data(download_dataset_args=download_dataset_args)
processed_data = preprocess_data(
processed_outputs = preprocess_data(
data=data, preprocess_data_args=preprocess_data_args
)

Expand All @@ -360,7 +364,7 @@ def map_model_configurations_over_data_set(
dataset_summaries: list[SummarizeOutputs] = []
for train_model_configuration in train_model_configurations:
model_output = train_model(
processed_data=processed_data,
preprocess_outputs=processed_outputs,
train_model_configuration=train_model_configuration,
).with_overrides(
requests=Resources(**asdict(train_model_resource_requests)),
Expand Down

0 comments on commit 1e5a223

Please sign in to comment.