From b2142dba21653850a6e64316658289b1bb379bfa Mon Sep 17 00:00:00 2001 From: Eugene Galan Date: Mon, 29 Jul 2024 15:00:32 +0000 Subject: [PATCH 1/2] Add DataflowStartYamlJobOperator --- .../providers/google/cloud/hooks/dataflow.py | 187 ++++++++++++++++- .../google/cloud/operators/dataflow.py | 194 +++++++++++++++++- .../google/cloud/triggers/dataflow.py | 146 ++++++++++++- .../operators/cloud/dataflow.rst | 32 +++ .../google/cloud/operators/test_dataflow.py | 107 +++++++++- .../google/cloud/triggers/test_dataflow.py | 141 ++++++++++++- .../cloud/dataflow/example_dataflow_yaml.py | 172 ++++++++++++++++ 7 files changed, 965 insertions(+), 14 deletions(-) create mode 100644 tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 8f7e2e25499f26..fccc171e09b40f 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -186,9 +186,9 @@ class DataflowJobType: class _DataflowJobsController(LoggingMixin): """ - Interface for communication with Google API. + Interface for communication with Google Cloud Dataflow API. - It's not use Apache Beam, but only Google Dataflow API. + Does not use Apache Beam API. :param dataflow: Discovery resource :param project_number: The Google Cloud Project ID. @@ -271,12 +271,12 @@ def _get_current_jobs(self) -> list[dict]: else: raise ValueError("Missing both dataflow job ID and name.") - def fetch_job_by_id(self, job_id: str) -> dict: + def fetch_job_by_id(self, job_id: str) -> dict[str, str]: """ Fetch the job with the specified Job ID. - :param job_id: Job ID to get. - :return: the Job + :param job_id: ID of the job that needs to be fetched. + :return: Dictionary containing the Job's data """ return ( self._dataflow.projects() @@ -444,7 +444,6 @@ def _check_dataflow_job_state(self, job) -> bool: "Google Cloud Dataflow job's expected terminal state cannot be " "JOB_STATE_DRAINED while it is a batch job" ) - if current_state == current_expected_state: if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING: return not self._wait_until_finished @@ -938,6 +937,182 @@ def launch_job_with_flex_template( response: dict = request.execute(num_retries=self.num_retries) return response["job"] + @GoogleBaseHook.fallback_to_default_project_id + def launch_beam_yaml_job( + self, + *, + job_name: str, + yaml_pipeline_file: str, + append_job_name: bool, + jinja_variables: dict[str, str] | None, + options: dict[str, Any] | None, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + on_new_job_callback: Callable[[dict], None] | None = None, + ) -> dict[str, Any]: + """ + Launch a Dataflow YAML job and run it until completion. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :param yaml_pipeline_file: Path to a file defining the YAML pipeline to run. + Must be a local file or a URL beginning with 'gs://'. + :param append_job_name: Set to True if a unique suffix has to be appended to the `job_name`. + :param jinja_variables: A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. + :param options: Additional gcloud or Beam job parameters. + It must be a dictionary with the keys matching the optional flag names in gcloud. + The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. + Note that if a flag does not require a value, then its dictionary value must be either True or None. + For example, the `--log-http` flag can be passed as {'log-http': True}. + :param project_id: The ID of the GCP project that owns the job. + :param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'. + :param on_new_job_callback: Callback function that passes the job to the operator once known. + :return: Dictionary containing the job's data. + """ + job_name = self.build_dataflow_job_name(job_name, append_job_name) + cmd = self._build_yaml_gcloud_command( + job_name=job_name, + yaml_pipeline_file=yaml_pipeline_file, + project_id=project_id, + region=location, + options=options, + jinja_variables=jinja_variables, + ) + job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + drain_pipeline=self.drain_pipeline, + wait_until_finished=self.wait_until_finished, + expected_terminal_state=self.expected_terminal_state, + ) + job = jobs_controller.fetch_job_by_id(job_id=job_id) + + if on_new_job_callback: + on_new_job_callback(job) + + jobs_controller.wait_for_done() + + return jobs_controller.fetch_job_by_id(job_id=job_id) + + @GoogleBaseHook.fallback_to_default_project_id + def launch_beam_yaml_job_deferrable( + self, + *, + job_name: str, + yaml_pipeline_file: str, + append_job_name: bool, + jinja_variables: dict[str, str] | None, + options: dict[str, Any] | None, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> dict[str, Any]: + """ + Launch a Dataflow YAML job and exit without waiting for its completion. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :param yaml_pipeline_file: Path to a file defining the YAML pipeline to run. + Must be a local file or a URL beginning with 'gs://'. + :param append_job_name: Set to True if a unique suffix has to be appended to the `job_name`. + :param jinja_variables: A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. + :param options: Additional gcloud or Beam job parameters. + It must be a dictionary with the keys matching the optional flag names in gcloud. + The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. + Note that if a flag does not require a value, then its dictionary value must be either True or None. + For example, the `--log-http` flag can be passed as {'log-http': True}. + :param project_id: The ID of the GCP project that owns the job. + :param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'. + :return: Dictionary containing the job's data. + """ + job_name = self.build_dataflow_job_name(job_name=job_name, append_job_name=append_job_name) + cmd = self._build_yaml_gcloud_command( + job_name=job_name, + yaml_pipeline_file=yaml_pipeline_file, + project_id=project_id, + region=location, + jinja_variables=jinja_variables, + options=options, + ) + job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) + jobs_controller = _DataflowJobsController( + dataflow=self.get_conn(), + project_number=project_id, + job_id=job_id, + location=location, + poll_sleep=self.poll_sleep, + num_retries=self.num_retries, + drain_pipeline=self.drain_pipeline, + expected_terminal_state=self.expected_terminal_state, + cancel_timeout=self.cancel_timeout, + ) + + return jobs_controller.fetch_job_by_id(job_id=job_id) + + def _build_yaml_gcloud_command( + self, + job_name: str, + yaml_pipeline_file: str, + project_id: str, + region: str, + options: dict[str, Any] | None, + jinja_variables: dict[str, str] | None, + ) -> list[str]: + gcp_flags = { + "yaml-pipeline-file": yaml_pipeline_file, + "project": project_id, + "format": "value(job.id)", + "region": region, + } + + if jinja_variables: + gcp_flags["jinja-variables"] = json.dumps(jinja_variables) + + if options: + gcp_flags.update(options) + + if self.impersonation_chain: + if isinstance(self.impersonation_chain, str): + impersonation_account = self.impersonation_chain + elif len(self.impersonation_chain) == 1: + impersonation_account = self.impersonation_chain[0] + else: + raise AirflowException( + "Chained list of accounts is not supported, please specify only one service account." + ) + gcp_flags.update({"impersonate-service-account": impersonation_account}) + + return [ + "gcloud", + "dataflow", + "yaml", + "run", + job_name, + *(beam_options_to_args(gcp_flags)), + ] + + def _create_dataflow_job_with_gcloud(self, cmd: list[str]) -> str: + """Create a Dataflow job with a gcloud command and return the job's ID.""" + self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) + success_code = 0 + + with self.provide_authorized_gcloud(): + proc = subprocess.run(cmd, capture_output=True) + + if proc.returncode != success_code: + stderr_last_20_lines = "\n".join(proc.stderr.decode().strip().splitlines()[-20:]) + raise AirflowException( + f"Process exit with non-zero exit code. Exit code: {proc.returncode}. Error Details : " + f"{stderr_last_20_lines}" + ) + + job_id = proc.stdout.decode().strip() + self.log.info("Created job's ID: %s", job_id) + + return job_id + @staticmethod def extract_job_id(job: dict) -> str: try: diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 625356d50e0adc..4d47017f889d44 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -22,6 +22,7 @@ import copy import re import uuid +import warnings from contextlib import ExitStack from enum import Enum from functools import cached_property @@ -40,7 +41,10 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.links.dataflow import DataflowJobLink, DataflowPipelineLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.dataflow import TemplateJobStartTrigger +from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowStartYamlJobTrigger, + TemplateJobStartTrigger, +) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.providers.google.common.deprecated import deprecated from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -1021,6 +1025,12 @@ def __init__( self.hook: DataflowHook | None = None def execute(self, context: Context): + warnings.warn( + "DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. Please use DataflowStartYamlJobOperator instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, drain_pipeline=self.drain_pipeline, @@ -1051,6 +1061,188 @@ def on_kill(self) -> None: ) +class DataflowStartYamlJobOperator(GoogleCloudBaseOperator): + """ + Launch a Dataflow YAML job and return the result. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStartYamlJobOperator` + + .. warning:: + This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker + `__ + + :param job_name: Required. The unique name to assign to the Cloud Dataflow job. + :param yaml_pipeline_file: Required. Path to a file defining the YAML pipeline to run. + Must be a local file or a URL beginning with 'gs://'. + :param region: Optional. Region ID of the job's regional endpoint. Defaults to 'us-central1'. + :param project_id: Required. The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :param gcp_conn_id: Optional. The connection ID used to connect to GCP. + :param append_job_name: Optional. Set to True if a unique suffix has to be appended to the `job_name`. + Defaults to True. + :param drain_pipeline: Optional. Set to True if you want to stop a streaming pipeline job by draining it + instead of canceling when killing the task instance. Note that this does not work for batch pipeline jobs + or in the deferrable mode. Defaults to False. + For more info see: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param deferrable: Optional. Run operator in the deferrable mode. + :param expected_terminal_state: Optional. The expected terminal state of the Dataflow job at which the + operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the batch jobs and 'JOB_STATE_RUNNING' + for the streaming jobs. + :param poll_sleep: Optional. The time in seconds to sleep between polling Google Cloud Platform for the Dataflow job status. + Used both for the sync and deferrable mode. + :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be + successfully canceled when the task is being killed. + :param jinja_variables: Optional. A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. + :param options: Optional. Additional gcloud or Beam job parameters. + It must be a dictionary with the keys matching the optional flag names in gcloud. + The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. + Note that if a flag does not require a value, then its dictionary value must be either True or None. + For example, the `--log-http` flag can be passed as {'log-http': True}. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :return: Dictionary containing the job's data. + """ + + template_fields: Sequence[str] = ( + "job_name", + "yaml_pipeline_file", + "jinja_variables", + "options", + "region", + "project_id", + "gcp_conn_id", + ) + template_fields_renderers = { + "jinja_variables": "json", + } + operator_extra_links = (DataflowJobLink(),) + + def __init__( + self, + *, + job_name: str, + yaml_pipeline_file: str, + region: str = DEFAULT_DATAFLOW_LOCATION, + project_id: str = PROVIDE_PROJECT_ID, + gcp_conn_id: str = "google_cloud_default", + append_job_name: bool = True, + drain_pipeline: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_sleep: int = 10, + cancel_timeout: int | None = 5 * 60, + expected_terminal_state: str | None = None, + jinja_variables: dict[str, str] | None = None, + options: dict[str, Any] | None = None, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.yaml_pipeline_file = yaml_pipeline_file + self.region = region + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.append_job_name = append_job_name + self.drain_pipeline = drain_pipeline + self.deferrable = deferrable + self.poll_sleep = poll_sleep + self.cancel_timeout = cancel_timeout + self.expected_terminal_state = expected_terminal_state + self.options = options + self.jinja_variables = jinja_variables + self.impersonation_chain = impersonation_chain + self.job: dict[str, Any] | None = None + + def execute(self, context: Context) -> dict[str, Any]: + def set_current_job(current_job: dict[str, Any]): + self.job = current_job + DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"]) + + if not self.deferrable: + job = self.hook.launch_beam_yaml_job( + job_name=self.job_name, + yaml_pipeline_file=self.yaml_pipeline_file, + append_job_name=self.append_job_name, + options=self.options, + jinja_variables=self.jinja_variables, + project_id=self.project_id, + location=self.region, + on_new_job_callback=set_current_job, + ) + return job + + self.job = self.hook.launch_beam_yaml_job_deferrable( + job_name=self.job_name, + yaml_pipeline_file=self.yaml_pipeline_file, + append_job_name=self.append_job_name, + jinja_variables=self.jinja_variables, + options=self.options, + project_id=self.project_id, + location=self.region, + ) + + DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"]) + + self.defer( + trigger=DataflowStartYamlJobTrigger( + job_id=self.job["id"], + project_id=self.project_id, + location=self.region, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + cancel_timeout=self.cancel_timeout, + expected_terminal_state=self.expected_terminal_state, + impersonation_chain=self.impersonation_chain, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) + + def execute_complete(self, context: Context, event: dict) -> dict[str, Any]: + """Execute after the trigger returns an event.""" + if event["status"] in ("error", "stopped"): + self.log.info("status: %s, msg: %s", event["status"], event["message"]) + raise AirflowException(event["message"]) + job = event["job"] + self.log.info("Job %s completed with response %s", job["id"], event["message"]) + self.xcom_push(context, key="job_id", value=job["id"]) + + return job + + def on_kill(self): + """ + Cancel the dataflow job if a task instance gets killed. + + This method will not be called if a task instance is killed in a deferred + state. + """ + self.log.info("On kill called.") + if self.job: + self.hook.cancel_job( + job_id=self.job.get("id"), + project_id=self.job.get("projectId"), + location=self.job.get("location"), + ) + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + expected_terminal_state=self.expected_terminal_state, + ) + + # TODO: Remove one day @deprecated( planned_removal_date="November 01, 2024", diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py index 01f96b98b78e0f..4b994bf6e4d7d8 100644 --- a/airflow/providers/google/cloud/triggers/dataflow.py +++ b/airflow/providers/google/cloud/triggers/dataflow.py @@ -24,8 +24,10 @@ from google.cloud.dataflow_v1beta3 import JobState from google.cloud.dataflow_v1beta3.types import ( AutoscalingEvent, + Job, JobMessage, JobMetrics, + JobType, MetricUpdate, ) @@ -157,7 +159,7 @@ def _get_async_hook(self) -> AsyncDataflowHook: class DataflowJobStatusTrigger(BaseTrigger): """ - Trigger that checks for metrics associated with a Dataflow job. + Trigger that monitors if a Dataflow job has reached any of the expected statuses. :param job_id: Required. ID of the job. :param expected_statuses: The expected state(s) of the operation. @@ -266,6 +268,148 @@ def async_hook(self) -> AsyncDataflowHook: ) +class DataflowStartYamlJobTrigger(BaseTrigger): + """ + Dataflow trigger that checks the state of a Dataflow YAML job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: The location where job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param poll_sleep: Optional. The time in seconds to sleep between polling Google Cloud Platform + for the Dataflow job. + :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :param expected_terminal_state: Optional. The expected terminal state of the Dataflow job at which the + operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the batch jobs and + 'JOB_STATE_RUNNING' for the streaming jobs. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + cancel_timeout: int | None = 5 * 60, + expected_terminal_state: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.project_id = project_id + self.job_id = job_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.cancel_timeout = cancel_timeout + self.expected_terminal_state = expected_terminal_state + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger", + { + "project_id": self.project_id, + "job_id": self.job_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "expected_terminal_state": self.expected_terminal_state, + "impersonation_chain": self.impersonation_chain, + "cancel_timeout": self.cancel_timeout, + }, + ) + + async def run(self): + """ + Fetch job and yield events depending on the job's type and state. + + Yield TriggerEvent if the job reaches a terminal state. + Otherwise awaits for a specified amount of time stored in self.poll_sleep variable. + """ + hook: AsyncDataflowHook = self._get_async_hook() + try: + while True: + job: Job = await hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_state = job.current_state + job_type = job.type_ + if job_state.name == self.expected_terminal_state: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": f"Job reached the expected terminal state: {self.expected_terminal_state}.", + } + ) + return + elif job_type == JobType.JOB_TYPE_STREAMING and job_state == JobState.JOB_STATE_RUNNING: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": "Streaming job reached the RUNNING state.", + } + ) + return + elif job_type == JobType.JOB_TYPE_BATCH and job_state == JobState.JOB_STATE_DONE: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": "Batch job completed.", + } + ) + return + elif job_state == JobState.JOB_STATE_FAILED: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "error", + "message": "Job failed.", + } + ) + return + elif job_state == JobState.JOB_STATE_STOPPED: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "stopped", + "message": "Job was stopped.", + } + ) + return + else: + self.log.info("Current job status is: %s", job_state.name) + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"job": None, "status": "error", "message": str(e)}) + + def _get_async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.cancel_timeout, + ) + + class DataflowJobMetricsTrigger(BaseTrigger): """ Trigger that checks for metrics associated with a Dataflow job. diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 71fc3275fe99f4..a9eb98ea9a5097 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -306,6 +306,38 @@ Here is an example of running Dataflow SQL job with See the `Dataflow SQL reference `_. +.. _howto/operator:DataflowStartYamlJobOperator: + +Dataflow YAML +"""""""""""""" +Beam YAML is a no-code SDK for configuring Apache Beam pipelines by using YAML files. +You can use Beam YAML to author and run a Beam pipeline without writing any code. +This API can be used to define both streaming and batch pipelines. + +Here is an example of running Dataflow YAML job with +:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartYamlJobOperator`: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataflow_start_yaml_job] + :end-before: [END howto_operator_dataflow_start_yaml_job] + +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataflow_start_yaml_job_def] + :end-before: [END howto_operator_dataflow_start_yaml_job_def] + +.. warning:: + This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker + `__ + +See the `Dataflow YAML reference +`_. + .. _howto/operator:DataflowStopJobOperator: Stopping a pipeline diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index a024ffd7a7503a..2e60a2567fd109 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -41,6 +41,7 @@ DataflowRunPipelineOperator, DataflowStartFlexTemplateOperator, DataflowStartSqlJobOperator, + DataflowStartYamlJobOperator, DataflowStopJobOperator, DataflowTemplatedJobStartOperator, ) @@ -711,6 +712,10 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr class TestDataflowStartSqlJobOperator: @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") def test_execute(self, mock_hook): + warning_msg = ( + "DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. " + "Please use DataflowStartYamlJobOperator instead." + ) start_sql = DataflowStartSqlJobOperator( task_id="start_sql_query", job_name=TEST_SQL_JOB_NAME, @@ -719,8 +724,9 @@ def test_execute(self, mock_hook): location=TEST_LOCATION, do_xcom_push=True, ) + with pytest.warns(AirflowProviderDeprecationWarning, match=warning_msg): + start_sql.execute(mock.MagicMock()) - start_sql.execute(mock.MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id="google_cloud_default", drain_pipeline=False, @@ -741,6 +747,105 @@ def test_execute(self, mock_hook): ) +class TestDataflowStartYamlJobOperator: + @pytest.fixture + def sync_operator(self): + return DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_sync", + job_name="dataflow_yaml_job", + yaml_pipeline_file="test_file_path", + append_job_name=False, + project_id=TEST_PROJECT, + region=TEST_LOCATION, + gcp_conn_id=GCP_CONN_ID, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + ) + + @pytest.fixture + def deferrable_operator(self): + return DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_def", + job_name="dataflow_yaml_job", + yaml_pipeline_file="test_file_path", + append_job_name=False, + project_id=TEST_PROJECT, + region=TEST_LOCATION, + gcp_conn_id=GCP_CONN_ID, + deferrable=True, + expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING, + ) + + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute(self, mock_hook, sync_operator): + sync_operator.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + poll_sleep=sync_operator.poll_sleep, + drain_pipeline=False, + impersonation_chain=None, + cancel_timeout=sync_operator.cancel_timeout, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + gcp_conn_id=GCP_CONN_ID, + ) + mock_hook.return_value.launch_beam_yaml_job.assert_called_once_with( + job_name=sync_operator.job_name, + yaml_pipeline_file=sync_operator.yaml_pipeline_file, + append_job_name=False, + options=None, + jinja_variables=None, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + on_new_job_callback=mock.ANY, + ) + + @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.defer") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferrable_operator): + deferrable_operator.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + poll_sleep=deferrable_operator.poll_sleep, + drain_pipeline=False, + impersonation_chain=None, + cancel_timeout=deferrable_operator.cancel_timeout, + expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING, + gcp_conn_id=GCP_CONN_ID, + ) + mock_hook.return_value.launch_beam_yaml_job_deferrable.assert_called_once_with( + job_name=deferrable_operator.job_name, + yaml_pipeline_file=deferrable_operator.yaml_pipeline_file, + append_job_name=False, + options=None, + jinja_variables=None, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + ) + mock_defer_method.assert_called_once() + + @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute_complete_success(self, mock_hook, mock_xcom_push, deferrable_operator): + expected_result = {"id": JOB_ID} + actual_result = deferrable_operator.execute_complete( + context=None, + event={ + "status": "success", + "message": "Batch job completed.", + "job": expected_result, + }, + ) + mock_xcom_push.assert_called_with(None, key="job_id", value=JOB_ID) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self, deferrable_operator): + with pytest.raises(AirflowException, match="Job failed."): + deferrable_operator.execute_complete( + context=None, event={"status": "error", "message": "Job failed."} + ) + with pytest.raises(AirflowException, match="Job was stopped."): + deferrable_operator.execute_complete( + context=None, event={"status": "stopped", "message": "Job was stopped."} + ) + + class TestDataflowStopJobOperator: @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") def test_exec_job_id(self, dataflow_mock): diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py b/tests/providers/google/cloud/triggers/test_dataflow.py index 2b9b63afa44e31..67b2f41009c25a 100644 --- a/tests/providers/google/cloud/triggers/test_dataflow.py +++ b/tests/providers/google/cloud/triggers/test_dataflow.py @@ -22,7 +22,7 @@ from unittest import mock import pytest -from google.cloud.dataflow_v1beta3 import JobState +from google.cloud.dataflow_v1beta3 import Job, JobState, JobType from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.triggers.dataflow import ( @@ -30,6 +30,7 @@ DataflowJobMessagesTrigger, DataflowJobMetricsTrigger, DataflowJobStatusTrigger, + DataflowStartYamlJobTrigger, TemplateJobStartTrigger, ) from airflow.triggers.base import TriggerEvent @@ -108,6 +109,24 @@ def dataflow_job_status_trigger(): ) +@pytest.fixture +def dataflow_start_yaml_job_trigger(): + return DataflowStartYamlJobTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + cancel_timeout=CANCEL_TIMEOUT, + ) + + +@pytest.fixture +def test_dataflow_batch_job(): + return Job(id=JOB_ID, current_state=JobState.JOB_STATE_DONE, type_=JobType.JOB_TYPE_BATCH) + + class TestTemplateJobStartTrigger: def test_serialize(self, template_job_start_trigger): actual_data = template_job_start_trigger.serialize() @@ -548,13 +567,11 @@ async def test_run_loop_is_still_running_if_fail_on_terminal_state( mock_get_job_metrics, mock_job_status, dataflow_job_metrics_trigger, - caplog, ): """Test that DataflowJobMetricsTrigger is still in loop if the job status is RUNNING.""" dataflow_job_metrics_trigger.fail_on_terminal_state = True mock_job_status.return_value = JobState.JOB_STATE_RUNNING mock_get_job_metrics.return_value = [] - caplog.set_level(logging.INFO) task = asyncio.create_task(dataflow_job_metrics_trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is False @@ -703,12 +720,10 @@ async def test_run_loop_is_still_running_if_state_is_not_terminal_or_expected( self, mock_job_status, dataflow_job_status_trigger, - caplog, ): """Test that DataflowJobStatusTrigger is still in loop if the job status neither terminal nor expected.""" dataflow_job_status_trigger.expected_statuses = {DataflowJobStatus.JOB_STATE_DONE} mock_job_status.return_value = JobState.JOB_STATE_RUNNING - caplog.set_level(logging.INFO) task = asyncio.create_task(dataflow_job_status_trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is False @@ -729,3 +744,119 @@ async def test_run_raises_exception(self, mock_job_status, dataflow_job_status_t ) actual_event = await dataflow_job_status_trigger.run().asend(None) assert expected_event == actual_event + + +class TestDataflowStartYamlJobTrigger: + def test_serialize(self, dataflow_start_yaml_job_trigger): + actual_data = dataflow_start_yaml_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "expected_terminal_state": None, + "impersonation_chain": IMPERSONATION_CHAIN, + "cancel_timeout": CANCEL_TIMEOUT, + }, + ) + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ("cancel_timeout", CANCEL_TIMEOUT), + ], + ) + def test_get_async_hook(self, dataflow_start_yaml_job_trigger, attr, expected): + hook = dataflow_start_yaml_job_trigger._get_async_hook() + actual = hook._hook_kwargs.get(attr) + assert actual is not None + assert actual == expected + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_success_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "success", + "message": "Batch job completed.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_failed_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + test_dataflow_batch_job.current_state = JobState.JOB_STATE_FAILED + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "error", + "message": "Job failed.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_stopped_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + test_dataflow_batch_job.current_state = JobState.JOB_STATE_STOPPED + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "stopped", + "message": "Job was stopped.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_expected_state_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + dataflow_start_yaml_job_trigger.expected_terminal_state = DataflowJobStatus.JOB_STATE_RUNNING + test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "success", + "message": f"Job reached the expected terminal state: {DataflowJobStatus.JOB_STATE_RUNNING}.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_is_still_running( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + """Test that DataflowStartYamlJobTrigger is still in loop if the job status neither terminal nor expected.""" + dataflow_start_yaml_job_trigger.expected_terminal_state = DataflowJobStatus.JOB_STATE_STOPPED + test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING + mock_get_job.return_value = test_dataflow_batch_job + task = asyncio.create_task(dataflow_start_yaml_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + task.cancel() diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py new file mode 100644 index 00000000000000..2bddc1de3fa466 --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow YAML service. + +Requirements: + This test requires ``gcloud`` command (Google Cloud SDK) to be installed on the Airflow worker + `__ +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryInsertJobOperator, +) +from airflow.providers.google.cloud.operators.dataflow import DataflowStartYamlJobOperator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "dataflow_yaml" +REGION = "europe-west2" +DATAFLOW_YAML_JOB_NAME = f"{DAG_ID}_{ENV_ID}".replace("_", "-") +BQ_DATASET = f"{DAG_ID}_{ENV_ID}".replace("-", "_") +BQ_INPUT_TABLE = f"input_{DAG_ID}".replace("-", "_") +BQ_OUTPUT_TABLE = f"output_{DAG_ID}".replace("-", "_") +DATAFLOW_YAML_PIPELINE_FILE_URL = ( + "gs://airflow-system-tests-resources/dataflow/yaml/example_beam_yaml_bq.yaml" +) + +BQ_VARIABLES = { + "project": PROJECT_ID, + "dataset": BQ_DATASET, + "input": BQ_INPUT_TABLE, + "output": BQ_OUTPUT_TABLE, +} + +BQ_VARIABLES_DEF = { + "project": PROJECT_ID, + "dataset": BQ_DATASET, + "input": BQ_INPUT_TABLE, + "output": f"{BQ_OUTPUT_TABLE}_def", +} + +INSERT_ROWS_QUERY = ( + f"INSERT {BQ_DATASET}.{BQ_INPUT_TABLE} VALUES " + "('John Doe', 900, 'USA'), " + "('Alice Storm', 1200, 'Australia')," + "('Bob Max', 1000, 'Australia')," + "('Peter Jackson', 800, 'New Zealand')," + "('Hobby Doyle', 1100, 'USA')," + "('Terrance Phillips', 2222, 'Canada')," + "('Joe Schmoe', 1500, 'Canada')," + "('Dominique Levillaine', 2780, 'France');" +) + + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "dataflow", "yaml"], +) as dag: + create_bq_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_bq_dataset", + dataset_id=BQ_DATASET, + location=REGION, + ) + + create_bq_input_table = BigQueryCreateEmptyTableOperator( + task_id="create_bq_input_table", + dataset_id=BQ_DATASET, + table_id=BQ_INPUT_TABLE, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "country", "type": "STRING", "mode": "NULLABLE"}, + ], + ) + + insert_data_into_bq_table = BigQueryInsertJobOperator( + task_id="insert_data_into_bq_table", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + "priority": "BATCH", + } + }, + location=REGION, + ) + + # [START howto_operator_dataflow_start_yaml_job] + start_dataflow_yaml_job = DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job", + job_name=DATAFLOW_YAML_JOB_NAME, + yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL, + append_job_name=True, + deferrable=False, + region=REGION, + project_id=PROJECT_ID, + jinja_variables=BQ_VARIABLES, + ) + # [END howto_operator_dataflow_start_yaml_job] + + # [START howto_operator_dataflow_start_yaml_job_def] + start_dataflow_yaml_job_def = DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_def", + job_name=DATAFLOW_YAML_JOB_NAME, + yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL, + append_job_name=True, + deferrable=True, + region=REGION, + project_id=PROJECT_ID, + jinja_variables=BQ_VARIABLES_DEF, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + ) + # [END howto_operator_dataflow_start_yaml_job_def] + + delete_bq_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_bq_dataset", + dataset_id=BQ_DATASET, + delete_contents=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_bq_dataset + >> create_bq_input_table + >> insert_data_into_bq_table + # TEST BODY + >> [start_dataflow_yaml_job, start_dataflow_yaml_job_def] + # TEST TEARDOWN + >> delete_bq_dataset + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From 5a4a34d7fec1bade9411d8d41279542a226d270d Mon Sep 17 00:00:00 2001 From: Maksim Moiseenkov Date: Fri, 2 Aug 2024 12:22:38 +0000 Subject: [PATCH 2/2] Refactor hook and operator --- .../providers/google/cloud/hooks/dataflow.py | 156 +++--------------- .../google/cloud/operators/dataflow.py | 76 ++++----- tests/always/test_example_dags.py | 1 + .../google/cloud/operators/test_dataflow.py | 32 +--- 4 files changed, 65 insertions(+), 200 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index fccc171e09b40f..97eaa49b36d523 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -948,8 +948,7 @@ def launch_beam_yaml_job( options: dict[str, Any] | None, project_id: str, location: str = DEFAULT_DATAFLOW_LOCATION, - on_new_job_callback: Callable[[dict], None] | None = None, - ) -> dict[str, Any]: + ) -> str: """ Launch a Dataflow YAML job and run it until completion. @@ -966,105 +965,13 @@ def launch_beam_yaml_job( :param project_id: The ID of the GCP project that owns the job. :param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'. :param on_new_job_callback: Callback function that passes the job to the operator once known. - :return: Dictionary containing the job's data. - """ - job_name = self.build_dataflow_job_name(job_name, append_job_name) - cmd = self._build_yaml_gcloud_command( - job_name=job_name, - yaml_pipeline_file=yaml_pipeline_file, - project_id=project_id, - region=location, - options=options, - jinja_variables=jinja_variables, - ) - job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) - jobs_controller = _DataflowJobsController( - dataflow=self.get_conn(), - project_number=project_id, - job_id=job_id, - location=location, - poll_sleep=self.poll_sleep, - num_retries=self.num_retries, - drain_pipeline=self.drain_pipeline, - wait_until_finished=self.wait_until_finished, - expected_terminal_state=self.expected_terminal_state, - ) - job = jobs_controller.fetch_job_by_id(job_id=job_id) - - if on_new_job_callback: - on_new_job_callback(job) - - jobs_controller.wait_for_done() - - return jobs_controller.fetch_job_by_id(job_id=job_id) - - @GoogleBaseHook.fallback_to_default_project_id - def launch_beam_yaml_job_deferrable( - self, - *, - job_name: str, - yaml_pipeline_file: str, - append_job_name: bool, - jinja_variables: dict[str, str] | None, - options: dict[str, Any] | None, - project_id: str, - location: str = DEFAULT_DATAFLOW_LOCATION, - ) -> dict[str, Any]: + :return: Job ID. """ - Launch a Dataflow YAML job and exit without waiting for its completion. - - :param job_name: The unique name to assign to the Cloud Dataflow job. - :param yaml_pipeline_file: Path to a file defining the YAML pipeline to run. - Must be a local file or a URL beginning with 'gs://'. - :param append_job_name: Set to True if a unique suffix has to be appended to the `job_name`. - :param jinja_variables: A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. - :param options: Additional gcloud or Beam job parameters. - It must be a dictionary with the keys matching the optional flag names in gcloud. - The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. - Note that if a flag does not require a value, then its dictionary value must be either True or None. - For example, the `--log-http` flag can be passed as {'log-http': True}. - :param project_id: The ID of the GCP project that owns the job. - :param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'. - :return: Dictionary containing the job's data. - """ - job_name = self.build_dataflow_job_name(job_name=job_name, append_job_name=append_job_name) - cmd = self._build_yaml_gcloud_command( - job_name=job_name, - yaml_pipeline_file=yaml_pipeline_file, - project_id=project_id, - region=location, - jinja_variables=jinja_variables, - options=options, - ) - job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) - jobs_controller = _DataflowJobsController( - dataflow=self.get_conn(), - project_number=project_id, - job_id=job_id, - location=location, - poll_sleep=self.poll_sleep, - num_retries=self.num_retries, - drain_pipeline=self.drain_pipeline, - expected_terminal_state=self.expected_terminal_state, - cancel_timeout=self.cancel_timeout, - ) - - return jobs_controller.fetch_job_by_id(job_id=job_id) - - def _build_yaml_gcloud_command( - self, - job_name: str, - yaml_pipeline_file: str, - project_id: str, - region: str, - options: dict[str, Any] | None, - jinja_variables: dict[str, str] | None, - ) -> list[str]: gcp_flags = { "yaml-pipeline-file": yaml_pipeline_file, "project": project_id, "format": "value(job.id)", - "region": region, + "region": location, } if jinja_variables: @@ -1073,6 +980,15 @@ def _build_yaml_gcloud_command( if options: gcp_flags.update(options) + job_name = self.build_dataflow_job_name(job_name, append_job_name) + cmd = self._build_gcloud_command( + command=["gcloud", "dataflow", "yaml", "run", job_name], parameters=gcp_flags + ) + job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) + return job_id + + def _build_gcloud_command(self, command: list[str], parameters: dict[str, str]) -> list[str]: + _parameters = deepcopy(parameters) if self.impersonation_chain: if isinstance(self.impersonation_chain, str): impersonation_account = self.impersonation_chain @@ -1082,16 +998,8 @@ def _build_yaml_gcloud_command( raise AirflowException( "Chained list of accounts is not supported, please specify only one service account." ) - gcp_flags.update({"impersonate-service-account": impersonation_account}) - - return [ - "gcloud", - "dataflow", - "yaml", - "run", - job_name, - *(beam_options_to_args(gcp_flags)), - ] + _parameters["impersonate-service-account"] = impersonation_account + return [*command, *(beam_options_to_args(_parameters))] def _create_dataflow_job_with_gcloud(self, cmd: list[str]) -> str: """Create a Dataflow job with a gcloud command and return the job's ID.""" @@ -1314,33 +1222,15 @@ def start_sql_job( :param on_new_job_callback: Callback called when the job is known. :return: the new job object """ - gcp_options = [ - f"--project={project_id}", - "--format=value(job.id)", - f"--job-name={job_name}", - f"--region={location}", - ] - - if self.impersonation_chain: - if isinstance(self.impersonation_chain, str): - impersonation_account = self.impersonation_chain - elif len(self.impersonation_chain) == 1: - impersonation_account = self.impersonation_chain[0] - else: - raise AirflowException( - "Chained list of accounts is not supported, please specify only one service account" - ) - gcp_options.append(f"--impersonate-service-account={impersonation_account}") - - cmd = [ - "gcloud", - "dataflow", - "sql", - "query", - query, - *gcp_options, - *(beam_options_to_args(options)), - ] + gcp_options = { + "project": project_id, + "format": "value(job.id)", + "job-name": job_name, + "region": location, + } + cmd = self._build_gcloud_command( + command=["gcloud", "dataflow", "sql", "query", query], parameters={**gcp_options, **options} + ) self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) with self.provide_authorized_gcloud(): proc = subprocess.run(cmd, capture_output=True) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 4d47017f889d44..fd4d0644a77df8 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -22,7 +22,6 @@ import copy import re import uuid -import warnings from contextlib import ExitStack from enum import Enum from functools import cached_property @@ -950,6 +949,11 @@ def on_kill(self) -> None: ) +@deprecated( + planned_removal_date="January 31, 2025", + use_instead="DataflowStartYamlJobOperator", + category=AirflowProviderDeprecationWarning, +) class DataflowStartSqlJobOperator(GoogleCloudBaseOperator): """ Starts Dataflow SQL query. @@ -1025,12 +1029,6 @@ def __init__( self.hook: DataflowHook | None = None def execute(self, context: Context): - warnings.warn( - "DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. Please use DataflowStartYamlJobOperator instead.", - AirflowProviderDeprecationWarning, - stacklevel=2, - ) - self.hook = DataflowHook( gcp_conn_id=self.gcp_conn_id, drain_pipeline=self.drain_pipeline, @@ -1159,51 +1157,41 @@ def __init__( self.options = options self.jinja_variables = jinja_variables self.impersonation_chain = impersonation_chain - self.job: dict[str, Any] | None = None + self.job_id: str | None = None def execute(self, context: Context) -> dict[str, Any]: - def set_current_job(current_job: dict[str, Any]): - self.job = current_job - DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"]) - - if not self.deferrable: - job = self.hook.launch_beam_yaml_job( - job_name=self.job_name, - yaml_pipeline_file=self.yaml_pipeline_file, - append_job_name=self.append_job_name, - options=self.options, - jinja_variables=self.jinja_variables, - project_id=self.project_id, - location=self.region, - on_new_job_callback=set_current_job, - ) - return job - - self.job = self.hook.launch_beam_yaml_job_deferrable( + self.job_id = self.hook.launch_beam_yaml_job( job_name=self.job_name, yaml_pipeline_file=self.yaml_pipeline_file, append_job_name=self.append_job_name, - jinja_variables=self.jinja_variables, options=self.options, + jinja_variables=self.jinja_variables, project_id=self.project_id, location=self.region, ) - DataflowJobLink.persist(self, context, self.project_id, self.region, self.job["id"]) + DataflowJobLink.persist(self, context, self.project_id, self.region, self.job_id) - self.defer( - trigger=DataflowStartYamlJobTrigger( - job_id=self.job["id"], - project_id=self.project_id, - location=self.region, - gcp_conn_id=self.gcp_conn_id, - poll_sleep=self.poll_sleep, - cancel_timeout=self.cancel_timeout, - expected_terminal_state=self.expected_terminal_state, - impersonation_chain=self.impersonation_chain, - ), - method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + if self.deferrable: + self.defer( + trigger=DataflowStartYamlJobTrigger( + job_id=self.job_id, + project_id=self.project_id, + location=self.region, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + cancel_timeout=self.cancel_timeout, + expected_terminal_state=self.expected_terminal_state, + impersonation_chain=self.impersonation_chain, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) + + self.hook.wait_for_done( + job_name=self.job_name, location=self.region, project_id=self.project_id, job_id=self.job_id ) + job = self.hook.get_job(job_id=self.job_id, location=self.region, project_id=self.project_id) + return job def execute_complete(self, context: Context, event: dict) -> dict[str, Any]: """Execute after the trigger returns an event.""" @@ -1224,11 +1212,11 @@ def on_kill(self): state. """ self.log.info("On kill called.") - if self.job: + if self.job_id: self.hook.cancel_job( - job_id=self.job.get("id"), - project_id=self.job.get("projectId"), - location=self.job.get("location"), + job_id=self.job_id, + project_id=self.project_id, + location=self.region, ) @cached_property diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py index 2dfcfb6e37cdd0..9d10ce5cad1994 100644 --- a/tests/always/test_example_dags.py +++ b/tests/always/test_example_dags.py @@ -51,6 +51,7 @@ # If the deprecation is postponed, the item should be added to this tuple, # and a corresponding Issue should be created on GitHub. "tests/system/providers/google/cloud/bigquery/example_bigquery_operations.py", + "tests/system/providers/google/cloud/dataflow/example_dataflow_sql.py", "tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py", "tests/system/providers/google/cloud/datapipelines/example_datapipeline.py", "tests/system/providers/google/cloud/gcs/example_gcs_sensor.py", diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 2e60a2567fd109..14787dba19b61a 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -712,19 +712,15 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr class TestDataflowStartSqlJobOperator: @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") def test_execute(self, mock_hook): - warning_msg = ( - "DataflowStartSqlJobOperator is deprecated and will be removed after 31.01.2025. " - "Please use DataflowStartYamlJobOperator instead." - ) - start_sql = DataflowStartSqlJobOperator( - task_id="start_sql_query", - job_name=TEST_SQL_JOB_NAME, - query=TEST_SQL_QUERY, - options=deepcopy(TEST_SQL_OPTIONS), - location=TEST_LOCATION, - do_xcom_push=True, - ) - with pytest.warns(AirflowProviderDeprecationWarning, match=warning_msg): + with pytest.warns(AirflowProviderDeprecationWarning): + start_sql = DataflowStartSqlJobOperator( + task_id="start_sql_query", + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=deepcopy(TEST_SQL_OPTIONS), + location=TEST_LOCATION, + do_xcom_push=True, + ) start_sql.execute(mock.MagicMock()) mock_hook.assert_called_once_with( @@ -794,7 +790,6 @@ def test_execute(self, mock_hook, sync_operator): jinja_variables=None, project_id=TEST_PROJECT, location=TEST_LOCATION, - on_new_job_callback=mock.ANY, ) @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.defer") @@ -809,15 +804,6 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING, gcp_conn_id=GCP_CONN_ID, ) - mock_hook.return_value.launch_beam_yaml_job_deferrable.assert_called_once_with( - job_name=deferrable_operator.job_name, - yaml_pipeline_file=deferrable_operator.yaml_pipeline_file, - append_job_name=False, - options=None, - jinja_variables=None, - project_id=TEST_PROJECT, - location=TEST_LOCATION, - ) mock_defer_method.assert_called_once() @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push")