Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom pickling hooks to LazyXComAccess #28191

Merged
merged 1 commit into from
Dec 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,21 @@ def __eq__(self, other: Any) -> bool:
return all(x == y for x, y in z)
return NotImplemented

def __getstate__(self) -> Any:
# We don't want to go to the trouble of serializing the entire Query
# object, including its filters, hints, etc. (plus SQLAlchemy does not
# provide a public API to inspect a query's contents). Converting the
# query into a SQL string is the best we can get. Theoratically we can
# do the same for count(), but I think it should be performant enough to
# calculate only that eagerly.
with self._get_bound_query() as query:
statement = query.statement.compile(query.session.get_bind())
return (str(statement), query.count())

def __setstate__(self, state: Any) -> None:
statement, self._len = state
self._query = Query(XCom.value).from_statement(text(statement))

def __len__(self):
if self._len is None:
with self._get_bound_query() as query:
Expand Down
15 changes: 15 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import operator
import os
import pathlib
import pickle
import signal
import sys
import urllib
Expand Down Expand Up @@ -3591,6 +3592,20 @@ def cmds():
assert out_lines == ["hello FOO", "goodbye FOO", "hello BAR", "goodbye BAR"]


def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session):
with dag_maker(session=session):
EmptyOperator(task_id="t")

run: DagRun = dag_maker.create_dagrun()
run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session)

original = LazyXComAccess.build_from_xcom_query(session.query(XCom))
processed = pickle.loads(pickle.dumps(original))

assert len(processed) == 1
assert list(processed) == [123]


@mock.patch("airflow.models.taskinstance.XCom.deserialize_value", side_effect=XCom.deserialize_value)
def test_ti_xcom_pull_on_mapped_operator_return_lazy_iterable(mock_deserialize_value, dag_maker, session):
"""Ensure we access XCom lazily when pulling from a mapped operator."""
Expand Down