Skip to content

Commit

Permalink
Fix attach_job (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weisu Yin authored Mar 22, 2023
1 parent 2bbaa83 commit f9327a1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/autogluon/cloud/job/sagemaker_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def run(
base_job_name=base_job_name,
output_path=output_path,
code_location=code_location,
custom_image_uri=custom_image_uri,
image_uri=custom_image_uri,
**autogluon_sagemaker_estimator_kwargs,
)
logger.log(20, f"Start sagemaker training job `{job_name}`")
Expand Down
13 changes: 8 additions & 5 deletions src/autogluon/cloud/utils/ag_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
class AutoGluonSagemakerEstimator(Estimator):
def __init__(
self,
entry_point,
region,
framework_version,
py_version,
instance_type,
entry_point=None,
source_dir=None,
hyperparameters=None,
custom_image_uri=None,
image_uri=None,
**kwargs,
):
self.framework_version = framework_version
self.py_version = py_version
self.image_uri = custom_image_uri
self.image_uri = image_uri
if self.image_uri is None:
self.image_uri = image_uris.retrieve(
"autogluon",
Expand Down Expand Up @@ -107,9 +107,12 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
init_params = super()._prepare_init_params_from_job_description(
job_details, model_channel_name=model_channel_name
)
# This two parameters will not be used, but is required to reattach the job
# These parameters will not be used, but is required to reattach the job
init_params["region"] = "us-east-1"
init_params["framework_version"] = retrieve_latest_framework_version()
framework_version, py_version = retrieve_latest_framework_version()
py_version = py_version[0]
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version
return init_params


Expand Down
1 change: 1 addition & 0 deletions src/autogluon/cloud/utils/sagemaker_iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"sagemaker:UpdateArtifact",
"sagemaker:UpdateEndpoint",
"sagemaker:InvokeEndpoint",
"sagemaker:ListTags", # Needed for re-attach job
],
"Resource": [
f"arn:aws:sagemaker:*:{POLICY_ACCOUNT_PLACE_HOLDER}:endpoint/{SAGEMAKER_RESOURCE_PREFIX}*",
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ def test_functionality(
**fit_kwargs,
)
info = cloud_predictor.info()
job_name = info["fit_job"]["name"]
assert info["local_output_path"] is not None
assert info["cloud_output_path"] is not None
assert info["fit_job"]["name"] is not None
assert job_name is not None
assert info["fit_job"]["status"] == "Completed"

cloud_predictor.attach_job(job_name)

if deploy_kwargs is None:
deploy_kwargs = dict()
if predict_real_time_kwargs is None:
Expand Down

0 comments on commit f9327a1

Please sign in to comment.