Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataprocJobBaseOperator not compatible with TaskGroups #23791

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
job_type: str,
properties: Optional[Dict[str, str]] = None,
) -> None:
name = task_id + "_" + str(uuid.uuid4())[:8]
name = f"{task_id.replace('.', '_')}_{uuid.uuid4()!s:.8}"
self.job_type = job_type
self.job = {
"job": {
Expand Down Expand Up @@ -175,11 +175,12 @@ def set_python_main(self, main: str) -> None:

def set_job_name(self, name: str) -> None:
"""
Set Dataproc job name.
Set Dataproc job name. Job name is sanitized, replacing dots by underscores.

:param name: Job name.
"""
self.job["job"]["reference"]["job_id"] = name + "_" + str(uuid.uuid4())[:8]
sanitized_name = f"{name.replace('.', '_')}_{uuid.uuid4()!s:.8}"
self.job["job"]["reference"]["job_id"] = sanitized_name

def build(self) -> Dict:
"""
Expand Down
32 changes: 21 additions & 11 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.dataproc_v1 import JobStatus
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
Expand Down Expand Up @@ -472,27 +473,28 @@ def setUp(self) -> None:
properties={"test": "test"},
)

@parameterized.expand([TASK_ID, f"group.{TASK_ID}"])
@mock.patch(DATAPROC_STRING.format("uuid.uuid4"))
def test_init(self, mock_uuid):
def test_init(self, job_name, mock_uuid):
mock_uuid.return_value = "uuid"
properties = {"test": "test"}
job = {
expected_job_id = f"{job_name}_{mock_uuid.return_value}".replace(".", "_")
expected_job = {
"job": {
"labels": {"airflow-version": AIRFLOW_VERSION},
"placement": {"cluster_name": CLUSTER_NAME},
"reference": {"job_id": TASK_ID + "_uuid", "project_id": GCP_PROJECT},
"reference": {"job_id": expected_job_id, "project_id": GCP_PROJECT},
"test": {"properties": properties},
}
}
builder = DataProcJobBuilder(
project_id=GCP_PROJECT,
task_id=TASK_ID,
task_id=job_name,
cluster_name=CLUSTER_NAME,
job_type="test",
properties=properties,
)

assert job == builder.job
assert expected_job == builder.job

def test_add_labels(self):
labels = {"key": "value"}
Expand Down Expand Up @@ -559,14 +561,22 @@ def test_set_python_main(self):
self.builder.set_python_main(main)
assert main == self.builder.job["job"][self.job_type]["main_python_file_uri"]

@parameterized.expand(
[
("simple", "name"),
("name with underscores", "name_with_dash"),
("name with dot", "group.name"),
("name with dot and underscores", "group.name_with_dash"),
]
)
@mock.patch(DATAPROC_STRING.format("uuid.uuid4"))
def test_set_job_name(self, mock_uuid):
def test_set_job_name(self, name, job_name, mock_uuid):
uuid = "test_uuid"
expected_job_name = f"{job_name}_{uuid[:8]}".replace(".", "_")
mock_uuid.return_value = uuid
name = "name"
self.builder.set_job_name(name)
name += "_" + uuid[:8]
assert name == self.builder.job["job"]["reference"]["job_id"]
self.builder.set_job_name(job_name)
assert expected_job_name == self.builder.job["job"]["reference"]["job_id"]
assert len(self.builder.job["job"]["reference"]["job_id"]) == len(job_name) + 9

def test_build(self):
assert self.builder.job == self.builder.build()
35 changes: 27 additions & 8 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,8 +1204,9 @@ class TestDataProcHiveOperator(unittest.TestCase):
query = "define sin HiveUDF('sin');"
variables = {"key": "value"}
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
gmcrocetti marked this conversation as resolved.
Show resolved Hide resolved
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hive_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1226,6 +1227,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitHiveJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1249,6 +1251,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitHiveJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1263,8 +1266,9 @@ class TestDataProcPigOperator(unittest.TestCase):
query = "define sin HiveUDF('sin');"
variables = {"key": "value"}
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pig_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1285,6 +1289,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitPigJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1308,6 +1313,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitPigJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1321,15 +1327,16 @@ def test_builder(self, mock_hook, mock_uuid):
class TestDataProcSparkSqlOperator(unittest.TestCase):
query = "SHOW DATABASES;"
variables = {"key": "value"}
job_name = "simple"
job_id = "uuid_id"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
}
other_project_job = {
"reference": {"project_id": "other-project", "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": "other-project", "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1350,6 +1357,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1375,6 +1383,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
project_id="other-project",
task_id=TASK_ID,
region=GCP_LOCATION,
Expand All @@ -1399,6 +1408,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1412,10 +1422,11 @@ def test_builder(self, mock_hook, mock_uuid):
class TestDataProcSparkOperator(DataprocJobTestBase):
main_class = "org.apache.spark.examples.SparkPi"
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
job_name = "simple"
job = {
"reference": {
"project_id": GCP_PROJECT,
"job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID,
"job_id": f"{job_name}_{TEST_JOB_ID}",
},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
Expand All @@ -1440,6 +1451,7 @@ def test_execute(self, mock_hook, mock_uuid):
self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')

op = DataprocSubmitSparkJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -1505,9 +1517,10 @@ def test_submit_spark_job_operator_extra_links(mock_hook, dag_maker, create_task
class TestDataProcHadoopOperator(unittest.TestCase):
args = ["wordcount", "gs://pub/shakespeare/rose.txt"]
jar = "file:///usr/lib/spark/examples/jars/spark-examples.jar"
job_name = "simple"
job_id = "uuid_id"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hadoop_job": {"main_jar_file_uri": jar, "args": args},
Expand All @@ -1529,6 +1542,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitHadoopJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1542,8 +1556,9 @@ def test_execute(self, mock_hook, mock_uuid):
class TestDataProcPySparkOperator(unittest.TestCase):
uri = "gs://{}/{}"
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pyspark_job": {"main_python_file_uri": uri},
Expand All @@ -1562,7 +1577,11 @@ def test_execute(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitPySparkJobOperator(
task_id=TASK_ID, region=GCP_LOCATION, gcp_conn_id=GCP_CONN_ID, main=self.uri
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
main=self.uri,
)
job = op.generate_job()
assert self.job == job
Expand Down