Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add autoresume functionality #169

Merged
merged 5 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bash_files/pretrain/cifar/barlow.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ python3 ../../../main_pretrain.py \
--method barlow_twins \
--proj_hidden_dim 2048 \
--proj_output_dim 2048 \
--scale_loss 0.1
--scale_loss 0.1 \
--auto_resume
14 changes: 14 additions & 0 deletions main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from solo.args.setup import parse_args_pretrain
from solo.methods import METHODS
from solo.utils.auto_resumer import AutoResumer

try:
from solo.methods.dali import PretrainABC
Expand Down Expand Up @@ -150,6 +151,19 @@ def main():
)
callbacks.append(auto_umap)

if args.auto_resume and args.resume_from_checkpoint is None:
auto_resumer = AutoResumer(
checkpoint_dir=os.path.join(args.checkpoint_dir, args.method),
max_hours=args.auto_resumer_max_hours,
)
resume_from_checkpoint = auto_resumer.find_checkpoint(args)
if resume_from_checkpoint is not None:
print(
"Resuming from previous checkpoint that matches specifications:",
f"'{resume_from_checkpoint}'",
)
args.resume_from_checkpoint = resume_from_checkpoint

trainer = Trainer.from_argparse_args(
args,
logger=wandb_logger if args.wandb else None,
Expand Down
5 changes: 5 additions & 0 deletions solo/args/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from solo.args.utils import additional_setup_linear, additional_setup_pretrain
from solo.methods import METHODS
from solo.utils.auto_resumer import AutoResumer
from solo.utils.checkpointer import Checkpointer

try:
Expand Down Expand Up @@ -71,6 +72,7 @@ def parse_args_pretrain() -> argparse.Namespace:
# add auto checkpoint/umap args
parser.add_argument("--save_checkpoint", action="store_true")
parser.add_argument("--auto_umap", action="store_true")
parser.add_argument("--auto_resume", action="store_true")
temp_args, _ = parser.parse_known_args()

# optionally add checkpointer and AutoUMAP args
Expand All @@ -80,6 +82,9 @@ def parse_args_pretrain() -> argparse.Namespace:
if _umap_available and temp_args.auto_umap:
parser = AutoUMAP.add_auto_umap_args(parser)

if temp_args.auto_resume:
parser = AutoResumer.add_autoresumer_args(parser)

# parse args
args = parser.parse_args()

Expand Down
95 changes: 95 additions & 0 deletions solo/utils/auto_resumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import json
import os
from argparse import ArgumentParser, Namespace
from collections import namedtuple
from datetime import datetime, timedelta
from pathlib import Path
from typing import Union

Checkpoint = namedtuple("Checkpoint", ["creation_time", "args", "checkpoint"])


class AutoResumer:
SHOULD_MATCH = [
"batch_size",
"dataset",
"encoder",
"max_epochs",
"method",
"name",
"project",
"entity",
]

def __init__(
self,
checkpoint_dir: Union[str, Path] = Path("trained_models"),
max_hours: int = 30,
):
"""Autoresumer object that automatically tries to find a checkpoint
that is as old as max_time.

Args:
checkpoint_dir (Union[str, Path], optional): base directory to store checkpoints.
Defaults to "trained_models".
max_hours (int): maximum elapsed hours to consider checkpoint as valid.
"""

self.checkpoint_dir = checkpoint_dir
self.max_hours = timedelta(hours=max_hours)

@staticmethod
def add_autoresumer_args(parent_parser: ArgumentParser):
"""Adds user-required arguments to a parser.

Args:
parent_parser (ArgumentParser): parser to add new args to.
"""

parser = parent_parser.add_argument_group("autoresumer")
parser.add_argument("--auto_resumer_max_hours", default=24, type=int)
return parent_parser

def find_checkpoint(self, args: Namespace):
"""Finds a valid checkpoint that matches the arguments

Args:
args (Namespace): namespace object containing all settings of the model.
"""

current_time = datetime.now()

possible_checkpoints = []
for rootdir, _, files in os.walk(self.checkpoint_dir):
rootdir = Path(rootdir)
if files:
# skip checkpoints that are empty
try:
checkpoint_file = [rootdir / f for f in files if f.endswith(".ckpt")][0]
except:
continue

creation_time = datetime.fromtimestamp(os.path.getctime(checkpoint_file))
if current_time - creation_time < self.max_hours:
ck = Checkpoint(
creation_time=creation_time,
args=rootdir / "args.json",
checkpoint=checkpoint_file,
)
possible_checkpoints.append(ck)

if possible_checkpoints:
# sort by most recent
possible_checkpoints = sorted(
possible_checkpoints, key=lambda ck: ck.creation_time, reverse=True
)

for checkpoint in possible_checkpoints:
checkpoint_args = Namespace(**json.load(open(checkpoint.args)))
if all(
getattr(checkpoint_args, param) == getattr(args, param)
for param in AutoResumer.SHOULD_MATCH
):
return checkpoint.checkpoint

return None
1 change: 1 addition & 0 deletions tests/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def gen_base_kwargs(
"data_dir": "/data/datasets",
"train_dir": "cifar10/train",
"val_dir": "cifar10/val",
"dataset": "cifar10",
}
if momentum:
BASE_KWARGS["base_tau_momentum"] = 0.99
Expand Down
89 changes: 89 additions & 0 deletions tests/utils/test_auto_resumer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2021 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import argparse
import json
import shutil

from pytorch_lightning import Trainer
from solo.methods import BarlowTwins
from solo.utils.auto_resumer import AutoResumer
from solo.utils.checkpointer import Checkpointer

from ..methods.utils import DATA_KWARGS, gen_base_kwargs, prepare_dummy_dataloaders


def test_checkpointer():
method_kwargs = {
"name": "barlow_twins",
"method": "barlow_twins",
"proj_hidden_dim": 2048,
"proj_output_dim": 2048,
"lamb": 5e-3,
"scale_loss": 0.025,
}

# normal training
BASE_KWARGS = gen_base_kwargs(cifar=False, batch_size=2)
kwargs = {**BASE_KWARGS, **DATA_KWARGS, **method_kwargs, "project": "test", "entity": "test"}
model = BarlowTwins(**kwargs, disable_knn_eval=True)

args = argparse.Namespace(**kwargs)

# checkpointer
ckpt_callback = Checkpointer(args)

trainer = Trainer.from_argparse_args(
args,
checkpoint_callback=False,
limit_train_batches=2,
limit_val_batches=2,
callbacks=[ckpt_callback],
)

train_dl, val_dl = prepare_dummy_dataloaders(
"imagenet100",
num_large_crops=BASE_KWARGS["num_large_crops"],
num_small_crops=0,
num_classes=BASE_KWARGS["num_classes"],
multicrop=False,
batch_size=BASE_KWARGS["batch_size"],
)

trainer.fit(model, train_dl, val_dl)

# check if checkpointer dumped the args
args_path = ckpt_callback.path / "args.json"
assert args_path.exists()

# check if the args are correct
loaded_args = json.load(open(args_path))
assert loaded_args == vars(args)

auto_resumer = AutoResumer(ckpt_callback.logdir, max_hours=1)
assert auto_resumer.find_checkpoint(args) is not None

# check arguments
parser = argparse.ArgumentParser()
auto_resumer.add_autoresumer_args(parser)
args = [vars(action)["dest"] for action in vars(parser)["_actions"]]
assert "auto_resumer_max_hours" in args

# clean stuff
shutil.rmtree(ckpt_callback.logdir)