diff --git a/src/autogluon/cloud/job/sagemaker_job.py b/src/autogluon/cloud/job/sagemaker_job.py index 5185167..edc09ae 100644 --- a/src/autogluon/cloud/job/sagemaker_job.py +++ b/src/autogluon/cloud/job/sagemaker_job.py @@ -185,7 +185,7 @@ def run( base_job_name=base_job_name, output_path=output_path, code_location=code_location, - custom_image_uri=custom_image_uri, + image_uri=custom_image_uri, **autogluon_sagemaker_estimator_kwargs, ) logger.log(20, f"Start sagemaker training job `{job_name}`") diff --git a/src/autogluon/cloud/utils/ag_sagemaker.py b/src/autogluon/cloud/utils/ag_sagemaker.py index f44c39a..1a00eb4 100644 --- a/src/autogluon/cloud/utils/ag_sagemaker.py +++ b/src/autogluon/cloud/utils/ag_sagemaker.py @@ -17,19 +17,19 @@ class AutoGluonSagemakerEstimator(Estimator): def __init__( self, - entry_point, region, framework_version, py_version, instance_type, + entry_point=None, source_dir=None, hyperparameters=None, - custom_image_uri=None, + image_uri=None, **kwargs, ): self.framework_version = framework_version self.py_version = py_version - self.image_uri = custom_image_uri + self.image_uri = image_uri if self.image_uri is None: self.image_uri = image_uris.retrieve( "autogluon", @@ -107,9 +107,12 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na init_params = super()._prepare_init_params_from_job_description( job_details, model_channel_name=model_channel_name ) - # This two parameters will not be used, but is required to reattach the job + # These parameters will not be used, but is required to reattach the job init_params["region"] = "us-east-1" - init_params["framework_version"] = retrieve_latest_framework_version() + framework_version, py_version = retrieve_latest_framework_version() + py_version = py_version[0] + init_params["framework_version"] = framework_version + init_params["py_version"] = py_version return init_params diff --git a/src/autogluon/cloud/utils/sagemaker_iam.py b/src/autogluon/cloud/utils/sagemaker_iam.py index 1141e4b..fd1af81 100644 --- a/src/autogluon/cloud/utils/sagemaker_iam.py +++ b/src/autogluon/cloud/utils/sagemaker_iam.py @@ -42,6 +42,7 @@ "sagemaker:UpdateArtifact", "sagemaker:UpdateEndpoint", "sagemaker:InvokeEndpoint", + "sagemaker:ListTags", # Needed for re-attach job ], "Resource": [ f"arn:aws:sagemaker:*:{POLICY_ACCOUNT_PLACE_HOLDER}:endpoint/{SAGEMAKER_RESOURCE_PREFIX}*", diff --git a/tests/conftest.py b/tests/conftest.py index b973749..f3e77dd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,11 +145,14 @@ def test_functionality( **fit_kwargs, ) info = cloud_predictor.info() + job_name = info["fit_job"]["name"] assert info["local_output_path"] is not None assert info["cloud_output_path"] is not None - assert info["fit_job"]["name"] is not None + assert job_name is not None assert info["fit_job"]["status"] == "Completed" + cloud_predictor.attach_job(job_name) + if deploy_kwargs is None: deploy_kwargs = dict() if predict_real_time_kwargs is None: