From 1e5a223399418737c3beb6b43903bbea8fd5b9b8 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Wed, 14 Aug 2024 22:14:05 -0400 Subject: [PATCH] feat(workflows): use PreprocessOutputs dataclass Signed-off-by: Cameron Smith --- src/pyrovelocity/workflows/main_workflow.py | 40 +++++++++++---------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/pyrovelocity/workflows/main_workflow.py b/src/pyrovelocity/workflows/main_workflow.py index ca80dbace..27f9feb9e 100644 --- a/src/pyrovelocity/workflows/main_workflow.py +++ b/src/pyrovelocity/workflows/main_workflow.py @@ -26,6 +26,7 @@ PYROVELOCITY_DATA_SUBSET, PostprocessConfiguration, PostprocessOutputs, + PreprocessOutputs, ResourcesJSON, SummarizeOutputs, TrainingOutputs, @@ -78,7 +79,7 @@ def download_data(download_dataset_args: DownloadDatasetInterface) -> FlyteFile: """ dataset_path = download_dataset(**asdict(download_dataset_args)) print(f"\nFlyte download data path: {dataset_path}\n") - return FlyteFile(path=dataset_path) + return FlyteFile(path=str(dataset_path)) @task( @@ -93,19 +94,22 @@ def download_data(download_dataset_args: DownloadDatasetInterface) -> FlyteFile: ) def preprocess_data( data: FlyteFile, preprocess_data_args: PreprocessDataInterface -) -> FlyteFile: +) -> PreprocessOutputs: """ Download external data. """ data_path = data.download() print(f"\nFlyte preprocess input data path: {data_path}\n") preprocess_data_args.adata = str(data_path) - _, processed_dataset_path = preprocess_dataset( + _, processed_dataset_path, processed_reports_path = preprocess_dataset( **asdict(preprocess_data_args), ) print(f"\nFlyte preprocess output data path: {processed_dataset_path}\n") - return FlyteFile(path=processed_dataset_path) + return PreprocessOutputs( + processed_data=FlyteFile(path=str(processed_dataset_path)), + processed_reports=FlyteDirectory(path=str(processed_reports_path)), + ) @task( @@ -121,13 +125,13 @@ def preprocess_data( enable_deck=False, ) def train_model( - processed_data: FlyteFile, + preprocess_outputs: PreprocessOutputs, train_model_configuration: PyroVelocityTrainInterface, ) -> TrainingOutputs: """ Train model. """ - processed_data_path = processed_data.download() + processed_data_path = preprocess_outputs.processed_data.download() print(f"\nFlyte train model input data path: {processed_data_path}\n") train_model_configuration.adata = str(processed_data_path) ( @@ -158,12 +162,12 @@ def train_model( return TrainingOutputs( data_model=data_model, data_model_path=data_model_path, - trained_data_path=FlyteFile(path=trained_data_path), - model_path=FlyteDirectory(path=model_path), - posterior_samples_path=FlyteFile(path=posterior_samples_path), - metrics_path=FlyteFile(path=metrics_path), - run_info_path=FlyteFile(path=run_info_path), - loss_plot_path=FlyteFile(path=loss_plot_path), + trained_data_path=FlyteFile(path=str(trained_data_path)), + model_path=FlyteDirectory(path=str(model_path)), + posterior_samples_path=FlyteFile(path=str(posterior_samples_path)), + metrics_path=FlyteFile(path=str(metrics_path)), + run_info_path=FlyteFile(path=str(run_info_path)), + loss_plot_path=FlyteFile(path=str(loss_plot_path)), ) @@ -203,8 +207,8 @@ def postprocess_data( ) return PostprocessOutputs( - pyrovelocity_data=FlyteFile(path=pyrovelocity_data_path), - postprocessed_data=FlyteFile(path=postprocessed_data_path), + pyrovelocity_data=FlyteFile(path=str(pyrovelocity_data_path)), + postprocessed_data=FlyteFile(path=str(postprocessed_data_path)), ) @@ -249,8 +253,8 @@ def summarize_data( f"\ndataframe_path: {dataframe_path}\n\n", ) return SummarizeOutputs( - data_model_reports=FlyteDirectory(path=data_model_reports_path), - dataframe=FlyteFile(path=dataframe_path), + data_model_reports=FlyteDirectory(path=str(data_model_reports_path)), + dataframe=FlyteFile(path=str(dataframe_path)), ) @@ -347,7 +351,7 @@ def map_model_configurations_over_data_set( list[FlyteFile]: Workflow outputs as flytekit.types.file.FlyteFile objects. """ data = download_data(download_dataset_args=download_dataset_args) - processed_data = preprocess_data( + processed_outputs = preprocess_data( data=data, preprocess_data_args=preprocess_data_args ) @@ -360,7 +364,7 @@ def map_model_configurations_over_data_set( dataset_summaries: list[SummarizeOutputs] = [] for train_model_configuration in train_model_configurations: model_output = train_model( - processed_data=processed_data, + preprocess_outputs=processed_outputs, train_model_configuration=train_model_configuration, ).with_overrides( requests=Resources(**asdict(train_model_resource_requests)),