diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index 8f7e2e25499f26..97eaa49b36d523 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -186,9 +186,9 @@ class DataflowJobType: class _DataflowJobsController(LoggingMixin): """ - Interface for communication with Google API. + Interface for communication with Google Cloud Dataflow API. - It's not use Apache Beam, but only Google Dataflow API. + Does not use Apache Beam API. :param dataflow: Discovery resource :param project_number: The Google Cloud Project ID. @@ -271,12 +271,12 @@ def _get_current_jobs(self) -> list[dict]: else: raise ValueError("Missing both dataflow job ID and name.") - def fetch_job_by_id(self, job_id: str) -> dict: + def fetch_job_by_id(self, job_id: str) -> dict[str, str]: """ Fetch the job with the specified Job ID. - :param job_id: Job ID to get. - :return: the Job + :param job_id: ID of the job that needs to be fetched. + :return: Dictionary containing the Job's data """ return ( self._dataflow.projects() @@ -444,7 +444,6 @@ def _check_dataflow_job_state(self, job) -> bool: "Google Cloud Dataflow job's expected terminal state cannot be " "JOB_STATE_DRAINED while it is a batch job" ) - if current_state == current_expected_state: if current_expected_state == DataflowJobStatus.JOB_STATE_RUNNING: return not self._wait_until_finished @@ -938,6 +937,90 @@ def launch_job_with_flex_template( response: dict = request.execute(num_retries=self.num_retries) return response["job"] + @GoogleBaseHook.fallback_to_default_project_id + def launch_beam_yaml_job( + self, + *, + job_name: str, + yaml_pipeline_file: str, + append_job_name: bool, + jinja_variables: dict[str, str] | None, + options: dict[str, Any] | None, + project_id: str, + location: str = DEFAULT_DATAFLOW_LOCATION, + ) -> str: + """ + Launch a Dataflow YAML job and run it until completion. + + :param job_name: The unique name to assign to the Cloud Dataflow job. + :param yaml_pipeline_file: Path to a file defining the YAML pipeline to run. + Must be a local file or a URL beginning with 'gs://'. + :param append_job_name: Set to True if a unique suffix has to be appended to the `job_name`. + :param jinja_variables: A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. + :param options: Additional gcloud or Beam job parameters. + It must be a dictionary with the keys matching the optional flag names in gcloud. + The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. + Note that if a flag does not require a value, then its dictionary value must be either True or None. + For example, the `--log-http` flag can be passed as {'log-http': True}. + :param project_id: The ID of the GCP project that owns the job. + :param location: Region ID of the job's regional endpoint. Defaults to 'us-central1'. + :param on_new_job_callback: Callback function that passes the job to the operator once known. + :return: Job ID. + """ + gcp_flags = { + "yaml-pipeline-file": yaml_pipeline_file, + "project": project_id, + "format": "value(job.id)", + "region": location, + } + + if jinja_variables: + gcp_flags["jinja-variables"] = json.dumps(jinja_variables) + + if options: + gcp_flags.update(options) + + job_name = self.build_dataflow_job_name(job_name, append_job_name) + cmd = self._build_gcloud_command( + command=["gcloud", "dataflow", "yaml", "run", job_name], parameters=gcp_flags + ) + job_id = self._create_dataflow_job_with_gcloud(cmd=cmd) + return job_id + + def _build_gcloud_command(self, command: list[str], parameters: dict[str, str]) -> list[str]: + _parameters = deepcopy(parameters) + if self.impersonation_chain: + if isinstance(self.impersonation_chain, str): + impersonation_account = self.impersonation_chain + elif len(self.impersonation_chain) == 1: + impersonation_account = self.impersonation_chain[0] + else: + raise AirflowException( + "Chained list of accounts is not supported, please specify only one service account." + ) + _parameters["impersonate-service-account"] = impersonation_account + return [*command, *(beam_options_to_args(_parameters))] + + def _create_dataflow_job_with_gcloud(self, cmd: list[str]) -> str: + """Create a Dataflow job with a gcloud command and return the job's ID.""" + self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) + success_code = 0 + + with self.provide_authorized_gcloud(): + proc = subprocess.run(cmd, capture_output=True) + + if proc.returncode != success_code: + stderr_last_20_lines = "\n".join(proc.stderr.decode().strip().splitlines()[-20:]) + raise AirflowException( + f"Process exit with non-zero exit code. Exit code: {proc.returncode}. Error Details : " + f"{stderr_last_20_lines}" + ) + + job_id = proc.stdout.decode().strip() + self.log.info("Created job's ID: %s", job_id) + + return job_id + @staticmethod def extract_job_id(job: dict) -> str: try: @@ -1139,33 +1222,15 @@ def start_sql_job( :param on_new_job_callback: Callback called when the job is known. :return: the new job object """ - gcp_options = [ - f"--project={project_id}", - "--format=value(job.id)", - f"--job-name={job_name}", - f"--region={location}", - ] - - if self.impersonation_chain: - if isinstance(self.impersonation_chain, str): - impersonation_account = self.impersonation_chain - elif len(self.impersonation_chain) == 1: - impersonation_account = self.impersonation_chain[0] - else: - raise AirflowException( - "Chained list of accounts is not supported, please specify only one service account" - ) - gcp_options.append(f"--impersonate-service-account={impersonation_account}") - - cmd = [ - "gcloud", - "dataflow", - "sql", - "query", - query, - *gcp_options, - *(beam_options_to_args(options)), - ] + gcp_options = { + "project": project_id, + "format": "value(job.id)", + "job-name": job_name, + "region": location, + } + cmd = self._build_gcloud_command( + command=["gcloud", "dataflow", "sql", "query", query], parameters={**gcp_options, **options} + ) self.log.info("Executing command: %s", " ".join(shlex.quote(c) for c in cmd)) with self.provide_authorized_gcloud(): proc = subprocess.run(cmd, capture_output=True) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 625356d50e0adc..fd4d0644a77df8 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -40,7 +40,10 @@ from airflow.providers.google.cloud.hooks.gcs import GCSHook from airflow.providers.google.cloud.links.dataflow import DataflowJobLink, DataflowPipelineLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.dataflow import TemplateJobStartTrigger +from airflow.providers.google.cloud.triggers.dataflow import ( + DataflowStartYamlJobTrigger, + TemplateJobStartTrigger, +) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME from airflow.providers.google.common.deprecated import deprecated from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -946,6 +949,11 @@ def on_kill(self) -> None: ) +@deprecated( + planned_removal_date="January 31, 2025", + use_instead="DataflowStartYamlJobOperator", + category=AirflowProviderDeprecationWarning, +) class DataflowStartSqlJobOperator(GoogleCloudBaseOperator): """ Starts Dataflow SQL query. @@ -1051,6 +1059,178 @@ def on_kill(self) -> None: ) +class DataflowStartYamlJobOperator(GoogleCloudBaseOperator): + """ + Launch a Dataflow YAML job and return the result. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:DataflowStartYamlJobOperator` + + .. warning:: + This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker + `__ + + :param job_name: Required. The unique name to assign to the Cloud Dataflow job. + :param yaml_pipeline_file: Required. Path to a file defining the YAML pipeline to run. + Must be a local file or a URL beginning with 'gs://'. + :param region: Optional. Region ID of the job's regional endpoint. Defaults to 'us-central1'. + :param project_id: Required. The ID of the GCP project that owns the job. + If set to ``None`` or missing, the default project_id from the GCP connection is used. + :param gcp_conn_id: Optional. The connection ID used to connect to GCP. + :param append_job_name: Optional. Set to True if a unique suffix has to be appended to the `job_name`. + Defaults to True. + :param drain_pipeline: Optional. Set to True if you want to stop a streaming pipeline job by draining it + instead of canceling when killing the task instance. Note that this does not work for batch pipeline jobs + or in the deferrable mode. Defaults to False. + For more info see: https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline + :param deferrable: Optional. Run operator in the deferrable mode. + :param expected_terminal_state: Optional. The expected terminal state of the Dataflow job at which the + operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the batch jobs and 'JOB_STATE_RUNNING' + for the streaming jobs. + :param poll_sleep: Optional. The time in seconds to sleep between polling Google Cloud Platform for the Dataflow job status. + Used both for the sync and deferrable mode. + :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be + successfully canceled when the task is being killed. + :param jinja_variables: Optional. A dictionary of Jinja2 variables to be used in reifying the yaml pipeline file. + :param options: Optional. Additional gcloud or Beam job parameters. + It must be a dictionary with the keys matching the optional flag names in gcloud. + The list of supported flags can be found at: `https://cloud.google.com/sdk/gcloud/reference/dataflow/yaml/run`. + Note that if a flag does not require a value, then its dictionary value must be either True or None. + For example, the `--log-http` flag can be passed as {'log-http': True}. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :return: Dictionary containing the job's data. + """ + + template_fields: Sequence[str] = ( + "job_name", + "yaml_pipeline_file", + "jinja_variables", + "options", + "region", + "project_id", + "gcp_conn_id", + ) + template_fields_renderers = { + "jinja_variables": "json", + } + operator_extra_links = (DataflowJobLink(),) + + def __init__( + self, + *, + job_name: str, + yaml_pipeline_file: str, + region: str = DEFAULT_DATAFLOW_LOCATION, + project_id: str = PROVIDE_PROJECT_ID, + gcp_conn_id: str = "google_cloud_default", + append_job_name: bool = True, + drain_pipeline: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_sleep: int = 10, + cancel_timeout: int | None = 5 * 60, + expected_terminal_state: str | None = None, + jinja_variables: dict[str, str] | None = None, + options: dict[str, Any] | None = None, + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.yaml_pipeline_file = yaml_pipeline_file + self.region = region + self.project_id = project_id + self.gcp_conn_id = gcp_conn_id + self.append_job_name = append_job_name + self.drain_pipeline = drain_pipeline + self.deferrable = deferrable + self.poll_sleep = poll_sleep + self.cancel_timeout = cancel_timeout + self.expected_terminal_state = expected_terminal_state + self.options = options + self.jinja_variables = jinja_variables + self.impersonation_chain = impersonation_chain + self.job_id: str | None = None + + def execute(self, context: Context) -> dict[str, Any]: + self.job_id = self.hook.launch_beam_yaml_job( + job_name=self.job_name, + yaml_pipeline_file=self.yaml_pipeline_file, + append_job_name=self.append_job_name, + options=self.options, + jinja_variables=self.jinja_variables, + project_id=self.project_id, + location=self.region, + ) + + DataflowJobLink.persist(self, context, self.project_id, self.region, self.job_id) + + if self.deferrable: + self.defer( + trigger=DataflowStartYamlJobTrigger( + job_id=self.job_id, + project_id=self.project_id, + location=self.region, + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + cancel_timeout=self.cancel_timeout, + expected_terminal_state=self.expected_terminal_state, + impersonation_chain=self.impersonation_chain, + ), + method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, + ) + + self.hook.wait_for_done( + job_name=self.job_name, location=self.region, project_id=self.project_id, job_id=self.job_id + ) + job = self.hook.get_job(job_id=self.job_id, location=self.region, project_id=self.project_id) + return job + + def execute_complete(self, context: Context, event: dict) -> dict[str, Any]: + """Execute after the trigger returns an event.""" + if event["status"] in ("error", "stopped"): + self.log.info("status: %s, msg: %s", event["status"], event["message"]) + raise AirflowException(event["message"]) + job = event["job"] + self.log.info("Job %s completed with response %s", job["id"], event["message"]) + self.xcom_push(context, key="job_id", value=job["id"]) + + return job + + def on_kill(self): + """ + Cancel the dataflow job if a task instance gets killed. + + This method will not be called if a task instance is killed in a deferred + state. + """ + self.log.info("On kill called.") + if self.job_id: + self.hook.cancel_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.region, + ) + + @cached_property + def hook(self) -> DataflowHook: + return DataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + drain_pipeline=self.drain_pipeline, + cancel_timeout=self.cancel_timeout, + expected_terminal_state=self.expected_terminal_state, + ) + + # TODO: Remove one day @deprecated( planned_removal_date="November 01, 2024", diff --git a/airflow/providers/google/cloud/triggers/dataflow.py b/airflow/providers/google/cloud/triggers/dataflow.py index 01f96b98b78e0f..4b994bf6e4d7d8 100644 --- a/airflow/providers/google/cloud/triggers/dataflow.py +++ b/airflow/providers/google/cloud/triggers/dataflow.py @@ -24,8 +24,10 @@ from google.cloud.dataflow_v1beta3 import JobState from google.cloud.dataflow_v1beta3.types import ( AutoscalingEvent, + Job, JobMessage, JobMetrics, + JobType, MetricUpdate, ) @@ -157,7 +159,7 @@ def _get_async_hook(self) -> AsyncDataflowHook: class DataflowJobStatusTrigger(BaseTrigger): """ - Trigger that checks for metrics associated with a Dataflow job. + Trigger that monitors if a Dataflow job has reached any of the expected statuses. :param job_id: Required. ID of the job. :param expected_statuses: The expected state(s) of the operation. @@ -266,6 +268,148 @@ def async_hook(self) -> AsyncDataflowHook: ) +class DataflowStartYamlJobTrigger(BaseTrigger): + """ + Dataflow trigger that checks the state of a Dataflow YAML job. + + :param job_id: Required. ID of the job. + :param project_id: Required. The Google Cloud project ID in which the job was started. + :param location: The location where job is executed. If set to None then + the value of DEFAULT_DATAFLOW_LOCATION will be used. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param poll_sleep: Optional. The time in seconds to sleep between polling Google Cloud Platform + for the Dataflow job. + :param cancel_timeout: Optional. How long (in seconds) operator should wait for the pipeline to be + successfully cancelled when task is being killed. + :param expected_terminal_state: Optional. The expected terminal state of the Dataflow job at which the + operator task is set to succeed. Defaults to 'JOB_STATE_DONE' for the batch jobs and + 'JOB_STATE_RUNNING' for the streaming jobs. + :param impersonation_chain: Optional. Service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_id: str, + project_id: str | None, + location: str = DEFAULT_DATAFLOW_LOCATION, + gcp_conn_id: str = "google_cloud_default", + poll_sleep: int = 10, + cancel_timeout: int | None = 5 * 60, + expected_terminal_state: str | None = None, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.project_id = project_id + self.job_id = job_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.poll_sleep = poll_sleep + self.cancel_timeout = cancel_timeout + self.expected_terminal_state = expected_terminal_state + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize class arguments and classpath.""" + return ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger", + { + "project_id": self.project_id, + "job_id": self.job_id, + "location": self.location, + "gcp_conn_id": self.gcp_conn_id, + "poll_sleep": self.poll_sleep, + "expected_terminal_state": self.expected_terminal_state, + "impersonation_chain": self.impersonation_chain, + "cancel_timeout": self.cancel_timeout, + }, + ) + + async def run(self): + """ + Fetch job and yield events depending on the job's type and state. + + Yield TriggerEvent if the job reaches a terminal state. + Otherwise awaits for a specified amount of time stored in self.poll_sleep variable. + """ + hook: AsyncDataflowHook = self._get_async_hook() + try: + while True: + job: Job = await hook.get_job( + job_id=self.job_id, + project_id=self.project_id, + location=self.location, + ) + job_state = job.current_state + job_type = job.type_ + if job_state.name == self.expected_terminal_state: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": f"Job reached the expected terminal state: {self.expected_terminal_state}.", + } + ) + return + elif job_type == JobType.JOB_TYPE_STREAMING and job_state == JobState.JOB_STATE_RUNNING: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": "Streaming job reached the RUNNING state.", + } + ) + return + elif job_type == JobType.JOB_TYPE_BATCH and job_state == JobState.JOB_STATE_DONE: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "success", + "message": "Batch job completed.", + } + ) + return + elif job_state == JobState.JOB_STATE_FAILED: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "error", + "message": "Job failed.", + } + ) + return + elif job_state == JobState.JOB_STATE_STOPPED: + yield TriggerEvent( + { + "job": Job.to_dict(job), + "status": "stopped", + "message": "Job was stopped.", + } + ) + return + else: + self.log.info("Current job status is: %s", job_state.name) + self.log.info("Sleeping for %s seconds.", self.poll_sleep) + await asyncio.sleep(self.poll_sleep) + except Exception as e: + self.log.exception("Exception occurred while checking for job completion.") + yield TriggerEvent({"job": None, "status": "error", "message": str(e)}) + + def _get_async_hook(self) -> AsyncDataflowHook: + return AsyncDataflowHook( + gcp_conn_id=self.gcp_conn_id, + poll_sleep=self.poll_sleep, + impersonation_chain=self.impersonation_chain, + cancel_timeout=self.cancel_timeout, + ) + + class DataflowJobMetricsTrigger(BaseTrigger): """ Trigger that checks for metrics associated with a Dataflow job. diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 71fc3275fe99f4..a9eb98ea9a5097 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -306,6 +306,38 @@ Here is an example of running Dataflow SQL job with See the `Dataflow SQL reference `_. +.. _howto/operator:DataflowStartYamlJobOperator: + +Dataflow YAML +"""""""""""""" +Beam YAML is a no-code SDK for configuring Apache Beam pipelines by using YAML files. +You can use Beam YAML to author and run a Beam pipeline without writing any code. +This API can be used to define both streaming and batch pipelines. + +Here is an example of running Dataflow YAML job with +:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartYamlJobOperator`: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataflow_start_yaml_job] + :end-before: [END howto_operator_dataflow_start_yaml_job] + +This operator can be run in deferrable mode by passing ``deferrable=True`` as a parameter. + +.. exampleinclude:: /../../tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_dataflow_start_yaml_job_def] + :end-before: [END howto_operator_dataflow_start_yaml_job_def] + +.. warning:: + This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker + `__ + +See the `Dataflow YAML reference +`_. + .. _howto/operator:DataflowStopJobOperator: Stopping a pipeline diff --git a/tests/always/test_example_dags.py b/tests/always/test_example_dags.py index 2dfcfb6e37cdd0..9d10ce5cad1994 100644 --- a/tests/always/test_example_dags.py +++ b/tests/always/test_example_dags.py @@ -51,6 +51,7 @@ # If the deprecation is postponed, the item should be added to this tuple, # and a corresponding Issue should be created on GitHub. "tests/system/providers/google/cloud/bigquery/example_bigquery_operations.py", + "tests/system/providers/google/cloud/dataflow/example_dataflow_sql.py", "tests/system/providers/google/cloud/dataproc/example_dataproc_gke.py", "tests/system/providers/google/cloud/datapipelines/example_datapipeline.py", "tests/system/providers/google/cloud/gcs/example_gcs_sensor.py", diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index a024ffd7a7503a..14787dba19b61a 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -41,6 +41,7 @@ DataflowRunPipelineOperator, DataflowStartFlexTemplateOperator, DataflowStartSqlJobOperator, + DataflowStartYamlJobOperator, DataflowStopJobOperator, DataflowTemplatedJobStartOperator, ) @@ -711,16 +712,17 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr class TestDataflowStartSqlJobOperator: @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") def test_execute(self, mock_hook): - start_sql = DataflowStartSqlJobOperator( - task_id="start_sql_query", - job_name=TEST_SQL_JOB_NAME, - query=TEST_SQL_QUERY, - options=deepcopy(TEST_SQL_OPTIONS), - location=TEST_LOCATION, - do_xcom_push=True, - ) + with pytest.warns(AirflowProviderDeprecationWarning): + start_sql = DataflowStartSqlJobOperator( + task_id="start_sql_query", + job_name=TEST_SQL_JOB_NAME, + query=TEST_SQL_QUERY, + options=deepcopy(TEST_SQL_OPTIONS), + location=TEST_LOCATION, + do_xcom_push=True, + ) + start_sql.execute(mock.MagicMock()) - start_sql.execute(mock.MagicMock()) mock_hook.assert_called_once_with( gcp_conn_id="google_cloud_default", drain_pipeline=False, @@ -741,6 +743,95 @@ def test_execute(self, mock_hook): ) +class TestDataflowStartYamlJobOperator: + @pytest.fixture + def sync_operator(self): + return DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_sync", + job_name="dataflow_yaml_job", + yaml_pipeline_file="test_file_path", + append_job_name=False, + project_id=TEST_PROJECT, + region=TEST_LOCATION, + gcp_conn_id=GCP_CONN_ID, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + ) + + @pytest.fixture + def deferrable_operator(self): + return DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_def", + job_name="dataflow_yaml_job", + yaml_pipeline_file="test_file_path", + append_job_name=False, + project_id=TEST_PROJECT, + region=TEST_LOCATION, + gcp_conn_id=GCP_CONN_ID, + deferrable=True, + expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING, + ) + + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute(self, mock_hook, sync_operator): + sync_operator.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + poll_sleep=sync_operator.poll_sleep, + drain_pipeline=False, + impersonation_chain=None, + cancel_timeout=sync_operator.cancel_timeout, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + gcp_conn_id=GCP_CONN_ID, + ) + mock_hook.return_value.launch_beam_yaml_job.assert_called_once_with( + job_name=sync_operator.job_name, + yaml_pipeline_file=sync_operator.yaml_pipeline_file, + append_job_name=False, + options=None, + jinja_variables=None, + project_id=TEST_PROJECT, + location=TEST_LOCATION, + ) + + @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.defer") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferrable_operator): + deferrable_operator.execute(mock.MagicMock()) + mock_hook.assert_called_once_with( + poll_sleep=deferrable_operator.poll_sleep, + drain_pipeline=False, + impersonation_chain=None, + cancel_timeout=deferrable_operator.cancel_timeout, + expected_terminal_state=DataflowJobStatus.JOB_STATE_RUNNING, + gcp_conn_id=GCP_CONN_ID, + ) + mock_defer_method.assert_called_once() + + @mock.patch(f"{DATAFLOW_PATH}.DataflowStartYamlJobOperator.xcom_push") + @mock.patch(f"{DATAFLOW_PATH}.DataflowHook") + def test_execute_complete_success(self, mock_hook, mock_xcom_push, deferrable_operator): + expected_result = {"id": JOB_ID} + actual_result = deferrable_operator.execute_complete( + context=None, + event={ + "status": "success", + "message": "Batch job completed.", + "job": expected_result, + }, + ) + mock_xcom_push.assert_called_with(None, key="job_id", value=JOB_ID) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self, deferrable_operator): + with pytest.raises(AirflowException, match="Job failed."): + deferrable_operator.execute_complete( + context=None, event={"status": "error", "message": "Job failed."} + ) + with pytest.raises(AirflowException, match="Job was stopped."): + deferrable_operator.execute_complete( + context=None, event={"status": "stopped", "message": "Job was stopped."} + ) + + class TestDataflowStopJobOperator: @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") def test_exec_job_id(self, dataflow_mock): diff --git a/tests/providers/google/cloud/triggers/test_dataflow.py b/tests/providers/google/cloud/triggers/test_dataflow.py index 2b9b63afa44e31..67b2f41009c25a 100644 --- a/tests/providers/google/cloud/triggers/test_dataflow.py +++ b/tests/providers/google/cloud/triggers/test_dataflow.py @@ -22,7 +22,7 @@ from unittest import mock import pytest -from google.cloud.dataflow_v1beta3 import JobState +from google.cloud.dataflow_v1beta3 import Job, JobState, JobType from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus from airflow.providers.google.cloud.triggers.dataflow import ( @@ -30,6 +30,7 @@ DataflowJobMessagesTrigger, DataflowJobMetricsTrigger, DataflowJobStatusTrigger, + DataflowStartYamlJobTrigger, TemplateJobStartTrigger, ) from airflow.triggers.base import TriggerEvent @@ -108,6 +109,24 @@ def dataflow_job_status_trigger(): ) +@pytest.fixture +def dataflow_start_yaml_job_trigger(): + return DataflowStartYamlJobTrigger( + project_id=PROJECT_ID, + job_id=JOB_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + poll_sleep=POLL_SLEEP, + impersonation_chain=IMPERSONATION_CHAIN, + cancel_timeout=CANCEL_TIMEOUT, + ) + + +@pytest.fixture +def test_dataflow_batch_job(): + return Job(id=JOB_ID, current_state=JobState.JOB_STATE_DONE, type_=JobType.JOB_TYPE_BATCH) + + class TestTemplateJobStartTrigger: def test_serialize(self, template_job_start_trigger): actual_data = template_job_start_trigger.serialize() @@ -548,13 +567,11 @@ async def test_run_loop_is_still_running_if_fail_on_terminal_state( mock_get_job_metrics, mock_job_status, dataflow_job_metrics_trigger, - caplog, ): """Test that DataflowJobMetricsTrigger is still in loop if the job status is RUNNING.""" dataflow_job_metrics_trigger.fail_on_terminal_state = True mock_job_status.return_value = JobState.JOB_STATE_RUNNING mock_get_job_metrics.return_value = [] - caplog.set_level(logging.INFO) task = asyncio.create_task(dataflow_job_metrics_trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is False @@ -703,12 +720,10 @@ async def test_run_loop_is_still_running_if_state_is_not_terminal_or_expected( self, mock_job_status, dataflow_job_status_trigger, - caplog, ): """Test that DataflowJobStatusTrigger is still in loop if the job status neither terminal nor expected.""" dataflow_job_status_trigger.expected_statuses = {DataflowJobStatus.JOB_STATE_DONE} mock_job_status.return_value = JobState.JOB_STATE_RUNNING - caplog.set_level(logging.INFO) task = asyncio.create_task(dataflow_job_status_trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is False @@ -729,3 +744,119 @@ async def test_run_raises_exception(self, mock_job_status, dataflow_job_status_t ) actual_event = await dataflow_job_status_trigger.run().asend(None) assert expected_event == actual_event + + +class TestDataflowStartYamlJobTrigger: + def test_serialize(self, dataflow_start_yaml_job_trigger): + actual_data = dataflow_start_yaml_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.dataflow.DataflowStartYamlJobTrigger", + { + "project_id": PROJECT_ID, + "job_id": JOB_ID, + "location": LOCATION, + "gcp_conn_id": GCP_CONN_ID, + "poll_sleep": POLL_SLEEP, + "expected_terminal_state": None, + "impersonation_chain": IMPERSONATION_CHAIN, + "cancel_timeout": CANCEL_TIMEOUT, + }, + ) + assert actual_data == expected_data + + @pytest.mark.parametrize( + "attr, expected", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("poll_sleep", POLL_SLEEP), + ("impersonation_chain", IMPERSONATION_CHAIN), + ("cancel_timeout", CANCEL_TIMEOUT), + ], + ) + def test_get_async_hook(self, dataflow_start_yaml_job_trigger, attr, expected): + hook = dataflow_start_yaml_job_trigger._get_async_hook() + actual = hook._hook_kwargs.get(attr) + assert actual is not None + assert actual == expected + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_success_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "success", + "message": "Batch job completed.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_failed_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + test_dataflow_batch_job.current_state = JobState.JOB_STATE_FAILED + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "error", + "message": "Job failed.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_stopped_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + test_dataflow_batch_job.current_state = JobState.JOB_STATE_STOPPED + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "stopped", + "message": "Job was stopped.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_return_expected_state_event( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + dataflow_start_yaml_job_trigger.expected_terminal_state = DataflowJobStatus.JOB_STATE_RUNNING + test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING + mock_get_job.return_value = test_dataflow_batch_job + expected_event = TriggerEvent( + { + "job": Job.to_dict(test_dataflow_batch_job), + "status": "success", + "message": f"Job reached the expected terminal state: {DataflowJobStatus.JOB_STATE_RUNNING}.", + } + ) + actual_event = await dataflow_start_yaml_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.dataflow.AsyncDataflowHook.get_job") + async def test_run_loop_is_still_running( + self, mock_get_job, dataflow_start_yaml_job_trigger, test_dataflow_batch_job + ): + """Test that DataflowStartYamlJobTrigger is still in loop if the job status neither terminal nor expected.""" + dataflow_start_yaml_job_trigger.expected_terminal_state = DataflowJobStatus.JOB_STATE_STOPPED + test_dataflow_batch_job.current_state = JobState.JOB_STATE_RUNNING + mock_get_job.return_value = test_dataflow_batch_job + task = asyncio.create_task(dataflow_start_yaml_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + task.cancel() diff --git a/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py new file mode 100644 index 00000000000000..2bddc1de3fa466 --- /dev/null +++ b/tests/system/providers/google/cloud/dataflow/example_dataflow_yaml.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example Airflow DAG for Google Cloud Dataflow YAML service. + +Requirements: + This test requires ``gcloud`` command (Google Cloud SDK) to be installed on the Airflow worker + `__ +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryInsertJobOperator, +) +from airflow.providers.google.cloud.operators.dataflow import DataflowStartYamlJobOperator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "dataflow_yaml" +REGION = "europe-west2" +DATAFLOW_YAML_JOB_NAME = f"{DAG_ID}_{ENV_ID}".replace("_", "-") +BQ_DATASET = f"{DAG_ID}_{ENV_ID}".replace("-", "_") +BQ_INPUT_TABLE = f"input_{DAG_ID}".replace("-", "_") +BQ_OUTPUT_TABLE = f"output_{DAG_ID}".replace("-", "_") +DATAFLOW_YAML_PIPELINE_FILE_URL = ( + "gs://airflow-system-tests-resources/dataflow/yaml/example_beam_yaml_bq.yaml" +) + +BQ_VARIABLES = { + "project": PROJECT_ID, + "dataset": BQ_DATASET, + "input": BQ_INPUT_TABLE, + "output": BQ_OUTPUT_TABLE, +} + +BQ_VARIABLES_DEF = { + "project": PROJECT_ID, + "dataset": BQ_DATASET, + "input": BQ_INPUT_TABLE, + "output": f"{BQ_OUTPUT_TABLE}_def", +} + +INSERT_ROWS_QUERY = ( + f"INSERT {BQ_DATASET}.{BQ_INPUT_TABLE} VALUES " + "('John Doe', 900, 'USA'), " + "('Alice Storm', 1200, 'Australia')," + "('Bob Max', 1000, 'Australia')," + "('Peter Jackson', 800, 'New Zealand')," + "('Hobby Doyle', 1100, 'USA')," + "('Terrance Phillips', 2222, 'Canada')," + "('Joe Schmoe', 1500, 'Canada')," + "('Dominique Levillaine', 2780, 'France');" +) + + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "dataflow", "yaml"], +) as dag: + create_bq_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_bq_dataset", + dataset_id=BQ_DATASET, + location=REGION, + ) + + create_bq_input_table = BigQueryCreateEmptyTableOperator( + task_id="create_bq_input_table", + dataset_id=BQ_DATASET, + table_id=BQ_INPUT_TABLE, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + {"name": "country", "type": "STRING", "mode": "NULLABLE"}, + ], + ) + + insert_data_into_bq_table = BigQueryInsertJobOperator( + task_id="insert_data_into_bq_table", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + "priority": "BATCH", + } + }, + location=REGION, + ) + + # [START howto_operator_dataflow_start_yaml_job] + start_dataflow_yaml_job = DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job", + job_name=DATAFLOW_YAML_JOB_NAME, + yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL, + append_job_name=True, + deferrable=False, + region=REGION, + project_id=PROJECT_ID, + jinja_variables=BQ_VARIABLES, + ) + # [END howto_operator_dataflow_start_yaml_job] + + # [START howto_operator_dataflow_start_yaml_job_def] + start_dataflow_yaml_job_def = DataflowStartYamlJobOperator( + task_id="start_dataflow_yaml_job_def", + job_name=DATAFLOW_YAML_JOB_NAME, + yaml_pipeline_file=DATAFLOW_YAML_PIPELINE_FILE_URL, + append_job_name=True, + deferrable=True, + region=REGION, + project_id=PROJECT_ID, + jinja_variables=BQ_VARIABLES_DEF, + expected_terminal_state=DataflowJobStatus.JOB_STATE_DONE, + ) + # [END howto_operator_dataflow_start_yaml_job_def] + + delete_bq_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_bq_dataset", + dataset_id=BQ_DATASET, + delete_contents=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_bq_dataset + >> create_bq_input_table + >> insert_data_into_bq_table + # TEST BODY + >> [start_dataflow_yaml_job, start_dataflow_yaml_job_def] + # TEST TEARDOWN + >> delete_bq_dataset + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)