Skip to content

Commit

Permalink
store dataset config
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfitzek committed Sep 13, 2023
1 parent 08fe4ad commit ac0c2ef
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
21 changes: 18 additions & 3 deletions src/rydberggpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,33 @@
from torch import nn


def save_to_yaml(data: Dict[str, Any], filename: str) -> None:
"""
Save a dictionary to a file in YAML format.
Args:
data (Dict[str, Any]): The dictionary to be saved.
filename (str): The path to the file where the dictionary will be saved.
Returns:
None
"""
with open(filename, "w") as file:
yaml.dump(data, file)


def to_one_hot(
data: Union[torch.Tensor, List[int], Tuple[int]], num_classes: int
) -> torch.Tensor:
"""
Converts the input data into one-hot representation.
Args:
- data: Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers.
- num_classes: Number of classes in the one-hot representation.
data: Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers.
num_classes: Number of classes in the one-hot representation.
Returns:
- data: The one-hot representation of the input data.
data: The one-hot representation of the input data.
"""

if isinstance(data, (list, tuple)):
Expand Down
26 changes: 21 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -28,7 +29,7 @@
)
from rydberggpt.training.trainer import RydbergGPTTrainer
from rydberggpt.training.utils import set_example_input_array
from rydberggpt.utils import create_config_from_yaml, load_yaml_file
from rydberggpt.utils import create_config_from_yaml, load_yaml_file, save_to_yaml
from rydberggpt.utils_ckpt import (
find_best_ckpt,
find_latest_ckpt,
Expand Down Expand Up @@ -128,16 +129,16 @@ def main(config_path: str, config_name: str, dataset_path: str):
# Setup Environment
setup_environment(config)

# Load data
train_loader, val_loader = load_data(config, dataset_path)
input_array = set_example_input_array(train_loader)

# Create Model
model = create_model(config)

# Setup tensorboard logger
logger = TensorBoardLogger(save_dir="logs")
log_path = f"logs/lightning_logs/version_{logger.version}"
print(f"Log path: {log_path}")

# save hyperparams
logger.log_hyperparams(vars(config))

rydberg_gpt_trainer = RydbergGPTTrainer(
model, config, logger=logger # , example_input_array=input_array
Expand Down Expand Up @@ -172,6 +173,21 @@ def main(config_path: str, config_name: str, dataset_path: str):
detect_anomaly=config.detect_anomaly,
)

# store list of datasets used
datasets_used = [
name
for name in os.listdir(dataset_path)
if os.path.isdir(os.path.join(dataset_path, name))
]

save_to_yaml(
{"datasets": datasets_used}, os.path.join(log_path, "datasets_used.yaml")
)

# Load data
train_loader, val_loader = load_data(config, dataset_path)
input_array = set_example_input_array(train_loader)

# Find the latest checkpoint
if config.from_checkpoint is not None:
log_path = get_ckpt_path(from_ckpt=config.from_checkpoint)
Expand Down

0 comments on commit ac0c2ef

Please sign in to comment.