From 03036d5e97faf885d1b4da92af2b6695a53c9196 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Fri, 3 Jun 2022 10:38:56 +0200 Subject: [PATCH] Issue a warning when a parameter is not consumed by a task --- luigi/parameter.py | 4 ++++ luigi/task.py | 19 ++++++++++++++++++ test/task_test.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/luigi/parameter.py b/luigi/parameter.py index 9a9de6b441..3f5e73aed7 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -95,6 +95,10 @@ class OptionalParameterTypeWarning(UserWarning): pass +class UnconsumedParameterWarning(UserWarning): + """Warning class for parameters that are not consumed by the task.""" + + class Parameter: """ Parameter whose value is a ``str``, and a base class for other parameter types. diff --git a/luigi/task.py b/luigi/task.py index b24cac16e5..81956edb87 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -32,9 +32,11 @@ import luigi +from luigi import configuration from luigi import parameter from luigi.task_register import Register from luigi.parameter import ParameterVisibility +from luigi.parameter import UnconsumedParameterWarning Parameter = parameter.Parameter logger = logging.getLogger('luigi-interface') @@ -429,6 +431,23 @@ def list_to_tuple(x): return tuple(x) else: return x + + # Check for unconsumed parameters + conf = configuration.get_config() + if not hasattr(cls, "_unconsumed_params"): + cls._unconsumed_params = set() + if task_family in conf.sections(): + for key, value in conf[task_family].items(): + composite_key = f"{task_family}_{key}" + if key not in result and composite_key not in cls._unconsumed_params: + warnings.warn( + "The configuration contains the parameter " + f"'{key}' with value '{value}' that is not consumed by the task " + f"'{task_family}'.", + UnconsumedParameterWarning, + ) + cls._unconsumed_params.add(composite_key) + # Sort it by the correct order and make a list return [(param_name, list_to_tuple(result[param_name])) for param_name, param_obj in params] diff --git a/test/task_test.py b/test/task_test.py index 47822de546..c3f569423b 100644 --- a/test/task_test.py +++ b/test/task_test.py @@ -169,6 +169,55 @@ class ATaskWithBadParam(luigi.Task): with self.assertRaisesRegex(ValueError, r"ATaskWithBadParam\[args=\(\), kwargs={}\]: Error when parsing the default value of 'bad_param'"): ATaskWithBadParam() + @with_config( + { + "TaskA": { + "a": "a", + "b": "b", + "c": "c", + }, + "TaskB": { + "a": "a", + "b": "b", + "c": "c", + }, + } + ) + def test_unconsumed_params(self): + class TaskA(luigi.Task): + a = luigi.Parameter(default="a") + + class TaskB(luigi.Task): + a = luigi.Parameter(default="a") + + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings( + action="ignore", + category=Warning, + ) + warnings.simplefilter( + action="always", + category=luigi.parameter.UnconsumedParameterWarning, + ) + + TaskA() + TaskB() + + assert len(w) == 4 + expected = [ + ("b", "TaskA"), + ("c", "TaskA"), + ("b", "TaskB"), + ("c", "TaskB"), + ] + for i, (expected_value, task_name) in zip(w, expected): + assert issubclass(i.category, luigi.parameter.UnconsumedParameterWarning) + assert str(i.message) == ( + "The configuration contains the parameter " + f"'{expected_value}' with value '{expected_value}' that is not consumed by " + f"the task '{task_name}'." + ) + class ExternalizeTaskTest(LuigiTestCase):