Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generic utils for cleanup db and operate with temp objects in TestXcomObjectStoreBackend #37166

Merged
merged 4 commits into from
Mar 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 24 additions & 69 deletions tests/providers/common/io/xcom/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,55 @@
# under the License.
from __future__ import annotations

from configparser import DuplicateSectionError
from typing import TYPE_CHECKING

import pytest

import airflow.models.xcom
from airflow import settings
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import BaseXCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.test_utils import db
from tests.test_utils.config import conf_vars

if TYPE_CHECKING:
from sqlalchemy.orm import Session
pytestmark = pytest.mark.db_test


@pytest.fixture(autouse=True)
def reset_db():
"""Reset XCom entries."""
with create_session() as session:
session.query(DagRun).delete()
session.query(airflow.models.xcom.XCom).delete()


@pytest.fixture
def task_instance_factory(request, session: Session):
def func(*, dag_id, task_id, execution_date):
run_id = DagRun.generate_run_id(DagRunType.SCHEDULED, execution_date)
run = DagRun(
dag_id=dag_id,
run_type=DagRunType.SCHEDULED,
run_id=run_id,
execution_date=execution_date,
)
session.add(run)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
session.commit()

def cleanup_database():
# This should also clear task instances by cascading.
session.query(DagRun).filter_by(id=run.id).delete()
session.commit()

request.addfinalizer(cleanup_database)
return ti

return func
db.clear_db_runs()
db.clear_db_xcom()
yield
db.clear_db_runs()
db.clear_db_xcom()


@pytest.fixture
def task_instance(task_instance_factory):
return task_instance_factory(
dag_id="dag",
task_id="task_1",
def task_instance(create_task_instance_of_operator):
return create_task_instance_of_operator(
EmptyOperator,
dag_id="test-dag-id",
task_id="test-task-id",
execution_date=timezone.datetime(2021, 12, 3, 4, 56),
)


class TestXComObjectStorageBackend:
path = "file:/tmp/xcom"

def setup_method(self):
try:
conf.add_section("common.io")
except DuplicateSectionError:
pass
conf.set("core", "xcom_backend", "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend")
conf.set("common.io", "xcom_objectstorage_path", self.path)
conf.set("common.io", "xcom_objectstorage_threshold", "50")
settings.configure_vars()

def teardown_method(self):
conf.remove_option("core", "xcom_backend")
conf.remove_option("common.io", "xcom_objectstorage_path")
conf.remove_option("common.io", "xcom_objectstorage_threshold")
settings.configure_vars()
p = ObjectStoragePath(self.path)
if p.exists():
p.rmdir(recursive=True)
@pytest.fixture(autouse=True)
def setup_test_cases(self, tmp_path):
xcom_path = tmp_path / "xcom"
xcom_path.mkdir()
self.path = f"file://{xcom_path.as_posix()}"
configuration = {
("core", "xcom_backend"): "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend",
("common.io", "xcom_objectstorage_path"): self.path,
("common.io", "xcom_objectstorage_threshold"): "50",
}
with conf_vars(configuration):
yield

@pytest.mark.db_test
def test_value_db(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -137,7 +95,6 @@ def test_value_db(self, task_instance, session):
)
assert qry.first().value == {"key": "value"}

@pytest.mark.db_test
def test_value_storage(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -183,7 +140,6 @@ def test_value_storage(self, task_instance, session):
)
assert str(p) == qry.first().value

@pytest.mark.db_test
def test_clear(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -222,7 +178,6 @@ def test_clear(self, task_instance, session):

assert p.exists() is False

@pytest.mark.db_test
@conf_vars({("common.io", "xcom_objectstorage_compression"): "gzip"})
def test_compression(self, task_instance, session):
XCom = resolve_xcom_backend()
Expand Down