diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py index c748666c00573a..db60bde9d78eb6 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker_training.py +++ b/airflow/providers/amazon/aws/operators/sagemaker_training.py @@ -46,8 +46,12 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. :type max_ingestion_time: int + :param check_if_job_exists: If set to true, then the operator will check whether a training job + already exists for the name in the config. + :type check_if_job_exists: bool :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" (default) and "fail". + This is only relevant if check_if_job_exists is True. :type action_if_job_exists: str """ @@ -65,6 +69,7 @@ def __init__( print_log: bool = True, check_interval: int = 30, max_ingestion_time: Optional[int] = None, + check_if_job_exists: bool = True, action_if_job_exists: str = "increment", # TODO use typing.Literal for this in Python 3.8 **kwargs, ): @@ -74,6 +79,7 @@ def __init__( self.print_log = print_log self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time + self.check_if_job_exists = check_if_job_exists if action_if_job_exists in ("increment", "fail"): self.action_if_job_exists = action_if_job_exists @@ -90,7 +96,22 @@ def expand_role(self) -> None: def execute(self, context) -> dict: self.preprocess_config() + if self.check_if_job_exists: + self._check_if_job_exists() + self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"]) + response = self.hook.create_training_job( + self.config, + wait_for_completion=self.wait_for_completion, + print_log=self.print_log, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time, + ) + if response['ResponseMetadata']['HTTPStatusCode'] != 200: + raise AirflowException(f'Sagemaker Training Job creation failed: {response}') + else: + return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} + def _check_if_job_exists(self) -> None: training_job_name = self.config["TrainingJobName"] training_jobs = self.hook.list_training_jobs(name_contains=training_job_name) @@ -105,16 +126,3 @@ def execute(self, context) -> dict: raise AirflowException( f"A SageMaker training job with name {training_job_name} already exists." ) - - self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"]) - response = self.hook.create_training_job( - self.config, - wait_for_completion=self.wait_for_completion, - print_log=self.print_log, - check_interval=self.check_interval, - max_ingestion_time=self.max_ingestion_time, - ) - if response['ResponseMetadata']['HTTPStatusCode'] != 200: - raise AirflowException(f'Sagemaker Training Job creation failed: {response}') - else: - return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])} diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 4aeca8c65e0772..8e54533a475072 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -86,12 +86,33 @@ def test_parse_config_integers(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'create_training_job') - def test_execute(self, mock_training, mock_client): + def test_execute_with_check_if_job_exists(self, mock_training, mock_client): mock_training.return_value = { 'TrainingJobArn': 'testarn', 'ResponseMetadata': {'HTTPStatusCode': 200}, } + self.sagemaker._check_if_job_exists = mock.MagicMock() self.sagemaker.execute(None) + self.sagemaker._check_if_job_exists.assert_called_once() + mock_training.assert_called_once_with( + create_training_params, + wait_for_completion=False, + print_log=True, + check_interval=5, + max_ingestion_time=None, + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_training_job') + def test_execute_without_check_if_job_exists(self, mock_training, mock_client): + mock_training.return_value = { + 'TrainingJobArn': 'testarn', + 'ResponseMetadata': {'HTTPStatusCode': 200}, + } + self.sagemaker.check_if_job_exists = False + self.sagemaker._check_if_job_exists = mock.MagicMock() + self.sagemaker.execute(None) + self.sagemaker._check_if_job_exists.assert_not_called() mock_training.assert_called_once_with( create_training_params, wait_for_completion=False, @@ -110,38 +131,24 @@ def test_execute_with_failure(self, mock_training, mock_client): with pytest.raises(AirflowException): self.sagemaker.execute(None) - # pylint: enable=unused-argument - @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "list_training_jobs") - @mock.patch.object(SageMakerHook, "create_training_job") - def test_execute_with_existing_job_increment( - self, mock_create_training_job, mock_list_training_jobs, mock_client - ): + def test_check_if_job_exists_increment(self, mock_list_training_jobs, mock_client): + self.sagemaker.check_if_job_exists = True self.sagemaker.action_if_job_exists = "increment" - mock_create_training_job.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}] - self.sagemaker.execute(None) + self.sagemaker._check_if_job_exists() expected_config = create_training_params.copy() # Expect to see TrainingJobName suffixed with "-2" because we return one existing job expected_config["TrainingJobName"] = f"{job_name}-2" - mock_create_training_job.assert_called_once_with( - expected_config, - wait_for_completion=False, - print_log=True, - check_interval=5, - max_ingestion_time=None, - ) + assert self.sagemaker.config == expected_config @mock.patch.object(SageMakerHook, "get_conn") @mock.patch.object(SageMakerHook, "list_training_jobs") - @mock.patch.object(SageMakerHook, "create_training_job") - def test_execute_with_existing_job_fail( - self, mock_create_training_job, mock_list_training_jobs, mock_client - ): + def test_check_if_job_exists_fail(self, mock_list_training_jobs, mock_client): + self.sagemaker.check_if_job_exists = True self.sagemaker.action_if_job_exists = "fail" - mock_create_training_job.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}} mock_list_training_jobs.return_value = [{"TrainingJobName": job_name}] with pytest.raises(AirflowException): - self.sagemaker.execute(None) + self.sagemaker._check_if_job_exists()