Skip to content

Commit

Permalink
Remove .nemo instead of renaming (#9281)
Browse files Browse the repository at this point in the history
* Remove .nemo instead of renaming

Signed-off-by: Mikołaj Błaż <[email protected]>

* add ignore_errors=True flag

Signed-off-by: dimapihtar <[email protected]>

* Revert "Remove .nemo instead of renaming"

This reverts commit b836410.

Signed-off-by: Mikołaj Błaż <[email protected]>

* Remove backup .nemo after success

Signed-off-by: Mikołaj Błaż <[email protected]>

* Update tests

Signed-off-by: Mikołaj Błaż <[email protected]>

* Backup .nemo imediately before save_to

Signed-off-by: Mikołaj Błaż <[email protected]>

* Apply isort and black reformatting

Signed-off-by: mikolajblaz <[email protected]>

* Fix CTC import

Signed-off-by: Mikołaj Błaż <[email protected]>

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: dimapihtar <[email protected]>
Signed-off-by: mikolajblaz <[email protected]>
Co-authored-by: dimapihtar <[email protected]>
Signed-off-by: Jan Lasek <[email protected]>
  • Loading branch information
2 people authored and janekl committed Jun 12, 2024
1 parent a1ff752 commit bc0c010
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import HypothesisType, LengthsType, LogprobsType, NeuralType
from nemo.utils import logging
from nemo.utils import logging, logging_mode


def pack_hypotheses(
Expand Down
33 changes: 25 additions & 8 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import pytorch_lightning
import torch
from _weakref import proxy

from lightning_fabric.utilities.cloud_io import get_filesystem
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
from pytorch_lightning.utilities import rank_zero_info

Expand Down Expand Up @@ -198,7 +200,6 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.')
# since we are creating tarfile artifacts we need to update .nemo path
self._backup_existing_nemo_ckpt(trainer)
app_state.model_restore_path = self._format_nemo_checkpoint_name()
if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1:
maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path)
Expand All @@ -222,14 +223,19 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
pl_module.load_state_dict(checkpoint, strict=True)
if torch.distributed.is_initialized():
torch.distributed.barrier()
backup_path = self._backup_existing_nemo_ckpt(trainer)
pl_module.save_to(save_path=app_state.model_restore_path)
logging.info(f"New best .nemo model saved to: {app_state.model_restore_path}")
pl_module.load_state_dict(old_state_dict, strict=True)
else:
if torch.distributed.is_initialized():
torch.distributed.barrier()
backup_path = self._backup_existing_nemo_ckpt(trainer)
pl_module.save_to(save_path=app_state.model_restore_path)
logging.info(f"New .nemo model saved to: {app_state.model_restore_path}")
if backup_path is not None and is_global_rank_zero():
logging.info(f'Removing old .nemo backup {backup_path}')
get_filesystem(backup_path).rm(backup_path)
return output

def on_train_end(self, trainer, pl_module):
Expand Down Expand Up @@ -268,16 +274,25 @@ def on_train_end(self, trainer, pl_module):
trainer._checkpoint_connector.restore(self.best_model_path)

if self.save_nemo_on_train_end:
self._backup_existing_nemo_ckpt(trainer)
backup_path = self._backup_existing_nemo_ckpt(trainer)
pl_module.save_to(save_path=self._format_nemo_checkpoint_name())
if backup_path is not None and is_global_rank_zero():
logging.info(f'Removing old .nemo backup {backup_path}')
get_filesystem(backup_path).rm(backup_path)

def _backup_existing_nemo_ckpt(self, trainer) -> str:
def _backup_existing_nemo_ckpt(self, trainer) -> Optional[str]:
"""Search for an available name with version infix and rename existing checkpoint.
NOTE: this behavior is slightly different from regular checkpoints.
PTL creates new regular checkpoint with the first available name.
Here, for backward compatibility, we create .nemo checkpoint as before
and create a backup under the first available name.
Args:
trainer (Trainer): trainer instance.
Returns:
Path to the backup checkpoint or None, if no backup was created
"""
base_path = self._format_nemo_checkpoint_name()
available_path = base_path
Expand All @@ -286,11 +301,13 @@ def _backup_existing_nemo_ckpt(self, trainer) -> str:
while self.file_exists(available_path, trainer, check_dist_ckpt=False):
available_path = self._format_nemo_checkpoint_name(version_cnt)
version_cnt += 1
if available_path != base_path:
if trainer.is_global_zero:
logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}')
shutil.move(base_path, available_path)
trainer.strategy.barrier()
if available_path == base_path:
# no existing ckpt, no need to backup
return None
if trainer.is_global_zero:
logging.info(f'{base_path} already exists, moving existing checkpoint to {available_path}')
shutil.move(base_path, available_path)
trainer.strategy.barrier()
return available_path

def _format_nemo_checkpoint_name(self, ver: Optional[int] = None) -> str:
Expand Down
65 changes: 50 additions & 15 deletions tests/core/test_exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_omegaconf(self):

@pytest.mark.unit
def test_trainer_loggers(self, tmp_path):
""" Test that a trainer with logger errors out with a number of arguments. Test that it works with
"""Test that a trainer with logger errors out with a number of arguments. Test that it works with
create_tensorboard_logger set to False
"""
test_trainer = pl.Trainer(accelerator='cpu') # Should create logger and modelcheckpoint
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_trainer_neptune_logger(self, tmp_path):

@pytest.mark.unit
def test_checkpoint_configurations(self):
""" Test that trainer creating modelcheckpoint and asking exp_manager to do it too results in errors, but
"""Test that trainer creating modelcheckpoint and asking exp_manager to do it too results in errors, but
is error free if only one is asked to do so.
"""
disable_tb_logger = {"create_tensorboard_logger": False}
Expand Down Expand Up @@ -297,7 +297,7 @@ def test_log_dir_overrides(self, monkeypatch, tmp_path):

@pytest.mark.unit
def test_resume(self, tmp_path):
""" Tests the resume capabilities of exp_manager"""
"""Tests the resume capabilities of exp_manager"""
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)

# Error because explicit_log_dir does not exist
Expand Down Expand Up @@ -428,7 +428,8 @@ def test_nemo_checkpoint_save_best_model_1(self, tmp_path):
def test_nemo_checkpoint_save_best_model_2(self, tmp_path):
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4)
exp_manager(
test_trainer, {"explicit_log_dir": str(tmp_path / "test")},
test_trainer,
{"explicit_log_dir": str(tmp_path / "test")},
)
model = ExampleModel()
test_trainer.fit(model)
Expand Down Expand Up @@ -456,6 +457,27 @@ def test_nemo_checkpoint_always_save_nemo(self, tmp_path):
model = ExampleModel.restore_from(str(tmp_path / "test" / "checkpoints" / "default.nemo"))
assert float(model(torch.tensor([1.0, 1.0], device=model.device))) == 0.0

@pytest.mark.unit
def test_nemo_checkpoint_doesnt_produce_too_many_nemo_ckpts(self, tmp_path):
test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4)
exp_manager(
test_trainer,
{
"checkpoint_callback_params": {"save_best_model": True, "always_save_nemo": True, "save_top_k": 2},
"explicit_log_dir": str(tmp_path / "test"),
},
)
model = ExampleModel()
test_trainer.fit(model)

assert Path(str(tmp_path / "test" / "checkpoints" / "default.nemo")).exists()
assert (
len(list((tmp_path / "test" / "checkpoints").glob("default*.nemo"))) == 1
) # check number of `.nemo` checkpoints

model = ExampleModel.restore_from(str(tmp_path / "test" / "checkpoints" / "default.nemo"))
assert float(model(torch.tensor([1.0, 1.0], device=model.device))) == 0.0

@pytest.mark.unit
def test_nemo_checkpoint_make_checkpoint_dir(self, tmp_path):
test_trainer = pl.Trainer(
Expand Down Expand Up @@ -511,8 +533,8 @@ def test_nemo_checkpoint_restore_model(self, tmp_path):

@pytest.mark.run_only_on('GPU')
@pytest.mark.parametrize('test_dist_ckpt', [False, True])
def test_checkpoints_are_not_overwritten(self, tmp_path, test_dist_ckpt):
""" Simulates already existing checkpoints in the ckpt directory and tests ckpt versioning """
def test_base_checkpoints_are_not_overwritten(self, tmp_path, test_dist_ckpt):
"""Simulates already existing checkpoints in the ckpt directory and tests non-nemo ckpt versioning"""
strategy = NLPDDPStrategy() if test_dist_ckpt else 'auto'
test_trainer = pl.Trainer(
accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=4, strategy=strategy
Expand Down Expand Up @@ -563,7 +585,8 @@ def _get_versioned_name(ckpt_name: Path, nemo: bool = False):

assert _get_versioned_name(ckpt_1).exists(), all_checkpoints
assert not _get_versioned_name(ckpt_2).exists(), all_checkpoints # ckpt2 didn't exist before
assert _get_versioned_name(ckpt_nemo, nemo=True).exists(), all_checkpoints
# .nemo checkpoints are not versioned:
assert not _get_versioned_name(ckpt_nemo, nemo=True).exists(), all_checkpoints

@pytest.mark.unit
def test_last_checkpoint_saved(self, tmp_path):
Expand Down Expand Up @@ -592,6 +615,7 @@ def train_dataloader(self):
model_path = checkpoint_dir / "val_loss=0.0300-epoch=1-step=64-last.ckpt"
last_saved_checkpoint = torch.load(model_path)
assert max_steps == last_saved_checkpoint['global_step']

# restart training, ensure global step starts correctly
class AssertCallback(Callback):
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Expand Down Expand Up @@ -681,8 +705,7 @@ def test_warning_validation_skipping_when_custom_epoch_loop(self, tmp_path):
"""
tmp_path = tmp_path / "test_3"

class CustomLoop(_TrainingEpochLoop):
...
class CustomLoop(_TrainingEpochLoop): ...

trainer = pl.Trainer(
accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=1, val_check_interval=0.33
Expand Down Expand Up @@ -759,7 +782,8 @@ def test_skipped_unfinished_checkpoints_when_restoring(self, tmp_path):

restored_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
exp_manager(
restored_trainer, {"resume_if_exists": True, "explicit_log_dir": str(test_dir)},
restored_trainer,
{"resume_if_exists": True, "explicit_log_dir": str(test_dir)},
)

# Check that last complete (w/o unifinished marker) checkpoint was found
Expand Down Expand Up @@ -803,7 +827,8 @@ def test_skipped_unfinished_dist_checkpoints_when_restoring(self, tmp_path):

restored_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False)
exp_manager(
restored_trainer, {"resume_if_exists": True, "explicit_log_dir": str(test_dir)},
restored_trainer,
{"resume_if_exists": True, "explicit_log_dir": str(test_dir)},
)

# Check that last complete (w/o unifinished marker) checkpoint was found
Expand Down Expand Up @@ -850,21 +875,28 @@ def test_incomplete_checkpoints_cleanup(self, tmp_path):

# unfinished checkpoint with EMA part, both parts should be removed
self._write_fake_checkpoint(
checkpoints_dir / "incomplete01-EMA.ckpt", isdir=False, add_unfinished_marker=False,
checkpoints_dir / "incomplete01-EMA.ckpt",
isdir=False,
add_unfinished_marker=False,
)
self._write_fake_checkpoint(checkpoints_dir / "incomplete01.ckpt", isdir=False, add_unfinished_marker=True)

# just EMA part - should be removed. NOTE marker path is the same for base part and for EMA part
self._write_fake_checkpoint(
checkpoints_dir / "incomplete02-EMA.ckpt", isdir=False, add_unfinished_marker=False,
checkpoints_dir / "incomplete02-EMA.ckpt",
isdir=False,
add_unfinished_marker=False,
)
(checkpoints_dir / f"incomplete02{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}").touch()

test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=1)

exp_manager(
test_trainer,
{"checkpoint_callback_params": {"save_top_k": 0, "save_last": False}, "explicit_log_dir": str(test_dir),},
{
"checkpoint_callback_params": {"save_top_k": 0, "save_last": False},
"explicit_log_dir": str(test_dir),
},
)

model = ExampleModel()
Expand Down Expand Up @@ -909,7 +941,10 @@ def test_incomplete_dist_checkpoints_cleanup(self, tmp_path):

exp_manager(
test_trainer,
{"checkpoint_callback_params": {"save_top_k": 0, "save_last": False}, "explicit_log_dir": str(test_dir),},
{
"checkpoint_callback_params": {"save_top_k": 0, "save_last": False},
"explicit_log_dir": str(test_dir),
},
)

model = ExampleModel()
Expand Down

0 comments on commit bc0c010

Please sign in to comment.