diff --git a/src/rydberggpt/tests/test_transformer.py b/src/rydberggpt/tests/test_transformer.py index acf2ec41..9712c63b 100644 --- a/src/rydberggpt/tests/test_transformer.py +++ b/src/rydberggpt/tests/test_transformer.py @@ -1,49 +1,62 @@ -# import pytest -# import torch -# from torch import nn +import pytest -# from rydberggpt.models.transformer_wavefunction import TransformerWavefunction -# from rydberggpt.training.loss import LabelSmoothing +from rydberggpt.data.dataclasses import Batch +from rydberggpt.data.loading import ( + get_chunked_dataloader, + get_chunked_random_dataloader, + get_rydberg_dataloader, + get_streaming_dataloader, +) +from rydberggpt.models.rydberg_encoder_decoder import get_rydberg_graph_encoder_decoder +from rydberggpt.training.trainer import RydbergGPTTrainer +from rydberggpt.utils import create_config_from_yaml, load_yaml_file -# def get_dummy_data(): -# num_atoms = 3 # number of atoms -# num_samples = 8 # number of samples +# Define a fixture for common model parameters (if needed) +@pytest.fixture(scope="module") +def config(): + return { + "config_path": "configs/", + } -# H = torch.rand( -# (num_samples, num_atoms, 4), dtype=torch.float -# ) # [batch_size , num_atoms, 4] -# dataset = torch.randint(0, 2, (num_samples, num_atoms), dtype=torch.int64) -# return H, dataset +# Define a fixture for common parameters +@pytest.fixture(scope="module") +def config_dataloader(): + return {"data_path": "src/rydberggpt/tests/dataset_test/", "batch_size": 10} -# def test_model_minimizes_loss(): -# # prepare dummy data -# H, dataset = get_dummy_data() -# inputs = nn.functional.one_hot(dataset, 2) -# inputs = inputs.to(torch.float) -# # initialize model, criterion, and optimizer -# model = TransformerWavefunction(10, 2, 2) -# criterion = LabelSmoothing() -# optimizer = torch.optim.Adam(model.parameters(), lr=0.05) +def test_rydberg_gpt_with_dataloader(config_dataloader): + yaml_dict = load_yaml_file("config/", "config_small") + config = create_config_from_yaml(yaml_dict) + dataloader, _ = get_chunked_random_dataloader(**config_dataloader) -# # calculate initial loss -# cond_probs = model.forward([H, inputs]) -# first_loss = criterion(cond_probs, inputs) + # Create Model + model = get_rydberg_graph_encoder_decoder(config) -# # train the model for 50 iterations -# for _ in range(50): -# optimizer.zero_grad() -# cond_probs = model.forward([H, inputs]) -# loss = criterion(cond_probs, inputs) -# loss.backward() -# optimizer.step() + # Initialize the trainer + rydberg_gpt_trainer = RydbergGPTTrainer(model, config) -# # check if the final loss is smaller than the initial loss -# assert loss < first_loss, "Final loss is not smaller than initial loss" + for i, batch in enumerate(dataloader): + assert isinstance(batch, Batch), "Batch is not an instance of the Batch class" + # Perform a forward pass through the model using the trainer + output = rydberg_gpt_trainer.training_step(batch, i) -# if __name__ == "__main__": -# pytest.main([__file__]) -# # test_model_minimizes_loss() + # Check output shapes, values, or any other property you are interested in + assert output >= 0, "Loss is not non-negative" + + # Exit loop after one batch for quick testing + if i >= 1: + break + + +if __name__ == "__main__": + pytest.main([__file__]) + + # config_dataloader = { + # "data_path": "src/rydberggpt/tests/dataset_test/", + # "batch_size": 10, + # } + + # test_rydberg_gpt_with_dataloader(config_dataloader)