From a64fd32ebe9353ccb5cb2fa12d16f6a49de969b3 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 5 Jul 2023 10:36:36 +0200 Subject: [PATCH] Adding validation_step and test_step to Model (#1181) --- .gitignore | 3 +++ merlin/models/torch/models/base.py | 19 +++++++++++++++++++ tests/unit/torch/models/test_base.py | 10 ++++++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index e5f5cb5bcc..ed7bef7e91 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,6 @@ dmypy.json # Experiment files _test.py + +# Lightning +**/lightning_logs/ \ No newline at end of file diff --git a/merlin/models/torch/models/base.py b/merlin/models/torch/models/base.py index df1826746c..9bf7271dfc 100644 --- a/merlin/models/torch/models/base.py +++ b/merlin/models/torch/models/base.py @@ -102,6 +102,25 @@ def training_step(self, batch, batch_idx): return loss_and_metrics["loss"] + def validation_step(self, batch, batch_idx): + return self._val_step(batch, batch_idx, type="val") + + def test_step(self, batch, batch_idx): + return self._val_step(batch, batch_idx, type="test") + + def _val_step(self, batch, batch_idx, type="val"): + del batch_idx + if not isinstance(batch, Batch): + batch = Batch(features=batch[0], targets=batch[1]) + + predictions = self(batch.features, batch=batch) + + loss_and_metrics = compute_loss(predictions, batch.targets, self.model_outputs()) + for name, value in loss_and_metrics.items(): + self.log(f"{type}_{name}", value) + + return loss_and_metrics + def configure_optimizers(self): """Configures the optimizer for the model.""" return self.optimizer(self.parameters()) diff --git a/tests/unit/torch/models/test_base.py b/tests/unit/torch/models/test_base.py index 9abfe2eef7..2ee931989d 100644 --- a/tests/unit/torch/models/test_base.py +++ b/tests/unit/torch/models/test_base.py @@ -128,7 +128,7 @@ def test_training_step_values(self): expected_loss = nn.BCELoss()(expected_outputs, targets["target"]) assert torch.allclose(loss, expected_loss) - def test_training_step_with_dataloader(self): + def test_step_with_dataloader(self): model = mm.Model( mm.Concat(), mm.BinaryOutput(ColumnSchema("target")), @@ -144,8 +144,11 @@ def test_training_step_with_dataloader(self): loss = model.training_step(batch, 0) assert loss > 0.0 + assert torch.equal( + model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"] + ) - def test_training_step_with_batch(self): + def test_step_with_batch(self): model = mm.Model( mm.Concat(), mm.BinaryOutput(ColumnSchema("target")), @@ -156,6 +159,9 @@ def test_training_step_with_batch(self): model.initialize(batch) loss = model.training_step(batch, 0) assert loss > 0.0 + assert torch.equal( + model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"] + ) def test_training_step_missing_output(self): model = mm.Model(mm.Block())