Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
liamclarkza committed Mar 12, 2024
1 parent 05f2c45 commit 5869627
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions marl_eval/json_tools/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
import zipfile
from collections import defaultdict
from concurrent.futures import as_completed, ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Tuple

Expand Down Expand Up @@ -113,11 +113,11 @@ def concatenate_json_files(


def pull_neptune_data(
project_name: str,
tags: List[str],
store_directory: str = "./downloaded_json_data",
neptune_data_key: str = "metrics",
disable_progress_bar: bool = False,
project_name: str,
tags: List[str],
store_directory: str = "./downloaded_json_data",
neptune_data_key: str = "metrics",
disable_progress_bar: bool = False,
) -> None:
"""Downloads logs from a Neptune project based on provided tags.
Expand All @@ -128,7 +128,7 @@ def pull_neptune_data(
Default is "./downloaded_json_data".
neptune_data_key (str, optional): Key for the Neptune data to download.
Default is "metrics".
disable_progress_bar (bool, optional): Whether to hide a progress bar during download.
disable_progress_bar (bool, optional): Whether to hide a progress bar.
Default is False.
Raises:
Expand All @@ -150,10 +150,7 @@ def pull_neptune_data(
# Fetch runs based on provided tags
try:
runs_table_df = project.fetch_runs_table(
state="inactive",
columns=['sys/id'],
tag=tags,
sort_by='sys/id'
state="inactive", columns=["sys/id"], tag=tags, sort_by="sys/id"
).to_pandas()
except Exception as e:
raise ValueError(f"Invalid tags {tags}: {e}")
Expand All @@ -162,29 +159,49 @@ def pull_neptune_data(

# Download logs concurrently
with ThreadPoolExecutor() as executor:
futures = [executor.submit(download_and_extract_data, project_name, run_id, store_directory, neptune_data_key) for run_id in run_ids]
for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading JSON logs", disable=disable_progress_bar):
futures = [
executor.submit(
_download_and_extract_data,
project_name,
run_id,
store_directory,
neptune_data_key,
)
for run_id in run_ids
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Downloading JSON logs",
disable=disable_progress_bar,
):
future.result()

# Restore neptune logger level
neptune_logger.setLevel(logging.INFO)
print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}")


def download_and_extract_data(project_name, run_id, store_directory, neptune_data_key):
def _download_and_extract_data(
project_name: str, run_id: str, store_directory: str, neptune_data_key: str
) -> None:
try:
with neptune.init_run(project=project_name, with_id=run_id, mode="read-only") as run:
for j, data_key in enumerate(run.get_structure()[neptune_data_key].keys(), start=1):
with neptune.init_run(
project=project_name, with_id=run_id, mode="read-only"
) as run:
for j, data_key in enumerate(
run.get_structure()[neptune_data_key].keys(), start=1
):
file_path = f"{store_directory}/{run_id}"
if j > 1:
file_path += f"_{j}"
run[f"{neptune_data_key}/{data_key}"].download(destination=file_path)
extract_zip_file(file_path)
_extract_zip_file(file_path)
except Exception as e:
print(f"Error downloading data for run {run_id}: {e}")


def extract_zip_file(file_path):
def _extract_zip_file(file_path: str) -> None:
try:
with zipfile.ZipFile(file_path, "r") as zip_ref:
for member in zip_ref.infolist():
Expand Down

0 comments on commit 5869627

Please sign in to comment.