diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py index 2fab2e6b1f1c78..b0810ea4cb77e9 100644 --- a/airflow/contrib/operators/qubole_operator.py +++ b/airflow/contrib/operators/qubole_operator.py @@ -18,7 +18,7 @@ # under the License. """Qubole operator""" import re -from typing import Iterable +from typing import FrozenSet, Iterable, Optional from airflow.contrib.hooks.qubole_hook import ( COMMAND_ARGS, HYPHEN_ARGS, POSITIONAL_ARGS, QuboleHook, flatten_list, @@ -42,7 +42,8 @@ def get_link(self, operator, dttm): :return: url link """ ti = TaskInstance(task=operator, execution_date=dttm) - conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id']) + conn = BaseHook.get_connection( + getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id']) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: @@ -181,6 +182,9 @@ class QuboleOperator(BaseOperator): QDSLink(), ) + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields: Optional[FrozenSet[str]] = None + @apply_defaults def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs): self.args = args @@ -240,3 +244,10 @@ def __setattr__(self, name, value): self.kwargs[name] = value else: object.__setattr__(self, name, value) + + @classmethod + def get_serialized_fields(cls): + """Serialized QuboleOperator contain exactly these fields.""" + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"qubole_conn_id"}) + return cls.__serialized_fields diff --git a/airflow/gcp/operators/bigquery.py b/airflow/gcp/operators/bigquery.py index f8ab0f04fef9e6..fed0db81b4394f 100644 --- a/airflow/gcp/operators/bigquery.py +++ b/airflow/gcp/operators/bigquery.py @@ -24,8 +24,9 @@ import json import warnings -from typing import Any, Dict, Iterable, List, Optional, SupportsAbs, Union +from typing import Any, Dict, FrozenSet, Iterable, List, Optional, SupportsAbs, Union +import attr from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException @@ -337,14 +338,13 @@ def get_link(self, operator, dttm): return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else '' +@attr.s(auto_attribs=True) class BigQueryConsoleIndexableLink(BaseOperatorLink): """ Helper class for constructing BigQuery link. """ - def __init__(self, index) -> None: - super().__init__() - self.index = index + index: int = attr.ib() @property def name(self) -> str: @@ -459,6 +459,9 @@ class BigQueryOperator(BaseOperator): template_ext = ('.sql', ) ui_color = '#e4f0e8' + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields: Optional[FrozenSet[str]] = None + @property def operator_extra_links(self): """ @@ -594,6 +597,13 @@ def on_kill(self): self.log.info('Cancelling running query') self.bq_cursor.cancel_query() + @classmethod + def get_serialized_fields(cls): + """Serialized BigQueryOperator contain exactly these fields.""" + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"sql"}) + return cls.__serialized_fields + class BigQueryCreateEmptyTableOperator(BaseOperator): """ diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 69a708f506e411..b5e92ae81bbb0c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -26,8 +26,9 @@ import warnings from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta -from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union +import attr import jinja2 from cached_property import cached_property from dateutil.relativedelta import relativedelta @@ -256,7 +257,7 @@ class derived from this one results in the creation of a task object, operator_extra_links: Iterable['BaseOperatorLink'] = () # The _serialized_fields are lazily loaded when get_serialized_fields() method is called - _serialized_fields: Optional[FrozenSet[str]] = None + __serialized_fields: Optional[FrozenSet[str]] = None _comps = { 'task_id', @@ -785,15 +786,15 @@ def prepare_template(self) -> None: def resolve_template_files(self) -> None: """Getting the content of files for template_field / template_ext""" if self.template_ext: # pylint: disable=too-many-nested-blocks - for attr in self.template_fields: - content = getattr(self, attr, None) + for field in self.template_fields: + content = getattr(self, field, None) if content is None: continue elif isinstance(content, str) and \ any([content.endswith(ext) for ext in self.template_ext]): env = self.get_template_env() try: - setattr(self, attr, env.loader.get_source(env, content)[0]) + setattr(self, field, env.loader.get_source(env, content)[0]) except Exception as e: # pylint: disable=broad-except self.log.exception(e) elif isinstance(content, list): @@ -939,10 +940,10 @@ def run( def dry_run(self) -> None: """Performs dry run for the operator - just render template fields.""" self.log.info('Dry run') - for attr in self.template_fields: - content = getattr(self, attr) + for field in self.template_fields: + content = getattr(self, field) if content and isinstance(content, str): - self.log.info('Rendering template for %s', attr) + self.log.info('Rendering template for %s', field) self.log.info(content) def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: @@ -1101,21 +1102,22 @@ def get_extra_links(self, dttm: datetime, link_name: str) -> Optional[Dict[str, @classmethod def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" - if not cls._serialized_fields: - cls._serialized_fields = frozenset( + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset( vars(BaseOperator(task_id='test')).keys() - { 'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag' - } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'} - ) - return cls._serialized_fields + } | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'}) + + return cls.__serialized_fields +@attr.s(auto_attribs=True) class BaseOperatorLink(metaclass=ABCMeta): """ Abstract base class that defines how we get an operator link. """ - operators: List[Type[BaseOperator]] = [] + operators: ClassVar[List[Type[BaseOperator]]] = [] """ This property will be used by Airflow Plugins to find the Operators to which you want to assign this Operator Link diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 58a2a529a70ce2..e61b44940a9fa6 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -193,7 +193,7 @@ class DAG(BaseDag, LoggingMixin): 'last_loaded', } - _serialized_fields: Optional[FrozenSet[str]] = None + __serialized_fields: Optional[FrozenSet[str]] = None def __init__( self, @@ -1508,14 +1508,14 @@ def _test_cycle_helper(self, visit_map, task_id): @classmethod def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" - if not cls._serialized_fields: - cls._serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - { 'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded', '_full_filepath', 'user_defined_filters', 'user_defined_macros', '_schedule_interval', 'partial', '_old_context_manager_dags', '_pickle_id', '_log', 'is_subdag', 'task_dict' } - return cls._serialized_fields + return cls.__serialized_fields class DagModel(Base): diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 3b34bf0ddcdb3e..034c7ea53a74e8 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -22,7 +22,7 @@ import os import re import sys -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set, Type import pkg_resources @@ -112,6 +112,33 @@ def load_entrypoint_plugins(entry_points, airflow_plugins): return airflow_plugins +def register_inbuilt_operator_links() -> None: + """ + Register all the Operators Links that are already defined for the operators + in the "airflow" project. Example: QDSLink (Operator Link for Qubole Operator) + + This is required to populate the "whitelist" of allowed classes when deserializing operator links + """ + inbuilt_operator_links: Set[Type] = set() + + try: + from airflow.gcp.operators.bigquery import BigQueryConsoleLink, BigQueryConsoleIndexableLink # noqa E501 # pylint: disable=R0401,line-too-long + inbuilt_operator_links.update([BigQueryConsoleLink, BigQueryConsoleIndexableLink]) + except ImportError: + pass + + try: + from airflow.contrib.operators.qubole_operator import QDSLink # pylint: disable=R0401 + inbuilt_operator_links.update([QDSLink]) + except ImportError: + pass + + registered_operator_link_classes.update({ + "{}.{}".format(link.__module__, link.__name__): link + for link in inbuilt_operator_links + }) + + def is_valid_plugin(plugin_obj, existing_plugins): """ Check whether a potential object is a subclass of @@ -200,6 +227,12 @@ def make_module(name: str, objects: List[Any]): stat_name_handler: Any = None global_operator_extra_links: List[Any] = [] operator_extra_links: List[Any] = [] +registered_operator_link_classes: Dict[str, Type] = {} +"""Mapping of class names to class of OperatorLinks registered by plugins. + +Used by the DAG serialization code to only allow specific classes to be created +during deserialization +""" stat_name_handlers = [] for p in plugins: @@ -227,10 +260,13 @@ def make_module(name: str, objects: List[Any]): if p.stat_name_handler: stat_name_handlers.append(p.stat_name_handler) global_operator_extra_links.extend(p.global_operator_extra_links) - # Only register Operator links if its ``operators`` property is not an empty list - # So that we can only attach this links to a specific Operator - operator_extra_links.extend([ - ope for ope in p.operator_extra_links if ope.operators]) + operator_extra_links.extend([ope for ope in p.operator_extra_links]) + + registered_operator_link_classes.update({ + "{}.{}".format(link.__class__.__module__, + link.__class__.__name__): link.__class__ + for link in p.operator_extra_links + }) if len(stat_name_handlers) > 1: raise AirflowPluginException( @@ -287,3 +323,4 @@ def integrate_plugins() -> None: integrate_hook_plugins() integrate_executor_plugins() integrate_macro_plugins() + register_inbuilt_operator_links() diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json index f2b05d535a63e3..16e14af665e1de 100644 --- a/airflow/serialization/schema.json +++ b/airflow/serialization/schema.json @@ -58,6 +58,14 @@ "type": "string", "pattern": "^#[a-fA-F0-9]{3,6}$" }, + "extra_links": { + "type": "array", + "items": { + "type": "object", + "minProperties": 1, + "maxProperties": 1 + } + }, "dag": { "type": "object", "properties": { @@ -112,6 +120,7 @@ "properties": { "_task_type": { "type": "string" }, "_task_module": { "type": "string" }, + "_operator_extra_links": { "$ref": "#/definitions/extra_links" }, "task_id": { "type": "string" }, "owner": { "type": "string" }, "start_date": { "$ref": "#/definitions/datetime" }, diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index ce08a8c44bc169..80347570a817b2 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -20,14 +20,15 @@ import enum import logging from inspect import Parameter, signature -from typing import Any, Dict, Optional, Set, Union +from typing import Any, Dict, Iterable, Optional, Set, Union +import cattr import pendulum from dateutil import relativedelta from airflow import DAG, AirflowException, LoggingMixin from airflow.models import Connection -from airflow.models.baseoperator import BaseOperator +from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding from airflow.serialization.json_schema import Validator, load_dag_schema from airflow.settings import json @@ -45,7 +46,6 @@ class BaseSerialization: _datetime_types = (datetime.datetime,) # Object types that are always excluded in serialization. - # FIXME: not needed if _included_fields of DAG and operator are customized. _excluded_types = (logging.Logger, Connection, type) _json_schema: Optional[Validator] = None @@ -299,6 +299,9 @@ def serialize_operator(cls, op: BaseOperator) -> dict: serialize_op = cls.serialize_to_json(op, cls._decorated_fields) serialize_op['_task_type'] = op.__class__.__name__ serialize_op['_task_module'] = op.__class__.__module__ + if op.operator_extra_links: + serialize_op['_operator_extra_links'] = \ + cls._serialize_operator_extra_links(op.operator_extra_links) return serialize_op @classmethod @@ -309,7 +312,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: op = SerializedBaseOperator(task_id=encoded_op['task_id']) - # Extra Operator Links + # Extra Operator Links defined in Plugins op_extra_links_from_plugin = {} for ope in operator_extra_links: @@ -318,7 +321,12 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: operator.__module__ == encoded_op["_task_module"]: op_extra_links_from_plugin.update({ope.name: ope}) - setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) + # If OperatorLinks are defined in Plugins but not in the Operator that is being Serialized + # set the Operator links attribute + # The case for "If OperatorLinks are defined in the operator that is being Serialized" + # is handled in the deserialization loop where it matches k == "_operator_extra_links" + if op_extra_links_from_plugin and "_operator_extra_links" not in encoded_op: + setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) for k, v in encoded_op.items(): @@ -330,6 +338,14 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> BaseOperator: v = cls._deserialize_timedelta(v) elif k.endswith("_date"): v = cls._deserialize_datetime(v) + elif k == "_operator_extra_links": + op_predefined_extra_links = cls._deserialize_operator_extra_links(v) + + # If OperatorLinks with the same name exists, Links via Plugin have higher precedence + op_predefined_extra_links.update(op_extra_links_from_plugin) + + v = list(op_predefined_extra_links.values()) + k = "operator_extra_links" elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is @@ -354,6 +370,79 @@ def _is_excluded(cls, var: Any, attrname: str, op: BaseOperator): return True return super()._is_excluded(var, attrname, op) + @classmethod + def _deserialize_operator_extra_links( + cls, + encoded_op_links: list + ) -> Dict[str, BaseOperatorLink]: + """ + Deserialize Operator Links if the Classes are registered in Airflow Plugins. + Error is raised if the OperatorLink is not found in Plugins too. + + :param encoded_op_links: Serialized Operator Link + :return: De-Serialized Operator Link + """ + from airflow.plugins_manager import registered_operator_link_classes + + op_predefined_extra_links = {} + + for _operator_links_source in encoded_op_links: + # Get the key, value pair as Tuple where key is OperatorLink ClassName + # and value is the dictionary containing the arguments passed to the OperatorLink + # + # Example of a single iteration: + # + # _operator_links_source = + # {'airflow.gcp.operators.bigquery.BigQueryConsoleIndexableLink': {'index': 0}}, + # + # list(_operator_links_source.items()) = + # [('airflow.gcp.operators.bigquery.BigQueryConsoleIndexableLink', {'index': 0})] + # + # list(_operator_links_source.items())[0] = + # ('airflow.gcp.operators.bigquery.BigQueryConsoleIndexableLink', {'index': 0}) + + _operator_link_class, data = list(_operator_links_source.items())[0] + + if _operator_link_class in registered_operator_link_classes: + single_op_link_class_name = registered_operator_link_classes[_operator_link_class] + else: + raise KeyError("Operator Link class %r not registered" % _operator_link_class) + + op_predefined_extra_link: BaseOperatorLink = cattr.structure( + data, single_op_link_class_name) + + op_predefined_extra_links.update( + {op_predefined_extra_link.name: op_predefined_extra_link} + ) + + return op_predefined_extra_links + + @classmethod + def _serialize_operator_extra_links( + cls, + operator_extra_links: Iterable[BaseOperatorLink] + ): + """ + Serialize Operator Links. Store the import path of the OperatorLink and the arguments + passed to it. Example ``[{'airflow.gcp.operators.bigquery.BigQueryConsoleLink': {}}]`` + + :param operator_extra_links: Operator Link + :return: Serialized Operator Link + """ + serialize_operator_extra_links = [] + for operator_extra_link in operator_extra_links: + op_link_arguments = cattr.unstructure(operator_extra_link) + if not isinstance(op_link_arguments, dict): + op_link_arguments = {} + serialize_operator_extra_links.append( + { + "{}.{}".format(operator_extra_link.__class__.__module__, + operator_extra_link.__class__.__name__): op_link_arguments + } + ) + + return serialize_operator_extra_links + class SerializedDAG(DAG, BaseSerialization): """ diff --git a/airflow/utils/tests.py b/airflow/utils/tests.py index 01bfc8c5bb5c35..939040654b40dc 100644 --- a/airflow/utils/tests.py +++ b/airflow/utils/tests.py @@ -19,7 +19,11 @@ import re import unittest +from typing import FrozenSet, Optional +import attr + +from airflow.models import TaskInstance from airflow.models.baseoperator import BaseOperator, BaseOperatorLink from airflow.utils.decorators import apply_defaults @@ -71,15 +75,70 @@ class Dummy3TestOperator(BaseOperator): operator_extra_links = () -class CustomBaseOperator(BaseOperator): - operator_extra_links = () +@attr.s(auto_attribs=True) +class CustomBaseIndexOpLink(BaseOperatorLink): + index: int = attr.ib() + + @property + def name(self) -> str: + return 'BigQuery Console #{index}'.format(index=self.index + 1) + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + search_queries = ti.xcom_pull(task_ids=operator.task_id, key='search_query') + if not search_queries: + return None + if len(search_queries) < self.index: + return None + search_query = search_queries[self.index] + return 'https://console.cloud.google.com/bigquery?j={}'.format(search_query) + + +class CustomOpLink(BaseOperatorLink): + """ + Operator Link for Apache Airflow Website + """ + name = 'Google Custom' + + def get_link(self, operator, dttm): + ti = TaskInstance(task=operator, execution_date=dttm) + search_query = ti.xcom_pull(task_ids=operator.task_id, key='search_query') + return 'http://google.com/custom_base_link?search={}'.format(search_query) + + +class CustomOperator(BaseOperator): + + # The _serialized_fields are lazily loaded when get_serialized_fields() method is called + __serialized_fields: Optional[FrozenSet[str]] = None + + @property + def operator_extra_links(self): + """ + Return operator extra links + """ + if isinstance(self.bash_command, str) or self.bash_command is None: + return ( + CustomOpLink(), + ) + return ( + CustomBaseIndexOpLink(i) for i, _ in enumerate(self.bash_command) + ) @apply_defaults - def __init__(self, *args, **kwargs): - super(CustomBaseOperator, self).__init__(*args, **kwargs) + def __init__(self, bash_command=None, *args, **kwargs): + super(CustomOperator, self).__init__(*args, **kwargs) + self.bash_command = bash_command def execute(self, context): self.log.info("Hello World!") + context['task_instance'].xcom_push(key='search_query', value="dummy_value") + + @classmethod + def get_serialized_fields(cls): + """Stringified CustomOperator contain exactly these fields.""" + if not cls.__serialized_fields: + cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"bash_command"}) + return cls.__serialized_fields class GoogleLink(BaseOperatorLink): @@ -87,7 +146,7 @@ class GoogleLink(BaseOperatorLink): Operator Link for Apache Airflow Website for Google """ name = 'google' - operators = [Dummy3TestOperator, CustomBaseOperator] + operators = [Dummy3TestOperator, CustomOperator] def get_link(self, operator, dttm): return 'https://www.google.com' diff --git a/docs/dag-serialization.rst b/docs/dag-serialization.rst index 8495aa0cd8e0bb..fc03ab3da41746 100644 --- a/docs/dag-serialization.rst +++ b/docs/dag-serialization.rst @@ -76,9 +76,6 @@ which is why we said "almost" stateless. the execution date and even the data passed by the upstream task using Xcom. * **Code View** will read the DAG File & show it using Pygments. However, it does not need to Parse the Python file so it is still a small operation. -* :doc:`Extra Operator Links ` for the inbuilt Operators would not no longer work. - However, you can define your own Operator Links via Airflow plugins. - Using a different JSON Library ------------------------------ diff --git a/tests/contrib/operators/test_qubole_operator.py b/tests/contrib/operators/test_qubole_operator.py index 92554932b781d8..7a3ac3240f8ba8 100644 --- a/tests/contrib/operators/test_qubole_operator.py +++ b/tests/contrib/operators/test_qubole_operator.py @@ -22,9 +22,10 @@ from airflow import settings from airflow.contrib.hooks.qubole_hook import QuboleHook -from airflow.contrib.operators.qubole_operator import QuboleOperator +from airflow.contrib.operators.qubole_operator import QDSLink, QuboleOperator from airflow.models import DAG, Connection from airflow.models.taskinstance import TaskInstance +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import db from airflow.utils.timezone import datetime @@ -141,3 +142,35 @@ def test_get_redirect_url(self): # check for negative case url2 = task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') self.assertEqual(url2, '') + + def test_extra_serialized_field(self): + dag = DAG(DAG_ID, start_date=DEFAULT_DATE) + with dag: + QuboleOperator( + task_id=TASK_ID, + command_type='shellcmd', + qubole_conn_id=TEST_CONN, + ) + + serialized_dag = SerializedDAG.to_dict(dag) + self.assertIn("qubole_conn_id", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict[TASK_ID] + self.assertEqual(getattr(simple_task, "qubole_conn_id"), TEST_CONN) + + ######################################################### + # Verify Operator Links work with Serialized Operator + ######################################################### + self.assertIsInstance(list(simple_task.operator_extra_links)[0], QDSLink) + + ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) + ti.xcom_push('qbol_cmd_id', 12345) + + # check for positive case + url = simple_task.get_extra_links(DEFAULT_DATE, 'Go to QDS') + self.assertEqual(url, 'http://localhost/v2/analyze?command_id=12345') + + # check for negative case + url2 = simple_task.get_extra_links(datetime(2017, 1, 2), 'Go to QDS') + self.assertEqual(url2, '') diff --git a/tests/gcp/operators/test_bigquery.py b/tests/gcp/operators/test_bigquery.py index 0a7ebfccd21f5f..d3fce38dadcc6c 100644 --- a/tests/gcp/operators/test_bigquery.py +++ b/tests/gcp/operators/test_bigquery.py @@ -24,12 +24,13 @@ from airflow import models from airflow.exceptions import AirflowException from airflow.gcp.operators.bigquery import ( - BigQueryConsoleLink, BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, - BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, BigQueryGetDataOperator, - BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, BigQueryOperator, + BigQueryConsoleIndexableLink, BigQueryConsoleLink, BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator, BigQueryDeleteDatasetOperator, + BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator, BigQueryOperator, BigQueryPatchDatasetOperator, BigQueryTableDeleteOperator, BigQueryUpdateDatasetOperator, ) from airflow.models import DAG, TaskFail, TaskInstance, XCom +from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import Session from airflow.utils.db import provide_session from tests.compat import mock @@ -432,6 +433,92 @@ def test_bigquery_operator_defaults(self, mock_hook): ti.render_templates() self.assertTrue(isinstance(ti.task.sql, str)) + def test_bigquery_operator_extra_serialized_field_when_single_query(self): + with self.dag: + BigQueryOperator( + task_id=TASK_ID, + sql='SELECT * FROM test_table', + ) + serialized_dag = SerializedDAG.to_dict(self.dag) + self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict[TASK_ID] + self.assertEqual(getattr(simple_task, "sql"), 'SELECT * FROM test_table') + + ######################################################### + # Verify Operator Links work with Serialized Operator + ######################################################### + + # Check Serialized version of operator link + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{'airflow.gcp.operators.bigquery.BigQueryConsoleLink': {}}] + ) + + # Check DeSerialized version of operator link + self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink) + + ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) + ti.xcom_push('job_id', 12345) + + # check for positive case + url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name) + self.assertEqual(url, 'https://console.cloud.google.com/bigquery?j=12345') + + # check for negative case + url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name) + self.assertEqual(url2, '') + + def test_bigquery_operator_extra_serialized_field_when_multiple_queries(self): + with self.dag: + BigQueryOperator( + task_id=TASK_ID, + sql=['SELECT * FROM test_table', 'SELECT * FROM test_table2'], + ) + serialized_dag = SerializedDAG.to_dict(self.dag) + self.assertIn("sql", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict[TASK_ID] + self.assertEqual(getattr(simple_task, "sql"), + ['SELECT * FROM test_table', 'SELECT * FROM test_table2']) + + ######################################################### + # Verify Operator Links work with Serialized Operator + ######################################################### + + # Check Serialized version of operator link + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [ + {'airflow.gcp.operators.bigquery.BigQueryConsoleIndexableLink': {'index': 0}}, + {'airflow.gcp.operators.bigquery.BigQueryConsoleIndexableLink': {'index': 1}} + ] + ) + + # Check DeSerialized version of operator link + self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleIndexableLink) + + ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE) + job_id = ['123', '45'] + ti.xcom_push(key='job_id', value=job_id) + + self.assertEqual( + {'BigQuery Console #1', 'BigQuery Console #2'}, + simple_task.operator_extra_link_dict.keys() + ) + + self.assertEqual( + 'https://console.cloud.google.com/bigquery?j=123', + simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #1'), + ) + + self.assertEqual( + 'https://console.cloud.google.com/bigquery?j=45', + simple_task.get_extra_links(DEFAULT_DATE, 'BigQuery Console #2'), + ) + @provide_session @mock.patch('airflow.gcp.operators.bigquery.BigQueryHook') def test_bigquery_operator_extra_link_when_missing_job_id(self, mock_hook, session): diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index 4ec1d7f295d757..a05b6743ab2169 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -27,7 +27,9 @@ # This is the class you derive to create a plugin from airflow.plugins_manager import AirflowPlugin from airflow.sensors.base_sensor_operator import BaseSensorOperator -from airflow.utils.tests import AirflowLink, AirflowLink2, GithubLink, GoogleLink +from airflow.utils.tests import ( + AirflowLink, AirflowLink2, CustomBaseIndexOpLink, CustomOpLink, GithubLink, GoogleLink, +) # Will show up under airflow.hooks.test_plugin.PluginHook @@ -107,7 +109,7 @@ class AirflowTestPlugin(AirflowPlugin): GithubLink(), ] operator_extra_links = [ - GoogleLink(), AirflowLink2() + GoogleLink(), AirflowLink2(), CustomOpLink(), CustomBaseIndexOpLink(1) ] diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 89d4fb18fcbd63..288c0ce5eac297 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -31,12 +31,12 @@ from airflow.contrib import example_dags as contrib_example_dags from airflow.gcp import example_dags as gcp_example_dags from airflow.hooks.base_hook import BaseHook -from airflow.models import DAG, Connection, DagBag +from airflow.models import DAG, Connection, DagBag, TaskInstance from airflow.models.baseoperator import BaseOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.subdag_operator import SubDagOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG -from airflow.utils.tests import CustomBaseOperator, GoogleLink +from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink serialized_simple_dag_ground_truth = { "__version": 1, @@ -78,10 +78,11 @@ "_downstream_task_ids": [], "_inlets": [], "_outlets": [], + "_operator_extra_links": [{"airflow.utils.tests.CustomOpLink": {}}], "ui_color": "#fff", "ui_fgcolor": "#000", "template_fields": [], - "_task_type": "CustomBaseOperator", + "_task_type": "CustomOperator", "_task_module": "airflow.utils.tests", }, ], @@ -108,7 +109,7 @@ def make_simple_dag(): start_date=datetime(2019, 8, 1), ) BaseOperator(task_id='simple_task', dag=dag, owner='airflow') - CustomBaseOperator(task_id='custom_task', dag=dag) + CustomOperator(task_id='custom_task', dag=dag) return {'simple_dag': dag} @@ -256,10 +257,6 @@ def test_deserialization(self): SubDagOperator.ui_fgcolor ) - simple_dag = stringified_dags['simple_dag'] - custom_task = simple_dag.task_dict['custom_task'] - self.validate_operator_extra_links(custom_task) - def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor): """Verify non-airflow operators are casted to BaseOperator.""" self.assertTrue(isinstance(task, SerializedBaseOperator)) @@ -276,22 +273,6 @@ def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor): else: self.assertIsNone(task.subdag) - def validate_operator_extra_links(self, task): - """ - This tests also depends on GoogleLink() registered as a plugin - in tests/plugins/test_plugin.py - - The function tests that if extra operator links are registered in plugin - in ``operator_extra_links`` and the same is also defined in - the Operator in ``BaseOperator.operator_extra_links``, it has the correct - extra link. - """ - self.assertEqual( - task.operator_extra_link_dict[GoogleLink.name].get_link( - task, datetime(2019, 8, 1)), - "https://www.google.com" - ) - @parameterized.expand([ (datetime(2019, 8, 1), None, datetime(2019, 8, 1)), (datetime(2019, 8, 1), datetime(2019, 8, 2), datetime(2019, 8, 2)), @@ -381,6 +362,106 @@ def test_roundtrip_relativedelta(self, val, expected): round_tripped = SerializedDAG._deserialize(serialized) self.assertEqual(val, round_tripped) + def test_extra_serialized_field_and_operator_links(self): + """ + Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. + + This tests also depends on GoogleLink() registered as a plugin + in tests/plugins/test_plugin.py + + The function tests that if extra operator links are registered in plugin + in ``operator_extra_links`` and the same is also defined in + the Operator in ``BaseOperator.operator_extra_links``, it has the correct + extra link. + """ + test_date = datetime(2019, 8, 1) + dag = DAG(dag_id='simple_dag', start_date=test_date) + CustomOperator(task_id='simple_task', dag=dag, bash_command="true") + + serialized_dag = SerializedDAG.to_dict(dag) + self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict["simple_task"] + self.assertEqual(getattr(simple_task, "bash_command"), "true") + + ######################################################### + # Verify Operator Links work with Serialized Operator + ######################################################### + # Check Serialized version of operator link only contains the inbuilt Op Link + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [{'airflow.utils.tests.CustomOpLink': {}}] + ) + + # Test all the extra_links are set + self.assertCountEqual(simple_task.extra_links, ['Google Custom', 'airflow', 'github', 'google']) + + ti = TaskInstance(task=simple_task, execution_date=test_date) + ti.xcom_push('search_query', "dummy_value_1") + + # Test Deserialized inbuilt link + custom_inbuilt_link = simple_task.get_extra_links(test_date, CustomOpLink.name) + self.assertEqual('http://google.com/custom_base_link?search=dummy_value_1', custom_inbuilt_link) + + # Test Deserialized link registered via Airflow Plugin + google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name) + self.assertEqual("https://www.google.com", google_link_from_plugin) + + def test_extra_serialized_field_and_multiple_operator_links(self): + """ + Assert extra field exists & OperatorLinks defined in Plugins and inbuilt Operator Links. + + This tests also depends on GoogleLink() registered as a plugin + in tests/plugins/test_plugin.py + + The function tests that if extra operator links are registered in plugin + in ``operator_extra_links`` and the same is also defined in + the Operator in ``BaseOperator.operator_extra_links``, it has the correct + extra link. + """ + test_date = datetime(2019, 8, 1) + dag = DAG(dag_id='simple_dag', start_date=test_date) + CustomOperator(task_id='simple_task', dag=dag, bash_command=["echo", "true"]) + + serialized_dag = SerializedDAG.to_dict(dag) + self.assertIn("bash_command", serialized_dag["dag"]["tasks"][0]) + + dag = SerializedDAG.from_dict(serialized_dag) + simple_task = dag.task_dict["simple_task"] + self.assertEqual(getattr(simple_task, "bash_command"), ["echo", "true"]) + + ######################################################### + # Verify Operator Links work with Serialized Operator + ######################################################### + # Check Serialized version of operator link only contains the inbuilt Op Link + self.assertEqual( + serialized_dag["dag"]["tasks"][0]["_operator_extra_links"], + [ + {'airflow.utils.tests.CustomBaseIndexOpLink': {'index': 0}}, + {'airflow.utils.tests.CustomBaseIndexOpLink': {'index': 1}}, + ] + ) + + # Test all the extra_links are set + self.assertCountEqual(simple_task.extra_links, [ + 'BigQuery Console #1', 'BigQuery Console #2', 'airflow', 'github', 'google']) + + ti = TaskInstance(task=simple_task, execution_date=test_date) + ti.xcom_push('search_query', ["dummy_value_1", "dummy_value_2"]) + + # Test Deserialized inbuilt link #1 + custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #1") + self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_1', custom_inbuilt_link) + + # Test Deserialized inbuilt link #2 + custom_inbuilt_link = simple_task.get_extra_links(test_date, "BigQuery Console #2") + self.assertEqual('https://console.cloud.google.com/bigquery?j=dummy_value_2', custom_inbuilt_link) + + # Test Deserialized link registered via Airflow Plugin + google_link_from_plugin = simple_task.get_extra_links(test_date, GoogleLink.name) + self.assertEqual("https://www.google.com", google_link_from_plugin) + if __name__ == '__main__': unittest.main()