diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py index 8b8e2e41e7678..09993f96d8738 100644 --- a/airflow/contrib/hooks/sagemaker_hook.py +++ b/airflow/contrib/hooks/sagemaker_hook.py @@ -59,8 +59,13 @@ def check_for_url(self, s3url): if not s3hook.check_for_bucket(bucket_name=bucket): raise AirflowException( "The input S3 Bucket {} does not exist ".format(bucket)) - if not s3hook.check_for_key(key=key, bucket_name=bucket): - raise AirflowException("The input S3 Key {} does not exist in the Bucket" + if key and not s3hook.check_for_key(key=key, bucket_name=bucket)\ + and not s3hook.check_for_prefix( + prefix=key, bucket_name=bucket, delimiter='/'): + # check if s3 key exists in the case user provides a single file + # or if s3 prefix exists in the case user provides a prefix for files + raise AirflowException("The input S3 Key " + "or Prefix {} does not exist in the Bucket {}" .format(s3url, bucket)) return True @@ -196,11 +201,13 @@ def create_training_job(self, training_job_config, wait_for_completion=True): training_job_config['TrainingJobName']) return response - def create_tuning_job(self, tuning_job_config): + def create_tuning_job(self, tuning_job_config, wait_for_completion=True): """ Create a tuning job :param tuning_job_config: the config for tuning :type tuning_job_config: dict + :param wait_for_completion: if the program should keep running until job finishes + :param wait_for_completion: bool :return: A dict that contains ARN of the tuning job. """ if self.use_db_config: @@ -216,13 +223,20 @@ def create_tuning_job(self, tuning_job_config): self.check_valid_tuning_input(tuning_job_config) - return self.conn.create_hyper_parameter_tuning_job( + response = self.conn.create_hyper_parameter_tuning_job( **tuning_job_config) + if wait_for_completion: + self.check_status(['InProgress', 'Stopping', 'Stopped'], + ['Failed'], + 'HyperParameterTuningJobStatus', + self.describe_tuning_job, + tuning_job_config['HyperParameterTuningJobName']) + return response def describe_training_job(self, training_job_name): """ :param training_job_name: the name of the training job - :type train_job_name: string + :type training_job_name: string Return the training job info associated with the current job_name :return: A dict contains all the training job info """ diff --git a/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py b/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py new file mode 100644 index 0000000000000..0c40a9adc93f4 --- /dev/null +++ b/airflow/contrib/operators/sagemaker_create_tuning_job_operator.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.exceptions import AirflowException + + +class SageMakerCreateTuningJobOperator(BaseOperator): + + """ + Initiate a SageMaker HyperParameter Tuning Job + + This operator returns The ARN of the model created in Amazon SageMaker + + :param sagemaker_conn_id: The SageMaker connection ID to use. + :type sagemaker_conn_id: string + :param region_name: The AWS region_name + :type region_name: string + :param tuning_job_config: + The configuration necessary to start a tuning job (templated) + :type tuning_job_config: dict + :param use_db_config: Whether or not to use db config + associated with sagemaker_conn_id. + If set to true, will automatically update the tuning config + with what's in db, so the db config doesn't need to + included everything, but what's there does replace the ones + in the tuning_job_config, so be careful + :type use_db_config: bool + :param wait_for_completion: if the operator should block + until tuning job finishes + :type wait_for_completion: bool + :param check_interval: if wait is set to be true, this is the time interval + which the operator will check the status of the tuning job + :type check_interval: int + :param max_ingestion_time: if wait is set to be true, the operator will fail + if the tuning job hasn't finish within the max_ingestion_time + (Caution: be careful to set this parameters because tuning can take very long) + :type max_ingestion_time: int + :param aws_conn_id: The AWS connection ID to use. + :type aws_conn_id: string + + **Example**: + The following operator would start a tuning job when executed + + sagemaker_tuning = + SageMakerCreateTuningJobOperator( + task_id='sagemaker_tuning', + sagemaker_conn_id='sagemaker_customers_conn', + tuning_job_config=config, + check_interval=2, + max_ingestion_time=3600, + aws_conn_id='aws_customers_conn', + ) + """ + + template_fields = ['tuning_job_config'] + template_ext = () + ui_color = '#ededed' + + @apply_defaults + def __init__(self, + sagemaker_conn_id=None, + region_name=None, + tuning_job_config=None, + use_db_config=False, + wait_for_completion=True, + check_interval=5, + max_ingestion_time=None, + *args, **kwargs): + super(SageMakerCreateTuningJobOperator, self)\ + .__init__(*args, **kwargs) + + self.sagemaker_conn_id = sagemaker_conn_id + self.region_name = region_name + self.tuning_job_config = tuning_job_config + self.use_db_config = use_db_config + self.wait_for_completion = wait_for_completion + self.check_interval = check_interval + self.max_ingestion_time = max_ingestion_time + + def execute(self, context): + sagemaker = SageMakerHook(sagemaker_conn_id=self.sagemaker_conn_id, + region_name=self.region_name, + use_db_config=self.use_db_config, + check_interval=self.check_interval, + max_ingestion_time=self.max_ingestion_time + ) + + self.log.info( + "Creating SageMaker Hyper Parameter Tunning Job %s" + % self.tuning_job_config['HyperParameterTuningJobName'] + ) + + response = sagemaker.create_tuning_job( + self.tuning_job_config, + wait_for_completion=self.wait_for_completion + ) + if not response['ResponseMetadata']['HTTPStatusCode'] \ + == 200: + raise AirflowException( + "Sagemaker Tuning Job creation failed: %s" % response) + else: + return response diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py new file mode 100644 index 0000000000000..bc74e3a5c5461 --- /dev/null +++ b/airflow/contrib/sensors/sagemaker_tuning_sensor.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.sensors.sagemaker_base_sensor import SageMakerBaseSensor +from airflow.utils.decorators import apply_defaults + + +class SageMakerTuningSensor(SageMakerBaseSensor): + """ + Asks for the state of the tuning state until it reaches a terminal state. + The sensor will error if the job errors, throwing a AirflowException + containing the failure reason. + + :param job_name: job_name of the tuning instance to check the state of + :type job_name: string + :param region_name: The AWS region_name + :type region_name: string + """ + + template_fields = ['job_name'] + template_ext = () + + @apply_defaults + def __init__(self, + job_name, + region_name=None, + *args, + **kwargs): + super(SageMakerTuningSensor, self).__init__(*args, **kwargs) + self.job_name = job_name + self.region_name = region_name + + def non_terminal_states(self): + return ['InProgress', 'Stopping', 'Stopped'] + + def failed_states(self): + return ['Failed'] + + def get_sagemaker_response(self): + sagemaker = SageMakerHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name + ) + + self.log.info('Poking Sagemaker Tuning Job %s', self.job_name) + return sagemaker.describe_tuning_job(self.job_name) + + def get_failed_reason_from_response(self, response): + return response['FailureReason'] + + def state_from_response(self, response): + return response['HyperParameterTuningJobStatus'] diff --git a/tests/contrib/hooks/test_sagemaker_hook.py b/tests/contrib/hooks/test_sagemaker_hook.py index 6887a5b484bed..8bb56cc8e7d12 100644 --- a/tests/contrib/hooks/test_sagemaker_hook.py +++ b/tests/contrib/hooks/test_sagemaker_hook.py @@ -212,17 +212,23 @@ def setUp(self): @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(S3Hook, 'check_for_key') @mock.patch.object(S3Hook, 'check_for_bucket') + @mock.patch.object(S3Hook, 'check_for_prefix') def test_check_for_url(self, - mock_check_bucket, mock_check_key, mock_client): + mock_check_prefix, + mock_check_bucket, + mock_check_key, + mock_client): mock_client.return_value = None hook = SageMakerHook() - mock_check_bucket.side_effect = [False, True, True] - mock_check_key.side_effect = [False, True] + mock_check_bucket.side_effect = [False, True, True, True] + mock_check_key.side_effect = [False, True, False] + mock_check_prefix.side_effect = [False, True, True] self.assertRaises(AirflowException, hook.check_for_url, data_url) self.assertRaises(AirflowException, hook.check_for_url, data_url) self.assertEqual(hook.check_for_url(data_url), True) + self.assertEqual(hook.check_for_url(data_url), True) @mock.patch.object(SageMakerHook, 'get_conn') @mock.patch.object(SageMakerHook, 'check_for_url') @@ -362,7 +368,8 @@ def test_create_tuning_job(self, mock_client, mock_check_tuning): mock_session.configure_mock(**attrs) mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id') - response = hook.create_tuning_job(create_tuning_params) + response = hook.create_tuning_job(create_tuning_params, + wait_for_completion=False) mock_session.create_hyper_parameter_tuning_job.\ assert_called_once_with(**create_tuning_params) self.assertEqual(response, test_arn_return) @@ -378,7 +385,8 @@ def test_create_tuning_job_db_config(self, mock_client, mock_check_tuning): mock_client.return_value = mock_session hook = SageMakerHook(sagemaker_conn_id='sagemaker_test_conn_id', use_db_config=True) - response = hook.create_tuning_job(create_tuning_params) + response = hook.create_tuning_job(create_tuning_params, + wait_for_completion=False) updated_config = copy.deepcopy(create_tuning_params) updated_config.update(db_config) mock_session.create_hyper_parameter_tuning_job. \ diff --git a/tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py b/tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py new file mode 100644 index 0000000000000..d317cff6f2289 --- /dev/null +++ b/tests/contrib/operators/test_sagemaker_create_tuning_job_operator.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.contrib.operators.sagemaker_create_tuning_job_operator \ + import SageMakerCreateTuningJobOperator +from airflow.exceptions import AirflowException + +role = 'test-role' + +bucket = 'test-bucket' + +key = 'test/data' +data_url = 's3://{}/{}'.format(bucket, key) + +job_name = 'test-job-name' + +image = 'test-image' + +output_url = 's3://{}/test/output'.format(bucket) + +create_tuning_params = {'HyperParameterTuningJobName': job_name, + 'HyperParameterTuningJobConfig': { + 'Strategy': 'Bayesian', + 'HyperParameterTuningJobObjective': { + 'Type': 'Maximize', + 'MetricName': 'test_metric' + }, + 'ResourceLimits': { + 'MaxNumberOfTrainingJobs': 123, + 'MaxParallelTrainingJobs': 123 + }, + 'ParameterRanges': { + 'IntegerParameterRanges': [ + { + 'Name': 'k', + 'MinValue': '2', + 'MaxValue': '10' + }, + ] + } + }, + 'TrainingJobDefinition': { + 'StaticHyperParameters': + { + 'k': '10', + 'feature_dim': '784', + 'mini_batch_size': '500', + 'force_dense': 'True' + }, + 'AlgorithmSpecification': + { + 'TrainingImage': image, + 'TrainingInputMode': 'File' + }, + 'RoleArn': 'string', + 'InputDataConfig': + [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri': data_url, + 'S3DataDistributionType': + 'FullyReplicated' + } + }, + 'CompressionType': 'None', + 'RecordWrapperType': 'None' + } + ], + 'OutputDataConfig': + { + 'S3OutputPath': output_url + }, + 'ResourceConfig': + { + 'InstanceCount': 2, + 'InstanceType': 'ml.c4.8xlarge', + 'VolumeSizeInGB': 50 + }, + 'StoppingCondition': dict(MaxRuntimeInSeconds=60 * 60) + } + } + + +class TestSageMakerTrainingOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + self.sagemaker = SageMakerCreateTuningJobOperator( + task_id='test_sagemaker_operator', + sagemaker_conn_id='sagemaker_test_conn', + tuning_job_config=create_tuning_params, + region_name='us-east-1', + use_db_config=False, + wait_for_completion=False, + check_interval=5 + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_tuning_job') + @mock.patch.object(SageMakerHook, '__init__') + def test_hook_init(self, hook_init, mock_tuning, mock_client): + mock_tuning.return_value = {'TrainingJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 200}} + hook_init.return_value = None + self.sagemaker.execute(None) + hook_init.assert_called_once_with( + sagemaker_conn_id='sagemaker_test_conn', + region_name='us-east-1', + use_db_config=False, + check_interval=5, + max_ingestion_time=None + ) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_tuning_job') + def test_execute_without_failure(self, mock_tuning, mock_client): + mock_tuning.return_value = {'TrainingJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 200}} + self.sagemaker.execute(None) + mock_tuning.assert_called_once_with(create_tuning_params, + wait_for_completion=False) + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'create_tuning_job') + def test_execute_with_failure(self, mock_tuning, mock_client): + mock_tuning.return_value = {'TrainingJobArn': 'testarn', + 'ResponseMetadata': + {'HTTPStatusCode': 404}} + self.assertRaises(AirflowException, self.sagemaker.execute, None) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/sensors/test_sagemaker_tuning_sensor.py b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py new file mode 100644 index 0000000000000..49f9b41b07c89 --- /dev/null +++ b/tests/contrib/sensors/test_sagemaker_tuning_sensor.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +from airflow import configuration +from airflow.contrib.sensors.sagemaker_tuning_sensor \ + import SageMakerTuningSensor +from airflow.contrib.hooks.sagemaker_hook import SageMakerHook +from airflow.exceptions import AirflowException + +DESCRIBE_TUNING_INPROGRESS_RETURN = { + 'HyperParameterTuningJobStatus': 'InProgress', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TUNING_COMPELETED_RETURN = { + 'HyperParameterTuningJobStatus': 'Compeleted', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TUNING_FAILED_RETURN = { + 'HyperParameterTuningJobStatus': 'Failed', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + }, + 'FailureReason': 'Unknown' +} +DESCRIBE_TUNING_STOPPING_RETURN = { + 'HyperParameterTuningJobStatus': 'Stopping', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} +DESCRIBE_TUNING_STOPPED_RETURN = { + 'HyperParameterTuningJobStatus': 'Stopped', + 'ResponseMetadata': { + 'HTTPStatusCode': 200, + } +} + + +class TestSageMakerTuningSensor(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, 'describe_tuning_job') + def test_raises_errors_failed_state(self, mock_describe_job, mock_client): + mock_describe_job.side_effect = [DESCRIBE_TUNING_FAILED_RETURN] + sensor = SageMakerTuningSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name' + ) + self.assertRaises(AirflowException, sensor.execute, None) + mock_describe_job.assert_called_once_with('test_job_name') + + @mock.patch.object(SageMakerHook, 'get_conn') + @mock.patch.object(SageMakerHook, '__init__') + @mock.patch.object(SageMakerHook, 'describe_tuning_job') + def test_calls_until_a_terminal_state(self, + mock_describe_job, hook_init, mock_client): + hook_init.return_value = None + + mock_describe_job.side_effect = [ + DESCRIBE_TUNING_INPROGRESS_RETURN, + DESCRIBE_TUNING_STOPPING_RETURN, + DESCRIBE_TUNING_STOPPED_RETURN, + DESCRIBE_TUNING_COMPELETED_RETURN + ] + sensor = SageMakerTuningSensor( + task_id='test_task', + poke_interval=2, + aws_conn_id='aws_test', + job_name='test_job_name', + region_name='us-east-1' + ) + + sensor.execute(None) + + # make sure we called 4 times(terminated when its compeleted) + self.assertEqual(mock_describe_job.call_count, 4) + + # make sure the hook was initialized with the specific params + hook_init.assert_called_with(aws_conn_id='aws_test', + region_name='us-east-1') + + +if __name__ == '__main__': + unittest.main()