-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain.py
35 lines (29 loc) · 1.17 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from .models import MaskedActionModeling
from .utils import PretrainRandomSequences
from .config import cfg
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import numpy as np
import os
if __name__=='__main__':
torch.cuda.empty_cache()
train_dataset = PretrainRandomSequences(np.load(cfg.data.train_path, allow_pickle=True), cfg.data.patch_size, cfg.pretrain.min_seq_length)
val_dataset = PretrainRandomSequences(np.load(cfg.data.val_path, allow_pickle=True), cfg.data.patch_size, cfg.pretrain.min_seq_length)
checkpoint = ModelCheckpoint(
monitor='val_loss',
dirpath=os.path.join(cfg.pretrain.ckpt_folder, cfg.pretrain.expt_folder),
filename=cfg.pretrain.expt_name,
mode='min'
)
model = MaskedActionModeling(train_dataset, val_dataset)
trainer = pl.Trainer(
accelerator=cfg.pretrain.hparams.accelerator,
devices=cfg.pretrain.hparams.devices,
max_epochs=cfg.pretrain.hparams.epochs,
num_nodes=1,
callbacks=[checkpoint],
accumulate_grad_batches=256,
log_every_n_steps=256
)
trainer.fit(model)