From 08bdb71cb5f946dd637954660765efc4eb7c82eb Mon Sep 17 00:00:00 2001 From: Jemma Daniel Date: Fri, 1 Mar 2024 17:24:47 +0200 Subject: [PATCH] chore: switch counter to enum --- marl_eval/json_tools/json_utils.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/marl_eval/json_tools/json_utils.py b/marl_eval/json_tools/json_utils.py index 38f5bc74..e4bce58c 100644 --- a/marl_eval/json_tools/json_utils.py +++ b/marl_eval/json_tools/json_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import json +import logging import os import zipfile from collections import defaultdict @@ -22,7 +23,6 @@ import neptune from colorama import Fore, Style from tqdm import tqdm -import logging def _read_json_files(directory: str) -> list: @@ -137,40 +137,36 @@ def pull_neptune_data( os.makedirs(store_directory) # Suppress neptune logger - neptune_logger = logging.getLogger('neptune') + neptune_logger = logging.getLogger("neptune") neptune_logger.setLevel(logging.ERROR) - # Initialise a counter to ensure unique file names - counter = 0 - # Download and unzip the data for run_id in tqdm(run_ids, desc="Downloading Neptune Data"): run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only") - for data_key in run.get_structure()[neptune_data_key].keys(): - counter += 1 - file_path = f"{store_directory}/{data_key}" + for j, data_key in enumerate( + run.get_structure()[neptune_data_key].keys(), start=1 + ): + # Create a unique filename + file_path = f"{store_directory}/{data_key}_{run_id}_{j}" run[f"{neptune_data_key}/{data_key}"].download(destination=file_path) - # Try to unzip the file else continue to the next file. + # Try to unzip the file else continue to the next file try: with zipfile.ZipFile(file_path, "r") as zip_ref: - # Create a unique file name - unzipped_filename = f"{file_path}_{counter}_unzip" - # Create a directory with to store unzipped data - os.makedirs(unzipped_filename, exist_ok=True) + # Create a directory to store unzipped data + os.makedirs(f"{file_path}_unzip", exist_ok=True) # Unzip the data zip_ref.extractall(f"{file_path}_unzip") # Remove the zip file os.remove(file_path) except zipfile.BadZipFile: - # If the file is not zipped, it is already downloaded - # and doesn't need to be unzipped. - # Rename the file by appending run_counter to its existing name - renamed_file_path = f"{file_path}_{counter}" - os.rename(file_path, renamed_file_path) + # If the file is not zipped continue to the next file + # as it is already downloaded and doesn't need to be + # unzipped. + continue except Exception as e: print(f"An error occurred while unzipping or storing {file_path}: {e}") run.stop() - + # Restore neptune logger level neptune_logger.setLevel(logging.INFO)