Skip to content

Commit

Permalink
[AIRFLOW-5945] Make inbuilt OperatorLinks work when using Serializati…
Browse files Browse the repository at this point in the history
…on (#6715)
  • Loading branch information
kaxil authored Dec 6, 2019
1 parent 993e105 commit 803a87f
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 72 deletions.
15 changes: 13 additions & 2 deletions airflow/contrib/operators/qubole_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# under the License.
"""Qubole operator"""
import re
from typing import Iterable
from typing import FrozenSet, Iterable, Optional

from airflow.contrib.hooks.qubole_hook import (
COMMAND_ARGS, HYPHEN_ARGS, POSITIONAL_ARGS, QuboleHook, flatten_list,
Expand All @@ -42,7 +42,8 @@ def get_link(self, operator, dttm):
:return: url link
"""
ti = TaskInstance(task=operator, execution_date=dttm)
conn = BaseHook.get_connection(operator.kwargs['qubole_conn_id'])
conn = BaseHook.get_connection(
getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'])
if conn and conn.host:
host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
else:
Expand Down Expand Up @@ -181,6 +182,9 @@ class QuboleOperator(BaseOperator):
QDSLink(),
)

# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
__serialized_fields: Optional[FrozenSet[str]] = None

@apply_defaults
def __init__(self, qubole_conn_id="qubole_default", *args, **kwargs):
self.args = args
Expand Down Expand Up @@ -240,3 +244,10 @@ def __setattr__(self, name, value):
self.kwargs[name] = value
else:
object.__setattr__(self, name, value)

@classmethod
def get_serialized_fields(cls):
"""Serialized QuboleOperator contain exactly these fields."""
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"qubole_conn_id"})
return cls.__serialized_fields
18 changes: 14 additions & 4 deletions airflow/gcp/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

import json
import warnings
from typing import Any, Dict, Iterable, List, Optional, SupportsAbs, Union
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, SupportsAbs, Union

import attr
from googleapiclient.errors import HttpError

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -337,14 +338,13 @@ def get_link(self, operator, dttm):
return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id) if job_id else ''


@attr.s(auto_attribs=True)
class BigQueryConsoleIndexableLink(BaseOperatorLink):
"""
Helper class for constructing BigQuery link.
"""

def __init__(self, index) -> None:
super().__init__()
self.index = index
index: int = attr.ib()

@property
def name(self) -> str:
Expand Down Expand Up @@ -459,6 +459,9 @@ class BigQueryOperator(BaseOperator):
template_ext = ('.sql', )
ui_color = '#e4f0e8'

# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
__serialized_fields: Optional[FrozenSet[str]] = None

@property
def operator_extra_links(self):
"""
Expand Down Expand Up @@ -594,6 +597,13 @@ def on_kill(self):
self.log.info('Cancelling running query')
self.bq_cursor.cancel_query()

@classmethod
def get_serialized_fields(cls):
"""Serialized BigQueryOperator contain exactly these fields."""
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(super().get_serialized_fields() | {"sql"})
return cls.__serialized_fields


class BigQueryCreateEmptyTableOperator(BaseOperator):
"""
Expand Down
30 changes: 16 additions & 14 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
import warnings
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union

import attr
import jinja2
from cached_property import cached_property
from dateutil.relativedelta import relativedelta
Expand Down Expand Up @@ -256,7 +257,7 @@ class derived from this one results in the creation of a task object,
operator_extra_links: Iterable['BaseOperatorLink'] = ()

# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
_serialized_fields: Optional[FrozenSet[str]] = None
__serialized_fields: Optional[FrozenSet[str]] = None

_comps = {
'task_id',
Expand Down Expand Up @@ -785,15 +786,15 @@ def prepare_template(self) -> None:
def resolve_template_files(self) -> None:
"""Getting the content of files for template_field / template_ext"""
if self.template_ext: # pylint: disable=too-many-nested-blocks
for attr in self.template_fields:
content = getattr(self, attr, None)
for field in self.template_fields:
content = getattr(self, field, None)
if content is None:
continue
elif isinstance(content, str) and \
any([content.endswith(ext) for ext in self.template_ext]):
env = self.get_template_env()
try:
setattr(self, attr, env.loader.get_source(env, content)[0])
setattr(self, field, env.loader.get_source(env, content)[0])
except Exception as e: # pylint: disable=broad-except
self.log.exception(e)
elif isinstance(content, list):
Expand Down Expand Up @@ -939,10 +940,10 @@ def run(
def dry_run(self) -> None:
"""Performs dry run for the operator - just render template fields."""
self.log.info('Dry run')
for attr in self.template_fields:
content = getattr(self, attr)
for field in self.template_fields:
content = getattr(self, field)
if content and isinstance(content, str):
self.log.info('Rendering template for %s', attr)
self.log.info('Rendering template for %s', field)
self.log.info(content)

def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
Expand Down Expand Up @@ -1101,21 +1102,22 @@ def get_extra_links(self, dttm: datetime, link_name: str) -> Optional[Dict[str,
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls._serialized_fields:
cls._serialized_fields = frozenset(
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(
vars(BaseOperator(task_id='test')).keys() - {
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag'
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'}
)
return cls._serialized_fields
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'})

return cls.__serialized_fields


@attr.s(auto_attribs=True)
class BaseOperatorLink(metaclass=ABCMeta):
"""
Abstract base class that defines how we get an operator link.
"""

operators: List[Type[BaseOperator]] = []
operators: ClassVar[List[Type[BaseOperator]]] = []
"""
This property will be used by Airflow Plugins to find the Operators to which you want
to assign this Operator Link
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class DAG(BaseDag, LoggingMixin):
'last_loaded',
}

_serialized_fields: Optional[FrozenSet[str]] = None
__serialized_fields: Optional[FrozenSet[str]] = None

def __init__(
self,
Expand Down Expand Up @@ -1508,14 +1508,14 @@ def _test_cycle_helper(self, visit_map, task_id):
@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls._serialized_fields:
cls._serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - {
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - {
'parent_dag', '_old_context_manager_dags', 'safe_dag_id', 'last_loaded',
'_full_filepath', 'user_defined_filters', 'user_defined_macros',
'_schedule_interval', 'partial', '_old_context_manager_dags',
'_pickle_id', '_log', 'is_subdag', 'task_dict'
}
return cls._serialized_fields
return cls.__serialized_fields


class DagModel(Base):
Expand Down
47 changes: 42 additions & 5 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import re
import sys
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set, Type

import pkg_resources

Expand Down Expand Up @@ -112,6 +112,33 @@ def load_entrypoint_plugins(entry_points, airflow_plugins):
return airflow_plugins


def register_inbuilt_operator_links() -> None:
"""
Register all the Operators Links that are already defined for the operators
in the "airflow" project. Example: QDSLink (Operator Link for Qubole Operator)
This is required to populate the "whitelist" of allowed classes when deserializing operator links
"""
inbuilt_operator_links: Set[Type] = set()

try:
from airflow.gcp.operators.bigquery import BigQueryConsoleLink, BigQueryConsoleIndexableLink # noqa E501 # pylint: disable=R0401,line-too-long
inbuilt_operator_links.update([BigQueryConsoleLink, BigQueryConsoleIndexableLink])
except ImportError:
pass

try:
from airflow.contrib.operators.qubole_operator import QDSLink # pylint: disable=R0401
inbuilt_operator_links.update([QDSLink])
except ImportError:
pass

registered_operator_link_classes.update({
"{}.{}".format(link.__module__, link.__name__): link
for link in inbuilt_operator_links
})


def is_valid_plugin(plugin_obj, existing_plugins):
"""
Check whether a potential object is a subclass of
Expand Down Expand Up @@ -200,6 +227,12 @@ def make_module(name: str, objects: List[Any]):
stat_name_handler: Any = None
global_operator_extra_links: List[Any] = []
operator_extra_links: List[Any] = []
registered_operator_link_classes: Dict[str, Type] = {}
"""Mapping of class names to class of OperatorLinks registered by plugins.
Used by the DAG serialization code to only allow specific classes to be created
during deserialization
"""

stat_name_handlers = []
for p in plugins:
Expand Down Expand Up @@ -227,10 +260,13 @@ def make_module(name: str, objects: List[Any]):
if p.stat_name_handler:
stat_name_handlers.append(p.stat_name_handler)
global_operator_extra_links.extend(p.global_operator_extra_links)
# Only register Operator links if its ``operators`` property is not an empty list
# So that we can only attach this links to a specific Operator
operator_extra_links.extend([
ope for ope in p.operator_extra_links if ope.operators])
operator_extra_links.extend([ope for ope in p.operator_extra_links])

registered_operator_link_classes.update({
"{}.{}".format(link.__class__.__module__,
link.__class__.__name__): link.__class__
for link in p.operator_extra_links
})

if len(stat_name_handlers) > 1:
raise AirflowPluginException(
Expand Down Expand Up @@ -287,3 +323,4 @@ def integrate_plugins() -> None:
integrate_hook_plugins()
integrate_executor_plugins()
integrate_macro_plugins()
register_inbuilt_operator_links()
9 changes: 9 additions & 0 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@
"type": "string",
"pattern": "^#[a-fA-F0-9]{3,6}$"
},
"extra_links": {
"type": "array",
"items": {
"type": "object",
"minProperties": 1,
"maxProperties": 1
}
},
"dag": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -112,6 +120,7 @@
"properties": {
"_task_type": { "type": "string" },
"_task_module": { "type": "string" },
"_operator_extra_links": { "$ref": "#/definitions/extra_links" },
"task_id": { "type": "string" },
"owner": { "type": "string" },
"start_date": { "$ref": "#/definitions/datetime" },
Expand Down
Loading

0 comments on commit 803a87f

Please sign in to comment.