diff --git a/marl_eval/utils/merge_json_files.py b/marl_eval/utils/merge_json_files.py index 43a53f7b..5756c777 100644 --- a/marl_eval/utils/merge_json_files.py +++ b/marl_eval/utils/merge_json_files.py @@ -60,13 +60,17 @@ def _check_seed(concatenated_data: Dict, algo_data: Dict, seed_number: str) -> s return seed_number -def concatenate_files(directory: str, json_path: str = "./concatenation") -> Dict: +def concatenate_files( + input_directory: str, output_json_path: str = "concatenated_json_files/" +) -> Dict: """Concatenate all json files in a directory and save the result in a json file.""" - # Read all json files in a directory - json_data = _read_json_files(directory) -# Create target folder - if not os.path.exists(json_path): - os.makedirs(json_path) + # Read all json files in a input_directory + json_data = _read_json_files(input_directory) + + # Create target folder + if not os.path.exists(output_json_path): + os.makedirs(output_json_path) + # Using defaultdict for automatic handling of missing keys concatenated_data: Dict = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))) @@ -88,7 +92,9 @@ def concatenate_files(directory: str, json_path: str = "./concatenation") -> Dic ] = algo_data # Save concatenated data in a json file - with open(f"{json_path}.json", "w") as f: + if output_json_path[-1] != "/": + output_json_path += "/" + with open(f"{output_json_path}metrics.json", "w") as f: json.dump(concatenated_data, f, indent=4) return concatenated_data