Skip to content

Commit

Permalink
[AIRFLOW-2797] Create Google Dataproc cluster with custom image (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
exploy authored and galak75 committed Nov 23, 2018
1 parent cdba019 commit 2d0ab12
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions airflow/contrib/operators/dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class DataprocClusterCreateOperator(BaseOperator):
:type metadata: dict
:param image_version: the version of software inside the Dataproc cluster
:type image_version: string
:param custom_image: custom Dataproc image for more info see
https://cloud.google.com/dataproc/docs/guides/dataproc-images
:type: custom_image: string
:param properties: dict of properties to set on
config files (e.g. spark-defaults.conf), see
https://cloud.google.com/dataproc/docs/reference/rest/v1/ \
Expand Down Expand Up @@ -148,6 +151,7 @@ def __init__(self,
init_actions_uris=None,
init_action_timeout="10m",
metadata=None,
custom_image=None,
image_version=None,
properties=None,
master_machine_type='n1-standard-4',
Expand Down Expand Up @@ -180,6 +184,7 @@ def __init__(self,
self.init_actions_uris = init_actions_uris
self.init_action_timeout = init_action_timeout
self.metadata = metadata
self.custom_image = custom_image
self.image_version = image_version
self.properties = properties
self.master_machine_type = master_machine_type
Expand All @@ -201,6 +206,9 @@ def __init__(self,
self.auto_delete_time = auto_delete_time
self.auto_delete_ttl = auto_delete_ttl

assert not (self.custom_image and self.image_version), \
"custom_image and image_version can't be both set"

def _get_cluster_list_for_project(self, service):
result = service.projects().regions().clusters().list(
projectId=self.project_id,
Expand Down Expand Up @@ -338,6 +346,12 @@ def _build_cluster_data(self):
cluster_data['config']['gceClusterConfig']['tags'] = self.tags
if self.image_version:
cluster_data['config']['softwareConfig']['imageVersion'] = self.image_version
elif self.custom_image:
custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \
'{}/global/images/{}'.format(self.project_id,
self.custom_image)
cluster_data['config']['masterConfig']['imageUri'] = custom_image_url
cluster_data['config']['workerConfig']['imageUri'] = custom_image_url
if self.properties:
cluster_data['config']['softwareConfig']['properties'] = self.properties
if self.idle_delete_ttl:
Expand Down
34 changes: 34 additions & 0 deletions tests/contrib/operators/test_dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
TAGS = ['tag1', 'tag2']
STORAGE_BUCKET = 'gs://airflow-test-bucket/'
IMAGE_VERSION = '1.1'
CUSTOM_IMAGE = 'test-custom-image'
MASTER_MACHINE_TYPE = 'n1-standard-2'
MASTER_DISK_SIZE = 100
MASTER_DISK_TYPE = 'pd-standard'
Expand Down Expand Up @@ -243,6 +244,39 @@ def test_build_cluster_data_with_autoDeleteTime_and_autoDeleteTtl(self):
self.assertEqual(cluster_data['config']['lifecycleConfig']['autoDeleteTime'],
"2017-06-07T00:00:00.000000Z")

def test_init_with_image_version_and_custom_image_both_set(self):
with self.assertRaises(AssertionError):
DataprocClusterCreateOperator(
task_id=TASK_ID,
cluster_name=CLUSTER_NAME,
project_id=PROJECT_ID,
num_workers=NUM_WORKERS,
zone=ZONE,
dag=self.dag,
image_version=IMAGE_VERSION,
custom_image=CUSTOM_IMAGE
)

def test_init_with_custom_image(self):
dataproc_operator = DataprocClusterCreateOperator(
task_id=TASK_ID,
cluster_name=CLUSTER_NAME,
project_id=PROJECT_ID,
num_workers=NUM_WORKERS,
zone=ZONE,
dag=self.dag,
custom_image=CUSTOM_IMAGE
)

cluster_data = dataproc_operator._build_cluster_data()
expected_custom_image_url = \
'https://www.googleapis.com/compute/beta/projects/' \
'{}/global/images/{}'.format(PROJECT_ID, CUSTOM_IMAGE)
self.assertEqual(cluster_data['config']['masterConfig']['imageUri'],
expected_custom_image_url)
self.assertEqual(cluster_data['config']['workerConfig']['imageUri'],
expected_custom_image_url)

def test_cluster_name_log_no_sub(self):
with patch('airflow.contrib.operators.dataproc_operator.DataProcHook') \
as mock_hook:
Expand Down

0 comments on commit 2d0ab12

Please sign in to comment.