From 2d0ab12707bf654ea78fa0645fff91ee69d41282 Mon Sep 17 00:00:00 2001 From: Jarek Date: Wed, 12 Sep 2018 10:29:12 +0200 Subject: [PATCH] [AIRFLOW-2797] Create Google Dataproc cluster with custom image (#3871) --- .../contrib/operators/dataproc_operator.py | 14 ++++++++ .../operators/test_dataproc_operator.py | 34 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py index 957a1b89e6b074..2638b60e516bd5 100644 --- a/airflow/contrib/operators/dataproc_operator.py +++ b/airflow/contrib/operators/dataproc_operator.py @@ -68,6 +68,9 @@ class DataprocClusterCreateOperator(BaseOperator): :type metadata: dict :param image_version: the version of software inside the Dataproc cluster :type image_version: string + :param custom_image: custom Dataproc image for more info see + https://cloud.google.com/dataproc/docs/guides/dataproc-images + :type: custom_image: string :param properties: dict of properties to set on config files (e.g. spark-defaults.conf), see https://cloud.google.com/dataproc/docs/reference/rest/v1/ \ @@ -148,6 +151,7 @@ def __init__(self, init_actions_uris=None, init_action_timeout="10m", metadata=None, + custom_image=None, image_version=None, properties=None, master_machine_type='n1-standard-4', @@ -180,6 +184,7 @@ def __init__(self, self.init_actions_uris = init_actions_uris self.init_action_timeout = init_action_timeout self.metadata = metadata + self.custom_image = custom_image self.image_version = image_version self.properties = properties self.master_machine_type = master_machine_type @@ -201,6 +206,9 @@ def __init__(self, self.auto_delete_time = auto_delete_time self.auto_delete_ttl = auto_delete_ttl + assert not (self.custom_image and self.image_version), \ + "custom_image and image_version can't be both set" + def _get_cluster_list_for_project(self, service): result = service.projects().regions().clusters().list( projectId=self.project_id, @@ -338,6 +346,12 @@ def _build_cluster_data(self): cluster_data['config']['gceClusterConfig']['tags'] = self.tags if self.image_version: cluster_data['config']['softwareConfig']['imageVersion'] = self.image_version + elif self.custom_image: + custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \ + '{}/global/images/{}'.format(self.project_id, + self.custom_image) + cluster_data['config']['masterConfig']['imageUri'] = custom_image_url + cluster_data['config']['workerConfig']['imageUri'] = custom_image_url if self.properties: cluster_data['config']['softwareConfig']['properties'] = self.properties if self.idle_delete_ttl: diff --git a/tests/contrib/operators/test_dataproc_operator.py b/tests/contrib/operators/test_dataproc_operator.py index ab191f3c327609..bb1e0dfb7eb11f 100644 --- a/tests/contrib/operators/test_dataproc_operator.py +++ b/tests/contrib/operators/test_dataproc_operator.py @@ -59,6 +59,7 @@ TAGS = ['tag1', 'tag2'] STORAGE_BUCKET = 'gs://airflow-test-bucket/' IMAGE_VERSION = '1.1' +CUSTOM_IMAGE = 'test-custom-image' MASTER_MACHINE_TYPE = 'n1-standard-2' MASTER_DISK_SIZE = 100 MASTER_DISK_TYPE = 'pd-standard' @@ -243,6 +244,39 @@ def test_build_cluster_data_with_autoDeleteTime_and_autoDeleteTtl(self): self.assertEqual(cluster_data['config']['lifecycleConfig']['autoDeleteTime'], "2017-06-07T00:00:00.000000Z") + def test_init_with_image_version_and_custom_image_both_set(self): + with self.assertRaises(AssertionError): + DataprocClusterCreateOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + project_id=PROJECT_ID, + num_workers=NUM_WORKERS, + zone=ZONE, + dag=self.dag, + image_version=IMAGE_VERSION, + custom_image=CUSTOM_IMAGE + ) + + def test_init_with_custom_image(self): + dataproc_operator = DataprocClusterCreateOperator( + task_id=TASK_ID, + cluster_name=CLUSTER_NAME, + project_id=PROJECT_ID, + num_workers=NUM_WORKERS, + zone=ZONE, + dag=self.dag, + custom_image=CUSTOM_IMAGE + ) + + cluster_data = dataproc_operator._build_cluster_data() + expected_custom_image_url = \ + 'https://www.googleapis.com/compute/beta/projects/' \ + '{}/global/images/{}'.format(PROJECT_ID, CUSTOM_IMAGE) + self.assertEqual(cluster_data['config']['masterConfig']['imageUri'], + expected_custom_image_url) + self.assertEqual(cluster_data['config']['workerConfig']['imageUri'], + expected_custom_image_url) + def test_cluster_name_log_no_sub(self): with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \ as mock_hook: