diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index cf99e6d1906f4b..1926885f3735f8 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -184,13 +184,23 @@ def is_valid_plugin(plugin_obj): return False +def register_plugin(plugin_instance): + """ + Start plugin load and register it after success initialization + + :param plugin_instance: subclass of AirflowPlugin + """ + global plugins # pylint: disable=global-statement + plugin_instance.on_load() + plugins.append(plugin_instance) + + def load_entrypoint_plugins(): """ Load and register plugins AirflowPlugin subclasses from the entrypoints. The entry_point group should be 'airflow.plugins'. """ global import_errors # pylint: disable=global-statement - global plugins # pylint: disable=global-statement log.debug("Loading plugins from entrypoints") @@ -202,10 +212,8 @@ def load_entrypoint_plugins(): continue plugin_instance = plugin_class() - if callable(getattr(plugin_instance, 'on_load', None)): - plugin_instance.on_load() - plugin_instance.source = EntryPointSource(entry_point, dist) - plugins.append(plugin_instance) + plugin_instance.source = EntryPointSource(entry_point, dist) + register_plugin(plugin_instance) except Exception as e: # pylint: disable=broad-except log.exception("Failed to import plugin %s", entry_point.name) import_errors[entry_point.module] = str(e) @@ -214,11 +222,9 @@ def load_entrypoint_plugins(): def load_plugins_from_plugin_directory(): """Load and register Airflow Plugins from plugins directory""" global import_errors # pylint: disable=global-statement - global plugins # pylint: disable=global-statement log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER) for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"): - if not os.path.isfile(file_path): continue mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1]) @@ -236,8 +242,7 @@ def load_plugins_from_plugin_directory(): for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)): plugin_instance = mod_attr_value() plugin_instance.source = PluginsDirectorySource(file_path) - plugins.append(plugin_instance) - + register_plugin(plugin_instance) except Exception as e: # pylint: disable=broad-except log.exception(e) log.error('Failed to import plugin %s', file_path) diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py index e233ee95c4db84..63c9a91b814315 100644 --- a/tests/plugins/test_plugin.py +++ b/tests/plugins/test_plugin.py @@ -127,3 +127,10 @@ class MockPluginB(AirflowPlugin): class MockPluginC(AirflowPlugin): name = 'plugin-c' + + +class AirflowTestOnLoadPlugin(AirflowPlugin): + name = 'preload' + + def on_load(self, *args, **kwargs): + self.name = 'postload' diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index f460b29b583802..778cf080e80f40 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -17,7 +17,9 @@ # under the License. import importlib import logging +import os import sys +import tempfile from unittest import mock import pytest @@ -25,11 +27,22 @@ from airflow.hooks.base import BaseHook from airflow.plugins_manager import AirflowPlugin from airflow.www import app as application +from tests.test_utils.config import conf_vars from tests.test_utils.mock_plugins import mock_plugin_manager py39 = sys.version_info >= (3, 9) importlib_metadata = 'importlib.metadata' if py39 else 'importlib_metadata' +ON_LOAD_EXCEPTION_PLUGIN = """ +from airflow.plugins_manager import AirflowPlugin + +class AirflowTestOnLoadExceptionPlugin(AirflowPlugin): + name = 'preload' + + def on_load(self, *args, **kwargs): + raise Exception("oops") +""" + class TestPluginsRBAC: @pytest.fixture(autouse=True) @@ -146,6 +159,40 @@ class TestPropertyHook(BaseHook): assert caplog.records[-1].levelname == 'DEBUG' assert caplog.records[-1].msg == 'Loading %d plugin(s) took %.2f seconds' + def test_loads_filesystem_plugins(self, caplog): + from airflow import plugins_manager + + with mock.patch('airflow.plugins_manager.plugins', []): + plugins_manager.load_plugins_from_plugin_directory() + + assert 5 == len(plugins_manager.plugins) + for plugin in plugins_manager.plugins: + if 'AirflowTestOnLoadPlugin' not in str(plugin): + continue + assert 'postload' == plugin.name + break + else: + pytest.fail("Wasn't able to find a registered `AirflowTestOnLoadPlugin`") + + assert caplog.record_tuples == [] + + def test_loads_filesystem_plugins_exception(self, caplog): + from airflow import plugins_manager + + with mock.patch('airflow.plugins_manager.plugins', []): + with tempfile.TemporaryDirectory() as tmpdir: + with open(os.path.join(tmpdir, 'testplugin.py'), "w") as f: + f.write(ON_LOAD_EXCEPTION_PLUGIN) + + with conf_vars({('core', 'plugins_folder'): tmpdir}): + plugins_manager.load_plugins_from_plugin_directory() + + assert plugins_manager.plugins == [] + + received_logs = caplog.text + assert 'Failed to import plugin' in received_logs + assert 'testplugin.py' in received_logs + def test_should_warning_about_incompatible_plugins(self, caplog): class AirflowAdminViewsPlugin(AirflowPlugin): name = "test_admin_views_plugin"