-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0676ef5
commit afb5df9
Showing
1 changed file
with
50 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,62 @@ | ||
# import pytest | ||
# import torch | ||
# from torch import nn | ||
import pytest | ||
|
||
# from rydberggpt.models.transformer_wavefunction import TransformerWavefunction | ||
# from rydberggpt.training.loss import LabelSmoothing | ||
from rydberggpt.data.dataclasses import Batch | ||
from rydberggpt.data.loading import ( | ||
get_chunked_dataloader, | ||
get_chunked_random_dataloader, | ||
get_rydberg_dataloader, | ||
get_streaming_dataloader, | ||
) | ||
from rydberggpt.models.rydberg_encoder_decoder import get_rydberg_graph_encoder_decoder | ||
from rydberggpt.training.trainer import RydbergGPTTrainer | ||
from rydberggpt.utils import create_config_from_yaml, load_yaml_file | ||
|
||
|
||
# def get_dummy_data(): | ||
# num_atoms = 3 # number of atoms | ||
# num_samples = 8 # number of samples | ||
# Define a fixture for common model parameters (if needed) | ||
@pytest.fixture(scope="module") | ||
def config(): | ||
return { | ||
"config_path": "configs/", | ||
} | ||
|
||
# H = torch.rand( | ||
# (num_samples, num_atoms, 4), dtype=torch.float | ||
# ) # [batch_size , num_atoms, 4] | ||
# dataset = torch.randint(0, 2, (num_samples, num_atoms), dtype=torch.int64) | ||
# return H, dataset | ||
|
||
# Define a fixture for common parameters | ||
@pytest.fixture(scope="module") | ||
def config_dataloader(): | ||
return {"data_path": "src/rydberggpt/tests/dataset_test/", "batch_size": 10} | ||
|
||
# def test_model_minimizes_loss(): | ||
# # prepare dummy data | ||
# H, dataset = get_dummy_data() | ||
# inputs = nn.functional.one_hot(dataset, 2) | ||
# inputs = inputs.to(torch.float) | ||
|
||
# # initialize model, criterion, and optimizer | ||
# model = TransformerWavefunction(10, 2, 2) | ||
# criterion = LabelSmoothing() | ||
# optimizer = torch.optim.Adam(model.parameters(), lr=0.05) | ||
def test_rydberg_gpt_with_dataloader(config_dataloader): | ||
yaml_dict = load_yaml_file("config/", "config_small") | ||
config = create_config_from_yaml(yaml_dict) | ||
dataloader, _ = get_chunked_random_dataloader(**config_dataloader) | ||
|
||
# # calculate initial loss | ||
# cond_probs = model.forward([H, inputs]) | ||
# first_loss = criterion(cond_probs, inputs) | ||
# Create Model | ||
model = get_rydberg_graph_encoder_decoder(config) | ||
|
||
# # train the model for 50 iterations | ||
# for _ in range(50): | ||
# optimizer.zero_grad() | ||
# cond_probs = model.forward([H, inputs]) | ||
# loss = criterion(cond_probs, inputs) | ||
# loss.backward() | ||
# optimizer.step() | ||
# Initialize the trainer | ||
rydberg_gpt_trainer = RydbergGPTTrainer(model, config) | ||
|
||
# # check if the final loss is smaller than the initial loss | ||
# assert loss < first_loss, "Final loss is not smaller than initial loss" | ||
for i, batch in enumerate(dataloader): | ||
assert isinstance(batch, Batch), "Batch is not an instance of the Batch class" | ||
|
||
# Perform a forward pass through the model using the trainer | ||
output = rydberg_gpt_trainer.training_step(batch, i) | ||
|
||
# if __name__ == "__main__": | ||
# pytest.main([__file__]) | ||
# # test_model_minimizes_loss() | ||
# Check output shapes, values, or any other property you are interested in | ||
assert output >= 0, "Loss is not non-negative" | ||
|
||
# Exit loop after one batch for quick testing | ||
if i >= 1: | ||
break | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) | ||
|
||
# config_dataloader = { | ||
# "data_path": "src/rydberggpt/tests/dataset_test/", | ||
# "batch_size": 10, | ||
# } | ||
|
||
# test_rydberg_gpt_with_dataloader(config_dataloader) |