Skip to content

Commit

Permalink
fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv committed Jul 3, 2023
1 parent 8744595 commit 8e18051
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions tests/unit/torch/models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import merlin.models.torch as mm
from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.torch.batch import Batch
from merlin.models.torch.batch import Batch, sample_batch
from merlin.models.torch.models.base import compute_loss
from merlin.models.torch.utils import module_utils
from merlin.schema import ColumnSchema
Expand Down Expand Up @@ -201,8 +201,10 @@ def test_no_output_schema(self):
with pytest.raises(ValueError, match="Could not get output schema of PlusOne()"):
mm.schema.output(model)

def test_train_classification_with_lightning_trainer(self, music_streaming_data):
schema = music_streaming_data.schema.without(["user_genres", "like", "item_genres"])
def test_train_classification_with_lightning_trainer(self, music_streaming_data, batch_size=16):
schema = music_streaming_data.schema.select_by_name(
["item_id", "user_id", "user_age", "item_genres", "click"]
)
music_streaming_data.schema = schema

model = mm.Model(
Expand All @@ -211,15 +213,18 @@ def test_train_classification_with_lightning_trainer(self, music_streaming_data)
mm.BinaryOutput(schema.select_by_name("click").first),
)

trainer = pl.Trainer(max_epochs=1)
trainer = pl.Trainer(max_epochs=1, devices=1)

with Loader(music_streaming_data, batch_size=16) as loader:
with Loader(music_streaming_data, batch_size=batch_size) as loader:
model.initialize(loader)
trainer.fit(model, loader)

assert trainer.logged_metrics["train_loss"] > 0.0
assert trainer.num_training_batches == 7 # 100 rows // 16 per batch + 1 for last batch

batch = sample_batch(music_streaming_data, batch_size)
_ = module_utils.module_test(model, batch)


class TestComputeLoss:
def test_tensor_inputs(self):
Expand Down

0 comments on commit 8e18051

Please sign in to comment.