diff --git a/tests/providers/common/io/xcom/test_backend.py b/tests/providers/common/io/xcom/test_backend.py index da2a11950ee06b..008394f365c117 100644 --- a/tests/providers/common/io/xcom/test_backend.py +++ b/tests/providers/common/io/xcom/test_backend.py @@ -17,97 +17,55 @@ # under the License. from __future__ import annotations -from configparser import DuplicateSectionError -from typing import TYPE_CHECKING - import pytest import airflow.models.xcom -from airflow import settings -from airflow.configuration import conf from airflow.io.path import ObjectStoragePath -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import BaseXCom, resolve_xcom_backend from airflow.operators.empty import EmptyOperator from airflow.providers.common.io.xcom.backend import XComObjectStorageBackend from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY +from tests.test_utils import db from tests.test_utils.config import conf_vars -if TYPE_CHECKING: - from sqlalchemy.orm import Session +pytestmark = pytest.mark.db_test @pytest.fixture(autouse=True) def reset_db(): """Reset XCom entries.""" - with create_session() as session: - session.query(DagRun).delete() - session.query(airflow.models.xcom.XCom).delete() - - -@pytest.fixture -def task_instance_factory(request, session: Session): - def func(*, dag_id, task_id, execution_date): - run_id = DagRun.generate_run_id(DagRunType.SCHEDULED, execution_date) - run = DagRun( - dag_id=dag_id, - run_type=DagRunType.SCHEDULED, - run_id=run_id, - execution_date=execution_date, - ) - session.add(run) - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) - ti.dag_id = dag_id - session.add(ti) - session.commit() - - def cleanup_database(): - # This should also clear task instances by cascading. - session.query(DagRun).filter_by(id=run.id).delete() - session.commit() - - request.addfinalizer(cleanup_database) - return ti - - return func + db.clear_db_runs() + db.clear_db_xcom() + yield + db.clear_db_runs() + db.clear_db_xcom() @pytest.fixture -def task_instance(task_instance_factory): - return task_instance_factory( - dag_id="dag", - task_id="task_1", +def task_instance(create_task_instance_of_operator): + return create_task_instance_of_operator( + EmptyOperator, + dag_id="test-dag-id", + task_id="test-task-id", execution_date=timezone.datetime(2021, 12, 3, 4, 56), ) class TestXComObjectStorageBackend: - path = "file:/tmp/xcom" - - def setup_method(self): - try: - conf.add_section("common.io") - except DuplicateSectionError: - pass - conf.set("core", "xcom_backend", "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend") - conf.set("common.io", "xcom_objectstorage_path", self.path) - conf.set("common.io", "xcom_objectstorage_threshold", "50") - settings.configure_vars() - - def teardown_method(self): - conf.remove_option("core", "xcom_backend") - conf.remove_option("common.io", "xcom_objectstorage_path") - conf.remove_option("common.io", "xcom_objectstorage_threshold") - settings.configure_vars() - p = ObjectStoragePath(self.path) - if p.exists(): - p.rmdir(recursive=True) + @pytest.fixture(autouse=True) + def setup_test_cases(self, tmp_path): + xcom_path = tmp_path / "xcom" + xcom_path.mkdir() + self.path = f"file://{xcom_path.as_posix()}" + configuration = { + ("core", "xcom_backend"): "airflow.providers.common.io.xcom.backend.XComObjectStorageBackend", + ("common.io", "xcom_objectstorage_path"): self.path, + ("common.io", "xcom_objectstorage_threshold"): "50", + } + with conf_vars(configuration): + yield - @pytest.mark.db_test def test_value_db(self, task_instance, session): XCom = resolve_xcom_backend() airflow.models.xcom.XCom = XCom @@ -137,7 +95,6 @@ def test_value_db(self, task_instance, session): ) assert qry.first().value == {"key": "value"} - @pytest.mark.db_test def test_value_storage(self, task_instance, session): XCom = resolve_xcom_backend() airflow.models.xcom.XCom = XCom @@ -183,7 +140,6 @@ def test_value_storage(self, task_instance, session): ) assert str(p) == qry.first().value - @pytest.mark.db_test def test_clear(self, task_instance, session): XCom = resolve_xcom_backend() airflow.models.xcom.XCom = XCom @@ -222,7 +178,6 @@ def test_clear(self, task_instance, session): assert p.exists() is False - @pytest.mark.db_test @conf_vars({("common.io", "xcom_objectstorage_compression"): "gzip"}) def test_compression(self, task_instance, session): XCom = resolve_xcom_backend()