Skip to content

Commit

Permalink
Add allow list for imports during deserialization (#27887)
Browse files Browse the repository at this point in the history
During deserialization Airflow can instantiate arbitrary
objects for which it imports modules. This can be dangerous
as it could lead to unwanted effects. With this change
administrators can now limit what objects can be deserialized.
It defaults to Airflow's own only.
  • Loading branch information
bolkedebruin authored Nov 25, 2022
1 parent 1e73b1c commit 542cfdc
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 3 deletions.
9 changes: 9 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@
example: ~
default: "False"
see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json"
- name: allowed_deserialization_classes
description: |
What classes can be imported during deserialization. This is a multi line value.
The individual items will be parsed as regexp. Python built-in classes (like dict)
are always allowed
version_added: 2.5.0
type: string
default: 'airflow\..*'
example: ~
- name: killed_task_cleanup_time
description: |
When a task is killed forcefully, this is the amount of time in seconds that
Expand Down
5 changes: 5 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ unit_test_mode = False
# RCE exploits).
enable_xcom_pickling = False

# What classes can be imported during deserialization. This is a multi line value.
# The individual items will be parsed as regexp. Python built-in classes (like dict)
# are always allowed
allowed_deserialization_classes = airflow\..*

# When a task is killed forcefully, this is the amount of time in seconds that
# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED
killed_task_cleanup_time = 60
Expand Down
2 changes: 2 additions & 0 deletions airflow/config_templates/default_test.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ plugins_folder = {TEST_PLUGINS_FOLDER}
dags_are_paused_at_creation = False
fernet_key = {FERNET_KEY}
killed_task_cleanup_time = 5
allowed_deserialization_classes = airflow\..*
tests\..*

[database]
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db
Expand Down
24 changes: 21 additions & 3 deletions airflow/utils/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import dataclasses
import json
import logging
import re
from datetime import date, datetime
from decimal import Decimal
from typing import Any

import attr
from flask.json.provider import JSONProvider

from airflow.configuration import conf
from airflow.serialization.enums import Encoding
from airflow.utils.module_loading import import_string
from airflow.utils.timezone import convert_to_utc, is_naive
Expand Down Expand Up @@ -181,20 +183,36 @@ class XComDecoder(json.JSONDecoder):
as is.
"""

_pattern: list[re.Pattern] = []

def __init__(self, *args, **kwargs) -> None:
if not kwargs.get("object_hook"):
kwargs["object_hook"] = self.object_hook

patterns = conf.get("core", "allowed_deserialization_classes").split()

self._pattern.clear() # ensure to reinit
for p in patterns:
self._pattern.append(re.compile(p))

super().__init__(*args, **kwargs)

@staticmethod
def object_hook(dct: dict) -> object:
def object_hook(self, dct: dict) -> object:
dct = XComDecoder._convert(dct)

if CLASSNAME in dct and VERSION in dct:
from airflow.serialization.serialized_objects import BaseSerialization

cls = import_string(dct[CLASSNAME])
classname = dct[CLASSNAME]
cls = None

for p in self._pattern:
if p.match(classname):
cls = import_string(classname)
break

if not cls:
raise ImportError(f"{classname} was not found in allow list for import")

if hasattr(cls, "deserialize"):
return getattr(cls, "deserialize")(dct[DATA], dct[VERSION])
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from airflow.datasets import Dataset
from airflow.utils import json as utils_json
from tests.test_utils.config import conf_vars


class Z:
Expand Down Expand Up @@ -215,3 +216,18 @@ def test_orm_custom_deserialize(self):
s = json.dumps(z, cls=utils_json.XComEncoder)
o = json.loads(s, cls=utils_json.XComDecoder, object_hook=utils_json.XComDecoder.orm_object_hook)
assert o == f"{Z.__module__}.{Z.__qualname__}@version={Z.version}(x={x})"

@conf_vars(
{
("core", "allowed_deserialization_classes"): "airflow[.].*",
}
)
def test_allow_list_for_imports(self):
x = 14
z = Z(x=x)
s = json.dumps(z, cls=utils_json.XComEncoder)

with pytest.raises(ImportError) as e:
json.loads(s, cls=utils_json.XComDecoder)

assert f"{Z.__module__}.{Z.__qualname__} was not found in allow list" in str(e.value)

0 comments on commit 542cfdc

Please sign in to comment.