Skip to content

Commit

Permalink
feat: add the unzip feature and remove system name restriction
Browse files Browse the repository at this point in the history
  • Loading branch information
OmaymaMahjoub committed Jan 26, 2024
1 parent 46b492a commit 6d69772
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions marl_eval/json_tools/pull_neptune_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,52 @@
# limitations under the License.

import os
import zipfile
from typing import List

import neptune
from colorama import Fore, Style
from tqdm import tqdm


def pull_neptune_data(
project_name: str, system_name: str, tag: List, store_directory: str = "."
) -> None:
def pull_neptune_data(project_name: str, tag: List, store_directory: str = ".") -> None:
"""Pulls the experiments data from Neptune to local directory.
Args:
project_name (str): Name of the Neptune project.
system_name (str): Name of the system (example: ff-ippo).
tag (List): List of tags.
store_directory (str, optional): Directory to store the data.
"""
# Get the run ids
project = neptune.init_project(project=project_name)
runs_table_df = project.fetch_runs_table(state="inactive", tag=tag).to_pandas()
runs_table_df = runs_table_df[
runs_table_df["config/logger/system_name"] == system_name
]
run_ids = runs_table_df["sys/id"].values.tolist()

# Check if store_directory exists
if not os.path.exists(store_directory):
os.makedirs(store_directory)

# Download the data
# Download and unzip the data
itr = 0 # To create a unique directory for each unzipped file
for run_id in tqdm(run_ids, desc="Downloading Neptune Data"):
run = neptune.init_run(project=project_name, with_id=run_id, mode="read-only")
data_key = list(run.get_structure()["metrics"].keys())[0]
run[f"metrics/{data_key}"].download(destination=f"{store_directory}/{data_key}")
for data_key in run.get_structure()["metrics"].keys():
file_path = f"{store_directory}/{data_key}"
run[f"metrics/{data_key}"].download(destination=file_path)
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
# Create a directory with to store unzipped data
os.makedirs(f"{store_directory}/{itr}", exist_ok=True)
# Unzip the data
zip_ref.extractall(f"{store_directory}/{itr}")
# Remove the zip file
os.remove(file_path)
except zipfile.BadZipFile:
# If it's not a zip file, just continue to the next file
continue
except Exception as e:
print(f"An error occurred while unzipping or storing {file_path}: {e}")
itr += 1
run.stop()

print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}")

0 comments on commit 6d69772

Please sign in to comment.