Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fixture for CLI tests requiring sample dags #26536

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import provide_session
from airflow.utils.timeout import timeout
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
import pathlib
Expand Down Expand Up @@ -92,8 +93,8 @@ class DagBag(LoggingMixin):
def __init__(
self,
dag_folder: str | pathlib.Path | None = None,
include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
include_examples: bool | ArgNotSet = NOTSET,
safe_mode: bool | ArgNotSet = NOTSET,
read_dags_from_db: bool = False,
store_serialized_dags: bool | None = None,
load_op_links: bool = True,
Expand All @@ -103,6 +104,15 @@ def __init__(

super().__init__()

include_examples = (
include_examples
if isinstance(include_examples, bool)
else conf.getboolean('core', 'LOAD_EXAMPLES')
)
safe_mode = (
safe_mode if isinstance(safe_mode, bool) else conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE')
)

if store_serialized_dags:
warnings.warn(
"The store_serialized_dags parameter has been deprecated. "
Expand Down
7 changes: 7 additions & 0 deletions tests/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow import models
from airflow.cli import cli_parser
from airflow.executors import celery_executor, celery_kubernetes_executor
from tests.test_utils.config import conf_vars

# Create custom executors here because conftest is imported first
custom_executor_module = type(sys)('custom_executor')
Expand All @@ -36,6 +37,12 @@
sys.modules['custom_executor'] = custom_executor_module


@pytest.fixture(autouse=True)
def load_examples():
with conf_vars({('core', 'load_examples'): 'True'}):
yield


@pytest.fixture(scope="session")
def dagbag():
return models.DagBag(include_examples=True)
Expand Down
10 changes: 8 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def dagbag():
return DagBag(read_dags_from_db=True)


@pytest.fixture
def load_examples():
with conf_vars({('core', 'load_examples'): 'True'}):
yield


@pytest.mark.usefixtures("disable_load_example")
@pytest.mark.need_serialized_dag
class TestSchedulerJob:
Expand Down Expand Up @@ -4014,7 +4020,7 @@ def test_find_zombies_nothing(self):

self.scheduler_job.executor.callback_sink.send.assert_not_called()

def test_find_zombies(self):
def test_find_zombies(self, load_examples):
dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
with create_session() as session:
session.query(LocalTaskJob).delete()
Expand Down Expand Up @@ -4072,7 +4078,7 @@ def test_find_zombies(self):
session.query(TaskInstance).delete()
session.query(LocalTaskJob).delete()

def test_zombie_message(self):
def test_zombie_message(self, load_examples):
"""
Check that the zombie message comes out as expected
"""
Expand Down