Skip to content

Commit

Permalink
update dataloader_test
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Sep 4, 2023
1 parent e824610 commit 3e4b638
Showing 1 changed file with 63 additions and 43 deletions.
106 changes: 63 additions & 43 deletions src/rydberggpt/tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,69 @@
import pytest

from rydberggpt.data.dataclasses import Batch # Assuming Batch is imported from here
from rydberggpt.data.loading.rydberg_dataset import get_rydberg_dataloader
from rydberggpt.data.loading.rydberg_dataset_chunked import get_chunked_dataloader
from rydberggpt.data.loading.rydberg_dataset_streaming import get_streaming_dataloader


class TestRydbergDatasets:
def __init__(self):
self.data_path = "src/rydberggpt/tests/dataset_test/"
self.batch_size = 12

def test_dataloader_common(self, dataloader):
for batch in dataloader:
print(batch)
assert (
batch.m_onehot.shape[0] == self.batch_size
), f"Batch size of m_onehot is not {self.batch_size}."
assert batch.m_onehot.shape[2] == 2, "Dimension of m_onehot is not 2."
assert (
batch.m_onehot.shape == batch.m_shifted_onehot.shape
), "Shapes of m_onehot and m_shifted_onehot are not the same."
assert isinstance(batch, Batch)
assert hasattr(batch, "graph")
assert hasattr(batch, "m_onehot")
assert hasattr(batch, "m_shifted_onehot")
break

def test_get_rydberg_dataloader(self):
dataloader, _ = get_rydberg_dataloader(
data_path=self.data_path, batch_size=self.batch_size
)
self.test_dataloader_common(dataloader)

def test_get_chunked_dataloader(self):
dataloader, _ = get_chunked_dataloader(
data_path=self.data_path, batch_size=self.batch_size
)
self.test_dataloader_common(dataloader)

def test_get_streaming_dataloader(self):
dataloader, _ = get_streaming_dataloader(
data_path=self.data_path, batch_size=self.batch_size
)
self.test_dataloader_common(dataloader)
from rydberggpt.data.loading import (
get_chunked_dataloader,
get_chunked_random_dataloader,
get_rydberg_dataloader,
get_streaming_dataloader,
)


# Define a fixture for common parameters
@pytest.fixture(scope="module")
def common_parameters():
return {"data_path": "src/rydberggpt/tests/dataset_test/", "batch_size": 10}


# Define your fixtures for each dataloader type, and use the common_parameters fixture as an argument
@pytest.fixture(scope="module")
def rydberg_dataloader(common_parameters):
dataloader, _ = get_rydberg_dataloader(**common_parameters)
return dataloader


@pytest.fixture(scope="module")
def chunked_dataloader(common_parameters):
dataloader, _ = get_chunked_dataloader(**common_parameters)
return dataloader


@pytest.fixture(scope="module")
def chunked_random_dataloader(common_parameters):
dataloader, _ = get_chunked_random_dataloader(**common_parameters)
return dataloader


@pytest.fixture(scope="module")
def streaming_dataloader(common_parameters):
dataloader, _ = get_streaming_dataloader(**common_parameters)
return dataloader


# Parameterize the common test to run for each dataloader
@pytest.mark.parametrize(
"dataloader",
[
"rydberg_dataloader",
"chunked_dataloader",
"chunked_random_dataloader",
"streaming_dataloader",
],
)
def test_dataloader_common(request, dataloader):
dataloader_instance = request.getfixturevalue(dataloader)

for batch in dataloader_instance:
assert batch.m_onehot.shape[0] == 10, "Batch size of m_onehot is not 10."
assert batch.m_onehot.shape[2] == 2, "Dimension of m_onehot is not 2."
assert (
batch.m_onehot.shape == batch.m_shifted_onehot.shape
), "Shapes of m_onehot and m_shifted_onehot are not the same."
assert isinstance(batch, Batch)
assert hasattr(batch, "graph")
assert hasattr(batch, "m_onehot")
assert hasattr(batch, "m_shifted_onehot")
break


if __name__ == "__main__":
Expand Down

0 comments on commit 3e4b638

Please sign in to comment.