From b0c57df5470a57ea9886d81118b8f9f29c7d8c9c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 6 Mar 2019 11:01:25 +0000 Subject: [PATCH] [AIRFLOW-2888] Add deprecation path for task_runner config change This change allows users to seamlessly upgrade without a hard-to-debug error when a task is actually run. This allows us to pull the change in to 1.10.3 --- airflow/configuration.py | 34 ++++++++++++++++++++++++++++++- tests/test_configuration.py | 40 ++++++++++++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/airflow/configuration.py b/airflow/configuration.py index 921d088e06a186..3ae1dc61e2ebd8 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -152,6 +152,19 @@ class AirflowConfigParser(ConfigParser): 'setting has been used, but please update your config.' ) + # A mapping of old default values that we want to change and warn the user + # about. Mapping of section -> setting -> { old, replace, by_version } + deprecated_values = { + 'core': { + 'task_runner': ('BashTaskRunner', 'StandardTaskRunner', '2.0'), + }, + } + deprecation_value_format_string = ( + 'The {name} setting in [{section}] has the old default value of {old!r}. This ' + 'value has been changed to {new!r} in the running config, but please ' + 'update your config before Apache Airflow {version}.' + ) + def __init__(self, default_config=None, *args, **kwargs): super(AirflowConfigParser, self).__init__(*args, **kwargs) @@ -169,11 +182,30 @@ def _validate(self): "error: cannot use sqlite with the {}".format( self.get('core', 'executor'))) + for section, replacement in self.deprecated_values.items(): + for name, info in replacement.items(): + old, new, version = info + if self.get(section, name, fallback=None) == old: + # Make sure the env var option is removed, otherwise it + # would be read and used instead of the value we set + env_var = self._env_var_name(section, name) + os.environ.pop(env_var, None) + + self.set(section, name, new) + warnings.warn( + self.deprecation_value_format_string.format(**locals()), + FutureWarning, + ) + self.is_validated = True + @staticmethod + def _env_var_name(section, key): + return 'AIRFLOW__{S}__{K}'.format(S=section.upper(), K=key.upper()) + def _get_env_var_option(self, section, key): # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) - env_var = 'AIRFLOW__{S}__{K}'.format(S=section.upper(), K=key.upper()) + env_var = self._env_var_name(section, key) if env_var in os.environ: return expand_env_var(os.environ[env_var]) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 622c3ffc31fe18..a28b64a61ee342 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -20,8 +20,9 @@ from __future__ import print_function from __future__ import unicode_literals -import os import contextlib +import os +import warnings from collections import OrderedDict import six @@ -339,3 +340,40 @@ def test_deprecated_options_cmd(self): self.assertEqual(conf.getint('celery', 'result_backend'), 99) if tmp: os.environ['AIRFLOW__CELERY__RESULT_BACKEND'] = tmp + + def test_deprecated_values(self): + def make_config(): + test_conf = AirflowConfigParser(default_config='') + # Guarantee we have a deprecated setting, so we test the deprecation + # lookup even if we remove this explicit fallback + test_conf.deprecated_values = { + 'core': { + 'task_runner': ('BashTaskRunner', 'StandardTaskRunner', '2.0'), + }, + } + test_conf.read_dict({ + 'core': { + 'executor': 'SequentialExecutor', + 'task_runner': 'BashTaskRunner', + 'sql_alchemy_conn': 'sqlite://', + }, + }) + return test_conf + + with self.assertWarns(FutureWarning): + test_conf = make_config() + self.assertEqual(test_conf.get('core', 'task_runner'), 'StandardTaskRunner') + + with self.assertWarns(FutureWarning): + with env_vars(AIRFLOW__CORE__TASK_RUNNER='BashTaskRunner'): + test_conf = make_config() + + self.assertEqual(test_conf.get('core', 'task_runner'), 'StandardTaskRunner') + + with warnings.catch_warnings(record=True) as w: + with env_vars(AIRFLOW__CORE__TASK_RUNNER='NotBashTaskRunner'): + test_conf = make_config() + + self.assertEqual(test_conf.get('core', 'task_runner'), 'NotBashTaskRunner') + + self.assertListEqual([], w)