Skip to content

Commit

Permalink
Refactor hook and operator
Browse files Browse the repository at this point in the history
  • Loading branch information
moiseenkov committed Aug 19, 2024
1 parent 0e12dcc commit e5af5c4
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 137 deletions.
78 changes: 3 additions & 75 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,7 @@ def launch_beam_yaml_job(
options: dict[str, Any] | None,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
on_new_job_callback: Callable[[dict], None] | None = None,
) -> dict[str, Any]:
) -> str:
"""
Launch a Dataflow YAML job and run it until completion.
Expand All @@ -966,7 +965,7 @@ def launch_beam_yaml_job(
:param project_id: The ID of the GCP project that owns the job.
:param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'.
:param on_new_job_callback: Callback function that passes the job to the operator once known.
:return: Dictionary containing the job's data.
:return: Job ID.
"""
job_name = self.build_dataflow_job_name(job_name, append_job_name)
cmd = self._build_yaml_gcloud_command(
Expand All @@ -978,78 +977,7 @@ def launch_beam_yaml_job(
jinja_variables=jinja_variables,
)
job_id = self._create_dataflow_job_with_gcloud(cmd=cmd)
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
wait_until_finished=self.wait_until_finished,
expected_terminal_state=self.expected_terminal_state,
)
job = jobs_controller.fetch_job_by_id(job_id=job_id)

if on_new_job_callback:
on_new_job_callback(job)

jobs_controller.wait_for_done()

return jobs_controller.fetch_job_by_id(job_id=job_id)

@GoogleBaseHook.fallback_to_default_project_id
def launch_beam_yaml_job_deferrable(
self,
*,
job_name: str,
yaml_pipeline_file: str,
append_job_name: bool,
jinja_variables: dict[str, str] | None,
options: dict[str, Any] | None,
project_id: str,
location: str = DEFAULT_DATAFLOW_LOCATION,
) -> dict[str, Any]:
"""
Launch a Dataflow YAML job and exit without waiting for its completion.
:param job_name: The unique name to assign to the Cloud Dataflow job.
:param yaml_pipeline_file: Path to a file defining the YAML pipeline to run.
Must be a local file or a URL beginning with 'gs://'.
:param append_job_name: Set to True if a unique suffix has to be appended to the `job_name`.
:param jinja_variables: A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file.
:param options: Additional gcloud or Beam job parameters.
It must be a dictionary with the keys matching the optional flag names in gcloud.
The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`.
Note that if a flag does not require a value, then its dictionary value must be either True or None.
For example, the `--log-http` flag can be passed as {'log-http': True}.
:param project_id: The ID of the GCP project that owns the job.
:param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'.
:return: Dictionary containing the job's data.
"""
job_name = self.build_dataflow_job_name(job_name=job_name, append_job_name=append_job_name)
cmd = self._build_yaml_gcloud_command(
job_name=job_name,
yaml_pipeline_file=yaml_pipeline_file,
project_id=project_id,
region=location,
jinja_variables=jinja_variables,
options=options,
)
job_id = self._create_dataflow_job_with_gcloud(cmd=cmd)
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
drain_pipeline=self.drain_pipeline,
expected_terminal_state=self.expected_terminal_state,
cancel_timeout=self.cancel_timeout,
)

return jobs_controller.fetch_job_by_id(job_id=job_id)
return job_id

def _build_yaml_gcloud_command(
self,
Expand Down
76 changes: 32 additions & 44 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import copy
import re
import uuid
import warnings
from contextlib import ExitStack
from enum import Enum
from functools import cached_property
Expand Down Expand Up @@ -949,6 +948,11 @@ def on_kill(self) -> None:
)


@deprecated(
reason="DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. "
"Please use DataflowStartYamlJobOperator instead.",
category=AirflowProviderDeprecationWarning,
)
class DataflowStartSqlJobOperator(GoogleCloudBaseOperator):
"""
Starts Dataflow SQL query.
Expand Down Expand Up @@ -1024,12 +1028,6 @@ def __init__(
self.hook: DataflowHook | None = None

def execute(self, context: Context):
warnings.warn(
"DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. Please use DataflowStartYamlJobOperator instead.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)

self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
drain_pipeline=self.drain_pipeline,
Expand Down Expand Up @@ -1158,51 +1156,41 @@ def __init__(
self.options = options
self.jinja_variables = jinja_variables
self.impersonation_chain = impersonation_chain
self.job: dict[str, Any] | None = None
self.job_id: str | None = None

def execute(self, context: Context) -> dict[str, Any]:
def set_current_job(current_job: dict[str, Any]):
self.job = current_job
DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"])

if not self.deferrable:
job = self.hook.launch_beam_yaml_job(
job_name=self.job_name,
yaml_pipeline_file=self.yaml_pipeline_file,
append_job_name=self.append_job_name,
options=self.options,
jinja_variables=self.jinja_variables,
project_id=self.project_id,
location=self.region,
on_new_job_callback=set_current_job,
)
return job

self.job = self.hook.launch_beam_yaml_job_deferrable(
self.job_id = self.hook.launch_beam_yaml_job(
job_name=self.job_name,
yaml_pipeline_file=self.yaml_pipeline_file,
append_job_name=self.append_job_name,
jinja_variables=self.jinja_variables,
options=self.options,
jinja_variables=self.jinja_variables,
project_id=self.project_id,
location=self.region,
)

DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"])
DataflowJobLink.persist(self, context, self.project_id, self.region, self.job_id)

self.defer(
trigger=DataflowStartYamlJobTrigger(
job_id=self.job["id"],
project_id=self.project_id,
location=self.region,
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
expected_terminal_state=self.expected_terminal_state,
impersonation_chain=self.impersonation_chain,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
if self.deferrable:
self.defer(
trigger=DataflowStartYamlJobTrigger(
job_id=self.job_id,
project_id=self.project_id,
location=self.region,
gcp_conn_id=self.gcp_conn_id,
poll_sleep=self.poll_sleep,
cancel_timeout=self.cancel_timeout,
expected_terminal_state=self.expected_terminal_state,
impersonation_chain=self.impersonation_chain,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)

self.hook.wait_for_done(
job_name=self.job_name, location=self.region, project_id=self.project_id, job_id=self.job_id
)
job = self.hook.get_job(job_id=self.job_id, location=self.region, project_id=self.project_id)
return job

def execute_complete(self, context: Context, event: dict) -> dict[str, Any]:
"""Execute after the trigger returns an event."""
Expand All @@ -1223,11 +1211,11 @@ def on_kill(self):
state.
"""
self.log.info("On kill called.")
if self.job:
if self.job_id:
self.hook.cancel_job(
job_id=self.job.get("id"),
project_id=self.job.get("projectId"),
location=self.job.get("location"),
job_id=self.job_id,
project_id=self.project_id,
location=self.region,
)

@cached_property
Expand Down
1 change: 1 addition & 0 deletions tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# If the deprecation is postponed, the item should be added to this tuple,
# and a corresponding Issue should be created on GitHub.
"tests/system/providers/google/cloud/bigquery/example_bigquery_operations.py",
"tests/system/providers/google/cloud/dataflow/example_dataflow_sql.py",
"tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py",
"tests/system/providers/google/cloud/datapipelines/example_datapipeline.py",
"tests/system/providers/google/cloud/gcs/example_gcs_sensor.py",
Expand Down
26 changes: 8 additions & 18 deletions tests/providers/google/cloud/operators/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,15 +716,15 @@ def test_execute(self, mock_hook):
"DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. "
"Please use DataflowStartYamlJobOperator instead."
)
start_sql = DataflowStartSqlJobOperator(
task_id="start_sql_query",
job_name=TEST_SQL_JOB_NAME,
query=TEST_SQL_QUERY,
options=deepcopy(TEST_SQL_OPTIONS),
location=TEST_LOCATION,
do_xcom_push=True,
)
with pytest.warns(AirflowProviderDeprecationWarning, match=warning_msg):
start_sql = DataflowStartSqlJobOperator(
task_id="start_sql_query",
job_name=TEST_SQL_JOB_NAME,
query=TEST_SQL_QUERY,
options=deepcopy(TEST_SQL_OPTIONS),
location=TEST_LOCATION,
do_xcom_push=True,
)
start_sql.execute(mock.MagicMock())

mock_hook.assert_called_once_with(
Expand Down Expand Up @@ -794,7 +794,6 @@ def test_execute(self, mock_hook, sync_operator):
jinja_variables=None,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
on_new_job_callback=mock.ANY,
)

@mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.defer")
Expand All @@ -809,15 +808,6 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr
expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING,
gcp_conn_id=GCP_CONN_ID,
)
mock_hook.return_value.launch_beam_yaml_job_deferrable.assert_called_once_with(
job_name=deferrable_operator.job_name,
yaml_pipeline_file=deferrable_operator.yaml_pipeline_file,
append_job_name=False,
options=None,
jinja_variables=None,
project_id=TEST_PROJECT,
location=TEST_LOCATION,
)
mock_defer_method.assert_called_once()

@mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push")
Expand Down

0 comments on commit e5af5c4

Please sign in to comment.