Skip to content

Commit

Permalink
Adding validation_step and test_step to Model (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jul 5, 2023
1 parent e2930f8 commit a64fd32
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,6 @@ dmypy.json

# Experiment files
_test.py

# Lightning
**/lightning_logs/
19 changes: 19 additions & 0 deletions merlin/models/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ def training_step(self, batch, batch_idx):

return loss_and_metrics["loss"]

def validation_step(self, batch, batch_idx):
return self._val_step(batch, batch_idx, type="val")

def test_step(self, batch, batch_idx):
return self._val_step(batch, batch_idx, type="test")

def _val_step(self, batch, batch_idx, type="val"):
del batch_idx
if not isinstance(batch, Batch):
batch = Batch(features=batch[0], targets=batch[1])

predictions = self(batch.features, batch=batch)

loss_and_metrics = compute_loss(predictions, batch.targets, self.model_outputs())
for name, value in loss_and_metrics.items():
self.log(f"{type}_{name}", value)

return loss_and_metrics

def configure_optimizers(self):
"""Configures the optimizer for the model."""
return self.optimizer(self.parameters())
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_training_step_values(self):
expected_loss = nn.BCELoss()(expected_outputs, targets["target"])
assert torch.allclose(loss, expected_loss)

def test_training_step_with_dataloader(self):
def test_step_with_dataloader(self):
model = mm.Model(
mm.Concat(),
mm.BinaryOutput(ColumnSchema("target")),
Expand All @@ -144,8 +144,11 @@ def test_training_step_with_dataloader(self):

loss = model.training_step(batch, 0)
assert loss > 0.0
assert torch.equal(
model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"]
)

def test_training_step_with_batch(self):
def test_step_with_batch(self):
model = mm.Model(
mm.Concat(),
mm.BinaryOutput(ColumnSchema("target")),
Expand All @@ -156,6 +159,9 @@ def test_training_step_with_batch(self):
model.initialize(batch)
loss = model.training_step(batch, 0)
assert loss > 0.0
assert torch.equal(
model.validation_step(batch, 0)["loss"], model.test_step(batch, 0)["loss"]
)

def test_training_step_missing_output(self):
model = mm.Model(mm.Block())
Expand Down

0 comments on commit a64fd32

Please sign in to comment.