Skip to content

Commit

Permalink
add test transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Sep 4, 2023
1 parent 0676ef5 commit afb5df9
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions src/rydberggpt/tests/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,62 @@
# import pytest
# import torch
# from torch import nn
import pytest

# from rydberggpt.models.transformer_wavefunction import TransformerWavefunction
# from rydberggpt.training.loss import LabelSmoothing
from rydberggpt.data.dataclasses import Batch
from rydberggpt.data.loading import (
get_chunked_dataloader,
get_chunked_random_dataloader,
get_rydberg_dataloader,
get_streaming_dataloader,
)
from rydberggpt.models.rydberg_encoder_decoder import get_rydberg_graph_encoder_decoder
from rydberggpt.training.trainer import RydbergGPTTrainer
from rydberggpt.utils import create_config_from_yaml, load_yaml_file


# def get_dummy_data():
# num_atoms = 3 # number of atoms
# num_samples = 8 # number of samples
# Define a fixture for common model parameters (if needed)
@pytest.fixture(scope="module")
def config():
return {
"config_path": "configs/",
}

# H = torch.rand(
# (num_samples, num_atoms, 4), dtype=torch.float
# ) # [batch_size , num_atoms, 4]
# dataset = torch.randint(0, 2, (num_samples, num_atoms), dtype=torch.int64)
# return H, dataset

# Define a fixture for common parameters
@pytest.fixture(scope="module")
def config_dataloader():
return {"data_path": "src/rydberggpt/tests/dataset_test/", "batch_size": 10}

# def test_model_minimizes_loss():
# # prepare dummy data
# H, dataset = get_dummy_data()
# inputs = nn.functional.one_hot(dataset, 2)
# inputs = inputs.to(torch.float)

# # initialize model, criterion, and optimizer
# model = TransformerWavefunction(10, 2, 2)
# criterion = LabelSmoothing()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
def test_rydberg_gpt_with_dataloader(config_dataloader):
yaml_dict = load_yaml_file("config/", "config_small")
config = create_config_from_yaml(yaml_dict)
dataloader, _ = get_chunked_random_dataloader(**config_dataloader)

# # calculate initial loss
# cond_probs = model.forward([H, inputs])
# first_loss = criterion(cond_probs, inputs)
# Create Model
model = get_rydberg_graph_encoder_decoder(config)

# # train the model for 50 iterations
# for _ in range(50):
# optimizer.zero_grad()
# cond_probs = model.forward([H, inputs])
# loss = criterion(cond_probs, inputs)
# loss.backward()
# optimizer.step()
# Initialize the trainer
rydberg_gpt_trainer = RydbergGPTTrainer(model, config)

# # check if the final loss is smaller than the initial loss
# assert loss < first_loss, "Final loss is not smaller than initial loss"
for i, batch in enumerate(dataloader):
assert isinstance(batch, Batch), "Batch is not an instance of the Batch class"

# Perform a forward pass through the model using the trainer
output = rydberg_gpt_trainer.training_step(batch, i)

# if __name__ == "__main__":
# pytest.main([__file__])
# # test_model_minimizes_loss()
# Check output shapes, values, or any other property you are interested in
assert output >= 0, "Loss is not non-negative"

# Exit loop after one batch for quick testing
if i >= 1:
break


if __name__ == "__main__":
pytest.main([__file__])

# config_dataloader = {
# "data_path": "src/rydberggpt/tests/dataset_test/",
# "batch_size": 10,
# }

# test_rydberg_gpt_with_dataloader(config_dataloader)

0 comments on commit afb5df9

Please sign in to comment.