Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Apache Beam operators - refactor operator - common Dataflow logic #14094

Merged
merged 2 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 74 additions & 61 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Apache Beam operators."""
import copy
from abc import ABCMeta
from contextlib import ExitStack
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Tuple, Union

from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
Expand All @@ -32,7 +34,68 @@
from airflow.version import version


class BeamRunPythonPipelineOperator(BaseOperator):
class BeamDataflowMixin(metaclass=ABCMeta):
"""
Helper class to store common, Dataflow specific logic for both
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunPythonPipelineOperator` and
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunJavaPipelineOperator`.
"""

dataflow_hook: Optional[DataflowHook]
dataflow_config: Optional[DataflowConfiguration]

def _set_dataflow(
self, pipeline_options: dict, job_name_variable_key: Optional[str] = None
) -> Tuple[str, dict, Callable[[str], None]]:
self.dataflow_hook = self.__set_dataflow_hook()
self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
dataflow_job_name = self.__get_dataflow_job_name()
pipeline_options = self.__get_dataflow_pipeline_options(
pipeline_options, dataflow_job_name, job_name_variable_key
)
process_line_callback = self.__get_dataflow_process_callback()
return dataflow_job_name, pipeline_options, process_line_callback

def __set_dataflow_hook(self) -> DataflowHook:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
poll_sleep=self.dataflow_config.poll_sleep,
impersonation_chain=self.dataflow_config.impersonation_chain,
drain_pipeline=self.dataflow_config.drain_pipeline,
cancel_timeout=self.dataflow_config.cancel_timeout,
wait_until_finished=self.dataflow_config.wait_until_finished,
)
return self.dataflow_hook

def __get_dataflow_job_name(self) -> str:
return DataflowHook.build_dataflow_job_name(
self.dataflow_config.job_name, self.dataflow_config.append_job_name
)

def __get_dataflow_pipeline_options(
self, pipeline_options: dict, job_name: str, job_name_key: Optional[str] = None
) -> dict:
pipeline_options = copy.deepcopy(pipeline_options)
if job_name_key is not None:
pipeline_options[job_name_key] = job_name
pipeline_options["project"] = self.dataflow_config.project_id
pipeline_options["region"] = self.dataflow_config.location
pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)
return pipeline_options

def __get_dataflow_process_callback(self) -> Callable[[str], None]:
def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

return process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
)


class BeamRunPythonPipelineOperator(BaseOperator, BeamDataflowMixin):
"""
Launching Apache Beam pipelines written in Python. Note that both
``default_pipeline_options`` and ``pipeline_options`` will be merged to specify pipeline
Expand All @@ -56,8 +119,6 @@ class BeamRunPythonPipelineOperator(BaseOperator):
See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
See: https://beam.apache.org/documentation/runners/capability-matrix/

If you use Dataflow runner check dedicated operator:
:class:`~providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`
:type runner: str
:param py_options: Additional python options, e.g., ["-m", "-v"].
:type py_options: list[str]
Expand Down Expand Up @@ -155,37 +216,14 @@ def execute(self, context):
pipeline_options = self.default_pipeline_options.copy()
process_line_callback: Optional[Callable] = None
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
dataflow_job_name: Optional[str] = None

if isinstance(self.dataflow_config, dict):
self.dataflow_config = DataflowConfiguration(**self.dataflow_config)

if is_dataflow:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
poll_sleep=self.dataflow_config.poll_sleep,
impersonation_chain=self.dataflow_config.impersonation_chain,
drain_pipeline=self.dataflow_config.drain_pipeline,
cancel_timeout=self.dataflow_config.cancel_timeout,
wait_until_finished=self.dataflow_config.wait_until_finished,
)
self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id

dataflow_job_name = DataflowHook.build_dataflow_job_name(
self.dataflow_config.job_name, self.dataflow_config.append_job_name
)
pipeline_options["job_name"] = dataflow_job_name
pipeline_options["project"] = self.dataflow_config.project_id
pipeline_options["region"] = self.dataflow_config.location
pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)

def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

process_line_callback = process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
pipeline_options=pipeline_options, job_name_variable_key="job_name"
)

pipeline_options.update(self.pipeline_options)
Expand Down Expand Up @@ -233,7 +271,7 @@ def on_kill(self) -> None:


# pylint: disable=too-many-instance-attributes
class BeamRunJavaPipelineOperator(BaseOperator):
class BeamRunJavaPipelineOperator(BaseOperator, BeamDataflowMixin):
"""
Launching Apache Beam pipelines written in Java.

Expand Down Expand Up @@ -261,8 +299,6 @@ class BeamRunJavaPipelineOperator(BaseOperator):
:param runner: Runner on which pipeline will be run. By default "DirectRunner" is being used.
See:
https://beam.apache.org/documentation/runners/capability-matrix/
If you use Dataflow runner check dedicated operator:
:class:`~providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`
:type runner: str
:param job_class: The name of the Apache Beam pipeline class to be executed, it
is often not the main class configured in the pipeline jar file.
Expand Down Expand Up @@ -343,37 +379,14 @@ def execute(self, context):
pipeline_options = self.default_pipeline_options.copy()
process_line_callback: Optional[Callable] = None
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
dataflow_job_name: Optional[str] = None

if isinstance(self.dataflow_config, dict):
self.dataflow_config = DataflowConfiguration(**self.dataflow_config)

if is_dataflow:
self.dataflow_hook = DataflowHook(
gcp_conn_id=self.dataflow_config.gcp_conn_id or self.gcp_conn_id,
delegate_to=self.dataflow_config.delegate_to or self.delegate_to,
poll_sleep=self.dataflow_config.poll_sleep,
impersonation_chain=self.dataflow_config.impersonation_chain,
drain_pipeline=self.dataflow_config.drain_pipeline,
cancel_timeout=self.dataflow_config.cancel_timeout,
wait_until_finished=self.dataflow_config.wait_until_finished,
)
self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id

self._dataflow_job_name = DataflowHook.build_dataflow_job_name(
self.dataflow_config.job_name, self.dataflow_config.append_job_name
)
pipeline_options["jobName"] = self.dataflow_config.job_name
pipeline_options["project"] = self.dataflow_config.project_id
pipeline_options["region"] = self.dataflow_config.location
pipeline_options.setdefault("labels", {}).update(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)

def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

process_line_callback = process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
pipeline_options=pipeline_options, job_name_variable_key=None
)

pipeline_options.update(self.pipeline_options)
Expand Down Expand Up @@ -412,15 +425,15 @@ def set_current_dataflow_job_id(job_id):
variables=pipeline_options,
)
if not is_running:
pipeline_options["jobName"] = self._dataflow_job_name
pipeline_options["jobName"] = dataflow_job_name
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
self.dataflow_hook.wait_for_done(
job_name=self._dataflow_job_name,
job_name=dataflow_job_name,
location=self.dataflow_config.location,
job_id=self.dataflow_job_id,
multiple_jobs=self.dataflow_config.multiple_jobs,
Expand Down
1 change: 0 additions & 1 deletion tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
dataflow_hook_mock.return_value.is_job_dataflow_running.return_value = False
self.operator.execute(None)
job_name = dataflow_hook_mock.build_dataflow_job_name.return_value
self.assertEqual(job_name, self.operator._dataflow_job_name)
dataflow_hook_mock.assert_called_once_with(
gcp_conn_id=dataflow_config.gcp_conn_id,
delegate_to=dataflow_config.delegate_to,
Expand Down