Skip to content

Commit

Permalink
chore: switch counter to enum
Browse files Browse the repository at this point in the history
  • Loading branch information
JemmaLDaniel committed Mar 1, 2024
1 parent 2d0cddf commit 08bdb71
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions marl_eval/json_tools/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import json
import logging
import os
import zipfile
from collections import defaultdict
Expand All @@ -22,7 +23,6 @@
import neptune
from colorama import Fore, Style
from tqdm import tqdm
import logging


def _read_json_files(directory: str) -> list:
Expand Down Expand Up @@ -137,40 +137,36 @@ def pull_neptune_data(
os.makedirs(store_directory)

# Suppress neptune logger
neptune_logger = logging.getLogger('neptune')
neptune_logger = logging.getLogger("neptune")
neptune_logger.setLevel(logging.ERROR)

# Initialise a counter to ensure unique file names
counter = 0

# Download and unzip the data
for run_id in tqdm(run_ids, desc="Downloading Neptune Data"):
run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only")
for data_key in run.get_structure()[neptune_data_key].keys():
counter += 1
file_path = f"{store_directory}/{data_key}"
for j, data_key in enumerate(
run.get_structure()[neptune_data_key].keys(), start=1
):
# Create a unique filename
file_path = f"{store_directory}/{data_key}_{run_id}_{j}"
run[f"{neptune_data_key}/{data_key}"].download(destination=file_path)
# Try to unzip the file else continue to the next file.
# Try to unzip the file else continue to the next file
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
# Create a unique file name
unzipped_filename = f"{file_path}_{counter}_unzip"
# Create a directory with to store unzipped data
os.makedirs(unzipped_filename, exist_ok=True)
# Create a directory to store unzipped data
os.makedirs(f"{file_path}_unzip", exist_ok=True)
# Unzip the data
zip_ref.extractall(f"{file_path}_unzip")
# Remove the zip file
os.remove(file_path)
except zipfile.BadZipFile:
# If the file is not zipped, it is already downloaded
# and doesn't need to be unzipped.
# Rename the file by appending run_counter to its existing name
renamed_file_path = f"{file_path}_{counter}"
os.rename(file_path, renamed_file_path)
# If the file is not zipped continue to the next file
# as it is already downloaded and doesn't need to be
# unzipped.
continue
except Exception as e:
print(f"An error occurred while unzipping or storing {file_path}: {e}")
run.stop()

# Restore neptune logger level
neptune_logger.setLevel(logging.INFO)

Expand Down

0 comments on commit 08bdb71

Please sign in to comment.