Skip to content

Commit

Permalink
Support Leaderboard API(#94)
Browse files Browse the repository at this point in the history
Added predictor.leaderboard() API support
  • Loading branch information
YiruMu authored Dec 20, 2023
1 parent a6c3bb5 commit bf85823
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
Empty file modified .github/workflow_scripts/cloud_lint_check.sh
100644 → 100755
Empty file.
20 changes: 20 additions & 0 deletions src/autogluon/cloud/predictor/cloud_predictor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import io
import logging
import os
import tarfile
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -132,6 +134,24 @@ def info(self) -> Dict[str, Any]:
)
return info

def leaderboard(self) -> pd.DataFrame:
info = self.backend.get_fit_job_info()
cloud_output_path = self.cloud_output_path
path = os.path.join(cloud_output_path, "model", info["name"], "output/output.tar.gz")
assert is_s3_url(path), "Please provide a valid s3 path to the leaderboard result."
bucket, key = s3_path_to_bucket_prefix(path)
s3 = boto3.client("s3")
try:
wholefile = s3.get_object(Bucket=bucket, Key=key)["Body"].read()
fileobj = io.BytesIO(wholefile)
tarf = tarfile.open(fileobj=fileobj)
leaderboard = tarf.extractfile("leaderboard.csv")
df = pd.read_csv(leaderboard)
return df
except Exception:
empty = pd.DataFrame()
return empty

def _setup_local_output_path(self, path):
if path is None:
utcnow = datetime.utcnow()
Expand Down

0 comments on commit bf85823

Please sign in to comment.