Skip to content

Commit

Permalink
[AIRFLOW-5045] Add ability to create Google Dataproc cluster with cus…
Browse files Browse the repository at this point in the history
…tom image from a different project (#5752)
  • Loading branch information
idralyuk authored and potiuk committed Aug 8, 2019
1 parent e07e304 commit cb31d08
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
8 changes: 7 additions & 1 deletion airflow/contrib/operators/dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class DataprocClusterCreateOperator(DataprocOperationBaseOperator):
:param custom_image: custom Dataproc image for more info see
https://cloud.google.com/dataproc/docs/guides/dataproc-images
:type custom_image: str
:param custom_image_project_id: project id for the custom Dataproc image, for more info see
https://cloud.google.com/dataproc/docs/guides/dataproc-images
:type custom_image_project_id: str
:param autoscaling_policy: The autoscaling policy used by the cluster. Only resource names
including projectid and location (region) are valid. Example:
``projects/[projectId]/locations/[dataproc_region]/autoscalingPolicies/[policy_id]``
Expand Down Expand Up @@ -199,6 +202,7 @@ def __init__(self,
init_action_timeout="10m",
metadata=None,
custom_image=None,
custom_image_project_id=None,
image_version=None,
autoscaling_policy=None,
properties=None,
Expand Down Expand Up @@ -229,6 +233,7 @@ def __init__(self,
self.init_action_timeout = init_action_timeout
self.metadata = metadata
self.custom_image = custom_image
self.custom_image_project_id = custom_image_project_id
self.image_version = image_version
self.properties = properties or dict()
self.master_machine_type = master_machine_type
Expand Down Expand Up @@ -392,8 +397,9 @@ def _build_cluster_data(self):
cluster_data['config']['softwareConfig']['imageVersion'] = self.image_version

elif self.custom_image:
project_id = self.custom_image_project_id if (self.custom_image_project_id) else self.project_id
custom_image_url = 'https://www.googleapis.com/compute/beta/projects/' \
'{}/global/images/{}'.format(self.project_id,
'{}/global/images/{}'.format(project_id,
self.custom_image)
cluster_data['config']['masterConfig']['imageUri'] = custom_image_url
if not self.single_node:
Expand Down
22 changes: 22 additions & 0 deletions tests/contrib/operators/test_dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
STORAGE_BUCKET = 'gs://airflow-test-bucket/'
IMAGE_VERSION = '1.1'
CUSTOM_IMAGE = 'test-custom-image'
CUSTOM_IMAGE_PROJECT_ID = 'test-custom-image-project-id'
MASTER_MACHINE_TYPE = 'n1-standard-2'
MASTER_DISK_SIZE = 100
MASTER_DISK_TYPE = 'pd-standard'
Expand Down Expand Up @@ -324,6 +325,27 @@ def test_init_with_custom_image(self):
self.assertEqual(cluster_data['config']['workerConfig']['imageUri'],
expected_custom_image_url)

def test_init_with_custom_image_with_custom_image_project_id(self):
dataproc_operator = DataprocClusterCreateOperator(
task_id=TASK_ID,
cluster_name=CLUSTER_NAME,
project_id=GCP_PROJECT_ID,
num_workers=NUM_WORKERS,
zone=GCE_ZONE,
dag=self.dag,
custom_image=CUSTOM_IMAGE,
custom_image_project_id=CUSTOM_IMAGE_PROJECT_ID
)

cluster_data = dataproc_operator._build_cluster_data()
expected_custom_image_url = \
'https://www.googleapis.com/compute/beta/projects/' \
'{}/global/images/{}'.format(CUSTOM_IMAGE_PROJECT_ID, CUSTOM_IMAGE)
self.assertEqual(cluster_data['config']['masterConfig']['imageUri'],
expected_custom_image_url)
self.assertEqual(cluster_data['config']['workerConfig']['imageUri'],
expected_custom_image_url)

def test_build_single_node_cluster(self):
dataproc_operator = DataprocClusterCreateOperator(
task_id=TASK_ID,
Expand Down

0 comments on commit cb31d08

Please sign in to comment.