Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Datamodule #66

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ options:
A simple example of inference is shown below:

```bash
runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/camp0/poesy/halfunet/sezn_run_dev_30 --date 2021061621 --dataset poesy_infer --infer_steps 2 --grib
runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/camp0/poesy/halfunet/sezn_run_dev_12 --date 2021061621 --dataset poesy_infer --infer_steps 2
```

### Making animated plots comparing multiple models
Expand Down
27 changes: 13 additions & 14 deletions bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@

from pytorch_lightning import Trainer

from py4cast.datasets import get_datasets
from py4cast.datasets.base import TorchDataloaderSettings
from py4cast.io.outputs import GribSavingSettings, save_named_tensors_to_grib
from py4cast.lightning import AutoRegressiveLightning
from py4cast.lightning import AutoRegressiveLightning, PlDataModule

default_config_root = Path(__file__).parents[1] / "config/IO/"

Expand Down Expand Up @@ -67,22 +66,22 @@
else:
config_override = {"num_inference_pred_steps": args.infer_steps}

# Get dataset for inference
_, _, infer_ds = get_datasets(
args.dataset,
hparams.num_input_steps,
hparams.num_pred_steps_train,
hparams.num_pred_steps_val_test,
args.dataset_conf,
dl_settings = TorchDataloaderSettings(batch_size=hparams.batch_size)

dm = PlDataModule(
dataset=args.dataset,
num_input_steps=hparams.num_input_steps,
num_pred_steps_train=hparams.num_pred_steps_train,
num_pred_steps_val_test=hparams.num_pred_steps_val_test,
dl_settings=dl_settings,
dataset_conf=args.dataset_conf,
config_override=config_override,
)

# Transform into dataloader
dl_settings = TorchDataloaderSettings(batch_size=1)
infer_loader = infer_ds.torch_dataloader(dl_settings)

trainer = Trainer(devices="auto")
preds = trainer.predict(lightning_module, infer_loader)
preds = trainer.predict(lightning_module, dm)

infer_ds = dm.get_infer_ds
CorentinSeznec marked this conversation as resolved.
Show resolved Hide resolved

if args.grib:
with open(default_config_root / args.saving_conf, "r") as f:
Expand Down
49 changes: 26 additions & 23 deletions bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from py4cast.datasets import get_datasets
from py4cast.datasets import registry as dataset_registry
from py4cast.datasets.base import TorchDataloaderSettings
from py4cast.lightning import ArLightningHyperParam, AutoRegressiveLightning
from py4cast.lightning import (
ArLightningHyperParam,
AutoRegressiveLightning,
PlDataModule,
)
from py4cast.models import registry as model_registry
from py4cast.settings import ROOTDIR

Expand All @@ -31,7 +34,6 @@
},
}


# Variables for multi-nodes multi-gpu training
nb_nodes = int(os.environ.get("SLURM_NNODES", 1))
if nb_nodes > 1:
Expand Down Expand Up @@ -223,36 +225,38 @@
run_id = date.strftime("%b-%d-%Y-%M-%S")
seed.seed_everything(args.seed)


# Init datasets and dataloaders
datasets = get_datasets(
args.dataset,
args.num_input_steps,
args.num_pred_steps_train,
args.num_pred_steps_val_test,
args.dataset_conf,
)
dl_settings = TorchDataloaderSettings(
batch_size=args.batch_size,
num_workers=args.num_workers,
prefetch_factor=args.prefetch_factor,
pin_memory=args.pin_memory,
)
train_ds, val_ds, test_ds = datasets
train_loader = train_ds.torch_dataloader(dl_settings)
val_loader = val_ds.torch_dataloader(dl_settings)
test_loader = test_ds.torch_dataloader(dl_settings)

# Wrap dataset with lightning datamodule
dm = PlDataModule(
dataset=args.dataset,
num_input_steps=args.num_input_steps,
num_pred_steps_train=args.num_pred_steps_train,
num_pred_steps_val_test=args.num_pred_steps_val_test,
dl_settings=dl_settings,
dataset_conf=args.dataset_conf,
config_override=None,
)

# Get essential info to instantiate ArLightningHyperParam
len_loader = dm.get_len_train_dl
dataset_info = dm.get_train_dataset_info

# Setup GPU usage + get len of loader for LR scheduler
if torch.cuda.is_available():
device_name = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s
len_loader = len(train_loader) // (torch.cuda.device_count() * nb_nodes)
len_loader = len_loader // (torch.cuda.device_count() * nb_nodes)
else:
device_name = "cpu"
len_loader = len(train_loader)
len_loader = len_loader

# Get Log folders
log_dir = ROOTDIR / "logs"
Expand All @@ -266,9 +270,8 @@
subfolder = f"{run_name}_{version}"
save_path = log_dir / folder / subfolder


hp = ArLightningHyperParam(
dataset_info=datasets[0].dataset_info,
dataset_info=dataset_info,
dataset_name=args.dataset,
dataset_conf=args.dataset_conf,
batch_size=args.batch_size,
Expand Down Expand Up @@ -358,6 +361,7 @@
limit_val_batches=args.limit_train_batches, # No reason to spend hours on validation if we limit the training.
limit_test_batches=args.limit_train_batches,
)

if args.load_model_ckpt:
lightning_module = AutoRegressiveLightning.load_from_checkpoint(
args.load_model_ckpt, hparams=hp
Expand All @@ -370,8 +374,7 @@
print("Starting training !")
trainer.fit(
model=lightning_module,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
datamodule=dm,
)

if not args.no_log:
Expand All @@ -383,4 +386,4 @@
print(
f"Testing using {'best' if best_checkpoint else 'last'} model at {model_to_test}"
)
trainer.test(ckpt_path=model_to_test, dataloaders=test_loader)
trainer.test(ckpt_path=model_to_test, datamodule=dm)
16 changes: 15 additions & 1 deletion config/datasets/poesy_infer.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,21 @@
"arome": {
"members": [
0,
1
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15
],
"term": {
"start": 0,
Expand Down
2 changes: 1 addition & 1 deletion py4cast/datasets/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def torch_dataloader(
) -> DataLoader:
return DataLoader(
self,
tl_settings.batch_size,
batch_size=tl_settings.batch_size,
num_workers=tl_settings.num_workers,
shuffle=self.shuffle,
prefetch_factor=tl_settings.prefetch_factor,
Expand Down
10 changes: 7 additions & 3 deletions py4cast/datasets/poesy.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def load_data(self, date: dt.datetime, term: List, member: int) -> np.array:
term : Position of leadtimes in file.
"""
data_array = np.load(self.filename(date=date), mmap_mode="r")

return data_array[
self.grid.subgrid[0] : self.grid.subgrid[1],
self.grid.subgrid[2] : self.grid.subgrid[3],
Expand Down Expand Up @@ -579,8 +578,14 @@ def from_json(
num_pred_steps_val_test: int,
config_override: Union[Dict, None] = None,
) -> Tuple["PoesyDataset", "PoesyDataset", "PoesyDataset"]:
"""
Return 3 PoesyDataset.
Override configuration file if needed.
"""
with open(fname, "r") as fp:
conf = json.load(fp)
if config_override is not None:
conf = merge_dicts(conf, config_override)

grid = Grid(**conf["grid"])
param_list = []
Expand Down Expand Up @@ -645,7 +650,7 @@ def torch_dataloader(
) -> DataLoader:
return DataLoader(
self,
tl_settings.batch_size,
batch_size=tl_settings.batch_size,
num_workers=tl_settings.num_workers,
shuffle=self.shuffle,
prefetch_factor=tl_settings.prefetch_factor,
Expand Down Expand Up @@ -852,7 +857,6 @@ def sample_list(self):
samples.append(samp)
number += 1
print("All samples are now defined")
print(samples)

return samples

Expand Down
9 changes: 8 additions & 1 deletion py4cast/datasets/smeagol.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from py4cast.forcingutils import generate_toa_radiation_forcing, get_year_hour_forcing
from py4cast.plots import DomainInfo
from py4cast.utils import merge_dicts

# torch.set_num_threads(8)
SCRATCH_PATH = Path(os.environ.get("PY4CAST_SMEAGOL_PATH", "/scratch/shared/smeagol"))
Expand Down Expand Up @@ -597,8 +598,14 @@ def from_json(
num_pred_steps_val_test: int,
config_override: Union[Dict, None] = None,
) -> Tuple["SmeagolDataset", "SmeagolDataset", "SmeagolDataset"]:
"""
Return 3 SmeagolDataset.
Override configuration file if needed.
"""
with open(fname, "r") as fp:
conf = json.load(fp)
if config_override is not None:
conf = merge_dicts(conf, config_override)

grid = Grid(**conf["grid"])
param_list = []
Expand Down Expand Up @@ -678,7 +685,7 @@ def torch_dataloader(
) -> DataLoader:
return DataLoader(
self,
tl_settings.batch_size,
batch_size=tl_settings.batch_size,
num_workers=tl_settings.num_workers,
shuffle=self.shuffle,
prefetch_factor=tl_settings.prefetch_factor,
Expand Down
2 changes: 1 addition & 1 deletion py4cast/datasets/titan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def torch_dataloader(
) -> DataLoader:
return DataLoader(
self,
tl_settings.batch_size,
batch_size=tl_settings.batch_size,
num_workers=tl_settings.num_workers,
shuffle=self.shuffle,
prefetch_factor=tl_settings.prefetch_factor,
Expand Down
62 changes: 60 additions & 2 deletions py4cast/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import asdict, dataclass
from functools import cached_property
from pathlib import Path
from typing import List, Tuple, Union
from typing import Dict, List, Tuple, Union

import einops
import matplotlib
Expand All @@ -16,7 +16,13 @@
from torchinfo import summary
from transformers import get_cosine_schedule_with_warmup

from py4cast.datasets.base import DatasetInfo, ItemBatch, NamedTensor
from py4cast.datasets import get_datasets
from py4cast.datasets.base import (
DatasetInfo,
ItemBatch,
NamedTensor,
TorchDataloaderSettings,
)
from py4cast.losses import ScaledLoss, WeightedLoss
from py4cast.metrics import MetricACC, MetricPSDK, MetricPSDVar
from py4cast.models import build_model_from_settings, get_model_kls_and_settings
Expand All @@ -36,6 +42,58 @@
PLOT_PERIOD: int = 10


@dataclass
class PlDataModule(pl.LightningDataModule):
"""
DataModule to encapsulate data splits and data loading.
"""

dataset: str
num_input_steps: int
num_pred_steps_train: int
num_pred_steps_val_test: int
dl_settings: TorchDataloaderSettings
dataset_conf: Union[Path, None] = None
config_override: Union[Dict, None] = (None,)

def __post_init__(self):
super().__init__()

colon3ltocard marked this conversation as resolved.
Show resolved Hide resolved
# Get dataset in initialisation to have access to this attribute before method trainer.fit
self.train_ds, self.val_ds, self.test_ds = get_datasets(
self.dataset,
self.num_input_steps,
self.num_pred_steps_train,
self.num_pred_steps_val_test,
self.dataset_conf,
self.config_override,
)

@property
def get_len_train_dl(self):
return len(self.train_ds.torch_dataloader(self.dl_settings))

@property
def get_train_dataset_info(self):
return self.train_ds.dataset_info

@property
def get_infer_ds(self):
return self.test_ds
Comment on lines +72 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should remove the get prefix, it is not necessary for properties and very "java like"


def train_dataloader(self):
return self.train_ds.torch_dataloader(self.dl_settings)

def val_dataloader(self):
return self.val_ds.torch_dataloader(self.dl_settings)

def test_dataloader(self):
return self.test_ds.torch_dataloader(self.dl_settings)

def predict_dataloader(self):
return self.test_ds.torch_dataloader(self.dl_settings)


@dataclass
class ArLightningHyperParam:
"""
Expand Down
Loading