diff --git a/.github/workflow_scripts/cloud_lint_check.sh b/.github/workflow_scripts/cloud_lint_check.sh old mode 100644 new mode 100755 diff --git a/src/autogluon/cloud/predictor/cloud_predictor.py b/src/autogluon/cloud/predictor/cloud_predictor.py index be71fe10..55a749a2 100644 --- a/src/autogluon/cloud/predictor/cloud_predictor.py +++ b/src/autogluon/cloud/predictor/cloud_predictor.py @@ -1,7 +1,9 @@ from __future__ import annotations +import io import logging import os +import tarfile from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Optional, Tuple, Union @@ -132,6 +134,24 @@ def info(self) -> Dict[str, Any]: ) return info + def leaderboard(self) -> pd.DataFrame: + info = self.backend.get_fit_job_info() + cloud_output_path = self.cloud_output_path + path = os.path.join(cloud_output_path, "model", info["name"], "output/output.tar.gz") + assert is_s3_url(path), "Please provide a valid s3 path to the leaderboard result." + bucket, key = s3_path_to_bucket_prefix(path) + s3 = boto3.client("s3") + try: + wholefile = s3.get_object(Bucket=bucket, Key=key)["Body"].read() + fileobj = io.BytesIO(wholefile) + tarf = tarfile.open(fileobj=fileobj) + leaderboard = tarf.extractfile("leaderboard.csv") + df = pd.read_csv(leaderboard) + return df + except Exception: + empty = pd.DataFrame() + return empty + def _setup_local_output_path(self, path): if path is None: utcnow = datetime.utcnow()