Skip to content

Commit

Permalink
fix(tasks): enable Path objects where applicable
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 2, 2024
1 parent fcbbede commit 4096175
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/pyrovelocity/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

@beartype
def train_dataset(
adata: str | AnnData,
adata: str | Path | AnnData,
data_set_name: str = "simulated",
model_identifier: str = "model2",
models_path: str | Path = "models",
guide_type: str = "auto",
model_type: str = "auto",
batch_size: int = -1,
Expand Down Expand Up @@ -73,6 +74,7 @@ def train_dataset(
Path to a file that can be read to an AnnData object or an AnnData object.
data_set_name (str, optional): Name of the dataset. Default is "simulated".
model_identifier (str, optional): Identifier for the model. Default is "model2".
models_path (str | Path, optional): Path to the models directory. Default is "models".
guide_type (str, optional):
The type of guide function for the Pyro model. Default is "auto".
model_type (str, optional): The type of Pyro model. Default is "auto".
Expand Down Expand Up @@ -114,7 +116,7 @@ def train_dataset(

# load data
data_model = f"{data_set_name}_{model_identifier}"
data_model_path = Path(f"models/{data_model}")
data_model_path = Path(f"{models_path}/{data_model}")

trained_data_path = data_model_path / "trained.h5ad"
model_path = data_model_path / "model"
Expand Down Expand Up @@ -274,7 +276,7 @@ def check_shared_time(posterior_samples, adata):

@beartype
def train_model(
adata: str | AnnData,
adata: str | Path | AnnData,
guide_type: str = "auto",
model_type: str = "auto",
batch_size: int = -1,
Expand Down Expand Up @@ -343,7 +345,7 @@ def train_model(
>>> copy_raw_counts(adata)
>>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path)
"""
if isinstance(adata, str):
if isinstance(adata, str | Path):
adata = load_anndata_from_path(adata)

logger.info(f"AnnData object prior to model training")
Expand Down

0 comments on commit 4096175

Please sign in to comment.