diff --git a/marl_eval/json_tools/json_utils.py b/marl_eval/json_tools/json_utils.py index 58203450..93633868 100644 --- a/marl_eval/json_tools/json_utils.py +++ b/marl_eval/json_tools/json_utils.py @@ -18,7 +18,7 @@ import os import zipfile from collections import defaultdict -from concurrent.futures import as_completed, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Dict, List, Tuple @@ -113,11 +113,11 @@ def concatenate_json_files( def pull_neptune_data( - project_name: str, - tags: List[str], - store_directory: str = "./downloaded_json_data", - neptune_data_key: str = "metrics", - disable_progress_bar: bool = False, + project_name: str, + tags: List[str], + store_directory: str = "./downloaded_json_data", + neptune_data_key: str = "metrics", + disable_progress_bar: bool = False, ) -> None: """Downloads logs from a Neptune project based on provided tags. @@ -128,7 +128,7 @@ def pull_neptune_data( Default is "./downloaded_json_data". neptune_data_key (str, optional): Key for the Neptune data to download. Default is "metrics". - disable_progress_bar (bool, optional): Whether to hide a progress bar during download. + disable_progress_bar (bool, optional): Whether to hide a progress bar. Default is False. Raises: @@ -150,10 +150,7 @@ def pull_neptune_data( # Fetch runs based on provided tags try: runs_table_df = project.fetch_runs_table( - state="inactive", - columns=['sys/id'], - tag=tags, - sort_by='sys/id' + state="inactive", columns=["sys/id"], tag=tags, sort_by="sys/id" ).to_pandas() except Exception as e: raise ValueError(f"Invalid tags {tags}: {e}") @@ -162,8 +159,22 @@ def pull_neptune_data( # Download logs concurrently with ThreadPoolExecutor() as executor: - futures = [executor.submit(download_and_extract_data, project_name, run_id, store_directory, neptune_data_key) for run_id in run_ids] - for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading JSON logs", disable=disable_progress_bar): + futures = [ + executor.submit( + _download_and_extract_data, + project_name, + run_id, + store_directory, + neptune_data_key, + ) + for run_id in run_ids + ] + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Downloading JSON logs", + disable=disable_progress_bar, + ): future.result() # Restore neptune logger level @@ -171,20 +182,26 @@ def pull_neptune_data( print(f"{Fore.CYAN}{Style.BRIGHT}Data downloaded successfully!{Style.RESET_ALL}") -def download_and_extract_data(project_name, run_id, store_directory, neptune_data_key): +def _download_and_extract_data( + project_name: str, run_id: str, store_directory: str, neptune_data_key: str +) -> None: try: - with neptune.init_run(project=project_name, with_id=run_id, mode="read-only") as run: - for j, data_key in enumerate(run.get_structure()[neptune_data_key].keys(), start=1): + with neptune.init_run( + project=project_name, with_id=run_id, mode="read-only" + ) as run: + for j, data_key in enumerate( + run.get_structure()[neptune_data_key].keys(), start=1 + ): file_path = f"{store_directory}/{run_id}" if j > 1: file_path += f"_{j}" run[f"{neptune_data_key}/{data_key}"].download(destination=file_path) - extract_zip_file(file_path) + _extract_zip_file(file_path) except Exception as e: print(f"Error downloading data for run {run_id}: {e}") -def extract_zip_file(file_path): +def _extract_zip_file(file_path: str) -> None: try: with zipfile.ZipFile(file_path, "r") as zip_ref: for member in zip_ref.infolist():