From ac0c2ef2a1ee7997eae50bb10e828df71873e56a Mon Sep 17 00:00:00 2001 From: davidfitzek Date: Wed, 13 Sep 2023 14:07:51 +0100 Subject: [PATCH] store dataset config --- src/rydberggpt/utils.py | 21 ++++++++++++++++++--- train.py | 26 +++++++++++++++++++++----- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/rydberggpt/utils.py b/src/rydberggpt/utils.py index 7dabcbf9..a95af10f 100644 --- a/src/rydberggpt/utils.py +++ b/src/rydberggpt/utils.py @@ -7,6 +7,21 @@ from torch import nn +def save_to_yaml(data: Dict[str, Any], filename: str) -> None: + """ + Save a dictionary to a file in YAML format. + + Args: + data (Dict[str, Any]): The dictionary to be saved. + filename (str): The path to the file where the dictionary will be saved. + + Returns: + None + """ + with open(filename, "w") as file: + yaml.dump(data, file) + + def to_one_hot( data: Union[torch.Tensor, List[int], Tuple[int]], num_classes: int ) -> torch.Tensor: @@ -14,11 +29,11 @@ def to_one_hot( Converts the input data into one-hot representation. Args: - - data: Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers. - - num_classes: Number of classes in the one-hot representation. + data: Input data to be converted into one-hot. It can be a 1D tensor, list or tuple of integers. + num_classes: Number of classes in the one-hot representation. Returns: - - data: The one-hot representation of the input data. + data: The one-hot representation of the input data. """ if isinstance(data, (list, tuple)): diff --git a/train.py b/train.py index 60c1014d..98f7b4d0 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,5 @@ import argparse +import os from typing import Optional import numpy as np @@ -28,7 +29,7 @@ ) from rydberggpt.training.trainer import RydbergGPTTrainer from rydberggpt.training.utils import set_example_input_array -from rydberggpt.utils import create_config_from_yaml, load_yaml_file +from rydberggpt.utils import create_config_from_yaml, load_yaml_file, save_to_yaml from rydberggpt.utils_ckpt import ( find_best_ckpt, find_latest_ckpt, @@ -128,16 +129,16 @@ def main(config_path: str, config_name: str, dataset_path: str): # Setup Environment setup_environment(config) - # Load data - train_loader, val_loader = load_data(config, dataset_path) - input_array = set_example_input_array(train_loader) - # Create Model model = create_model(config) # Setup tensorboard logger logger = TensorBoardLogger(save_dir="logs") log_path = f"logs/lightning_logs/version_{logger.version}" + print(f"Log path: {log_path}") + + # save hyperparams + logger.log_hyperparams(vars(config)) rydberg_gpt_trainer = RydbergGPTTrainer( model, config, logger=logger # , example_input_array=input_array @@ -172,6 +173,21 @@ def main(config_path: str, config_name: str, dataset_path: str): detect_anomaly=config.detect_anomaly, ) + # store list of datasets used + datasets_used = [ + name + for name in os.listdir(dataset_path) + if os.path.isdir(os.path.join(dataset_path, name)) + ] + + save_to_yaml( + {"datasets": datasets_used}, os.path.join(log_path, "datasets_used.yaml") + ) + + # Load data + train_loader, val_loader = load_data(config, dataset_path) + input_array = set_example_input_array(train_loader) + # Find the latest checkpoint if config.from_checkpoint is not None: log_path = get_ckpt_path(from_ckpt=config.from_checkpoint)