Skip to content

Commit

Permalink
[AIRFLOW-5729] Make InputDataConfig optional in Sagemaker's training …
Browse files Browse the repository at this point in the history
…config (#6398)

* [AIRFLOW-5729] Make InputDataConfig optional in Sagemaker's training config

* Added test checking training config without InputDataConfig

(cherry picked from commit bdc5836)
  • Loading branch information
BasPH authored and ashb committed Dec 19, 2019
1 parent b6b2d0d commit fa0e23b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 3 additions & 2 deletions airflow/contrib/hooks/sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def check_training_config(self, training_config):
:type training_config: dict
:return: None
"""
for channel in training_config['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
if "InputDataConfig" in training_config:
for channel in training_config['InputDataConfig']:
self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])

def check_tuning_config(self, tuning_config):
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/contrib/hooks/test_sagemaker_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ def test_check_valid_training(self, mock_check_url, mock_client):
hook.check_training_config(create_training_params)
mock_check_url.assert_called_once_with(data_url)

# InputDataConfig is optional, verify if check succeeds without InputDataConfig
create_training_params_no_inputdataconfig = create_training_params.copy()
create_training_params_no_inputdataconfig.pop("InputDataConfig")
hook.check_training_config(create_training_params_no_inputdataconfig)

@mock.patch.object(SageMakerHook, 'get_conn')
@mock.patch.object(SageMakerHook, 'check_s3_url')
def test_check_valid_tuning(self, mock_check_url, mock_client):
Expand Down

0 comments on commit fa0e23b

Please sign in to comment.