Skip to content

Commit

Permalink
Use generic utils for cleanup db and operate with temp objects in Tes…
Browse files Browse the repository at this point in the history
…tXcomObjectStoreBackend (apache#37166)

* Use generic utils for cleanup db and operate with temp objects in TestXcomObjectStoreBackend

* Return back @pytest.mark.db_test markers

* Use global marker on module

* DRY: Use create_task_instance_of_operator fixture
  • Loading branch information
Taragolis authored and Mathia Haure-Touze committed Apr 4, 2024
1 parent fa3f937 commit 5bafa2e
Showing 1 changed file with 24 additions and 69 deletions.
93 changes: 24 additions & 69 deletions tests/providers/common/io/xcom/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,55 @@
# under the License.
from __future__ import annotations

from configparser import DuplicateSectionError
from typing import TYPE_CHECKING

import pytest

import airflow.models.xcom
from airflow import settings
from airflow.configuration import conf
from airflow.io.path import ObjectStoragePath
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import BaseXCom, resolve_xcom_backend
from airflow.operators.empty import EmptyOperator
from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.types import DagRunType
from airflow.utils.xcom import XCOM_RETURN_KEY
from tests.test_utils import db
from tests.test_utils.config import conf_vars

if TYPE_CHECKING:
from sqlalchemy.orm import Session
pytestmark = pytest.mark.db_test


@pytest.fixture(autouse=True)
def reset_db():
"""Reset XCom entries."""
with create_session() as session:
session.query(DagRun).delete()
session.query(airflow.models.xcom.XCom).delete()


@pytest.fixture
def task_instance_factory(request, session: Session):
def func(*, dag_id, task_id, execution_date):
run_id = DagRun.generate_run_id(DagRunType.SCHEDULED, execution_date)
run = DagRun(
dag_id=dag_id,
run_type=DagRunType.SCHEDULED,
run_id=run_id,
execution_date=execution_date,
)
session.add(run)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
ti.dag_id = dag_id
session.add(ti)
session.commit()

def cleanup_database():
# This should also clear task instances by cascading.
session.query(DagRun).filter_by(id=run.id).delete()
session.commit()

request.addfinalizer(cleanup_database)
return ti

return func
db.clear_db_runs()
db.clear_db_xcom()
yield
db.clear_db_runs()
db.clear_db_xcom()


@pytest.fixture
def task_instance(task_instance_factory):
return task_instance_factory(
dag_id="dag",
task_id="task_1",
def task_instance(create_task_instance_of_operator):
return create_task_instance_of_operator(
EmptyOperator,
dag_id="test-dag-id",
task_id="test-task-id",
execution_date=timezone.datetime(2021, 12, 3, 4, 56),
)


class TestXComObjectStorageBackend:
path = "file:/tmp/xcom"

def setup_method(self):
try:
conf.add_section("common.io")
except DuplicateSectionError:
pass
conf.set("core", "xcom_backend", "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend")
conf.set("common.io", "xcom_objectstorage_path", self.path)
conf.set("common.io", "xcom_objectstorage_threshold", "50")
settings.configure_vars()

def teardown_method(self):
conf.remove_option("core", "xcom_backend")
conf.remove_option("common.io", "xcom_objectstorage_path")
conf.remove_option("common.io", "xcom_objectstorage_threshold")
settings.configure_vars()
p = ObjectStoragePath(self.path)
if p.exists():
p.rmdir(recursive=True)
@pytest.fixture(autouse=True)
def setup_test_cases(self, tmp_path):
xcom_path = tmp_path / "xcom"
xcom_path.mkdir()
self.path = f"file://{xcom_path.as_posix()}"
configuration = {
("core", "xcom_backend"): "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend",
("common.io", "xcom_objectstorage_path"): self.path,
("common.io", "xcom_objectstorage_threshold"): "50",
}
with conf_vars(configuration):
yield

@pytest.mark.db_test
def test_value_db(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -137,7 +95,6 @@ def test_value_db(self, task_instance, session):
)
assert qry.first().value == {"key": "value"}

@pytest.mark.db_test
def test_value_storage(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -183,7 +140,6 @@ def test_value_storage(self, task_instance, session):
)
assert str(p) == qry.first().value

@pytest.mark.db_test
def test_clear(self, task_instance, session):
XCom = resolve_xcom_backend()
airflow.models.xcom.XCom = XCom
Expand Down Expand Up @@ -222,7 +178,6 @@ def test_clear(self, task_instance, session):

assert p.exists() is False

@pytest.mark.db_test
@conf_vars({("common.io", "xcom_objectstorage_compression"): "gzip"})
def test_compression(self, task_instance, session):
XCom = resolve_xcom_backend()
Expand Down

0 comments on commit 5bafa2e

Please sign in to comment.