From e60b153e84717bdb948a1a12aaed8f6baa8dd6ef Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Tue, 17 Jan 2023 09:48:58 +0100 Subject: [PATCH 1/3] Fix schema validation and add custom validators --- luigi/parameter.py | 67 ++++++++++++++++++++++++++++++++++--- test/dict_parameter_test.py | 16 +++++++++ test/list_parameter_test.py | 24 ++++++++++--- 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/luigi/parameter.py b/luigi/parameter.py index 0ec0dcee0c..fe2e5ac2dd 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -1092,6 +1092,26 @@ def run(self): $ luigi --module my_tasks MyTask --tags '{"role": "UNKNOWN_VALUE", "env": "staging"}' + Finally, the provided schema can be a custom validator: + + .. code-block:: python + + custom_validator = jsonschema.Draft4Validator( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + } + } + ) + + class MyTask(luigi.Task): + tags = luigi.DictParameter(schema=custom_validator) + + def run(self): + logging.info("Find server with role: %s", self.tags['role']) + server = aws.ec2.find_my_resource(self.tags) + """ def __init__( @@ -1105,7 +1125,9 @@ def __init__( "The 'jsonschema' package is not installed so the parameter can not be validated " "even though a schema is given." ) - self.schema = schema + self.schema = None + else: + self.schema = schema super().__init__( *args, **kwargs, @@ -1117,7 +1139,12 @@ def normalize(self, value): """ frozen_value = recursively_freeze(value) if self.schema is not None: - jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema) + unfrozen_value = recursively_unfreeze(frozen_value) + try: + self.schema.validate(unfrozen_value) # Validators may update the instance inplace + frozen_value = super().normalize(unfrozen_value) + except AttributeError: + jsonschema.validate(instance=unfrozen_value, schema=self.schema) return frozen_value def parse(self, source): @@ -1212,6 +1239,31 @@ def run(self): $ luigi --module my_tasks MyTask --numbers '[]' # must have at least 1 element $ luigi --module my_tasks MyTask --numbers '[-999, 999]' # elements must be in [0, 10] + Finally, the provided schema can be a custom validator: + + .. code-block:: python + + custom_validator = jsonschema.Draft4Validator( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10 + }, + "minItems": 1 + } + ) + + class MyTask(luigi.Task): + grades = luigi.ListParameter(schema=custom_validator) + + def run(self): + sum = 0 + for element in self.grades: + sum += element + avg = sum / len(self.grades) + """ def __init__( @@ -1225,7 +1277,9 @@ def __init__( "The 'jsonschema' package is not installed so the parameter can not be validated " "even though a schema is given." ) - self.schema = schema + self.schema = None + else: + self.schema = schema super().__init__( *args, **kwargs, @@ -1240,7 +1294,12 @@ def normalize(self, x): """ frozen_value = recursively_freeze(x) if self.schema is not None: - jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema) + unfrozen_value = recursively_unfreeze(frozen_value) + try: + self.schema.validate(unfrozen_value) # Validators may update the instance inplace + frozen_value = super().normalize(unfrozen_value) + except AttributeError: + jsonschema.validate(instance=unfrozen_value, schema=self.schema) return frozen_value def parse(self, x): diff --git a/test/dict_parameter_test.py b/test/dict_parameter_test.py index ae4e1bc07d..6b1d98493d 100644 --- a/test/dict_parameter_test.py +++ b/test/dict_parameter_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +from jsonschema import Draft4Validator from jsonschema.exceptions import ValidationError from helpers import unittest, in_parse @@ -113,6 +114,7 @@ def test_schema(self): with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): b.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) + # Check that warnings are properly emitted with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False): with pytest.warns( UserWarning, @@ -122,3 +124,17 @@ def test_schema(self): ) ): luigi.ListParameter(schema={"type": "object"}) + + # Test with a custom validator + validator = Draft4Validator( + schema={ + "type": "object", + "patternProperties": { + ".*": {"type": "string", "enum": ["web", "staging"]}, + }, + } + ) + c = luigi.DictParameter(schema=validator) + c.normalize({"role": "web", "env": "staging"}) + with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): + c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) diff --git a/test/list_parameter_test.py b/test/list_parameter_test.py index 51526fb083..9e27876193 100644 --- a/test/list_parameter_test.py +++ b/test/list_parameter_test.py @@ -15,6 +15,7 @@ # limitations under the License. # +from jsonschema import Draft4Validator from jsonschema.exceptions import ValidationError from helpers import unittest, in_parse @@ -76,10 +77,7 @@ def test_schema(self): ) # Check that the default value is validated - with pytest.raises( - ValidationError, - match=r"'INVALID_ATTRIBUTE' is not of type 'number'", - ): + with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'"): a.normalize(["INVALID_ATTRIBUTE"]) # Check that empty list is not valid @@ -100,6 +98,7 @@ def test_schema(self): with pytest.raises(ValidationError, match="-999 is less than the minimum of 0"): a.normalize(invalid_list_value) + # Check that warnings are properly emitted with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False): with pytest.warns( UserWarning, @@ -109,3 +108,20 @@ def test_schema(self): ) ): luigi.ListParameter(schema={"type": "array", "items": {"type": "number"}}) + + # Test with a custom validator + validator = Draft4Validator( + schema={ + "type": "array", + "items": { + "type": "number", + "minimum": 0, + "maximum": 10, + }, + "minItems": 1, + } + ) + c = luigi.DictParameter(schema=validator) + c.normalize(valid_list) + with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'",): + c.normalize(["INVALID_ATTRIBUTE"]) From e7b2d91577b9992701d2616790fa88ebff7986a1 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Tue, 17 Jan 2023 10:18:11 +0100 Subject: [PATCH 2/3] Fix CI --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index b63cc4e50b..ff75cf6f12 100644 --- a/tox.ini +++ b/tox.ini @@ -40,6 +40,7 @@ deps = postgres: psycopg2<3.0 postgres: pg8000>=1.23.0 mysql-connector-python>=8.0.12 + py35,py36: mysql-connector-python<8.0.32 gcloud: google-api-python-client>=1.6.6,<2.0 avro-python3 gcloud: google-auth==1.4.1 From 53148b62fe1908d831e5ea59320e2b50a86113f1 Mon Sep 17 00:00:00 2001 From: Adrien Berchet Date: Tue, 17 Jan 2023 14:09:56 +0100 Subject: [PATCH 3/3] Simplify freezing processes --- luigi/parameter.py | 18 ++++++++---------- test/dict_parameter_test.py | 4 ++++ test/list_parameter_test.py | 4 ++++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/luigi/parameter.py b/luigi/parameter.py index fe2e5ac2dd..3278377f1c 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -1137,15 +1137,14 @@ def normalize(self, value): """ Ensure that dictionary parameter is converted to a FrozenOrderedDict so it can be hashed. """ - frozen_value = recursively_freeze(value) if self.schema is not None: - unfrozen_value = recursively_unfreeze(frozen_value) + unfrozen_value = recursively_unfreeze(value) try: - self.schema.validate(unfrozen_value) # Validators may update the instance inplace - frozen_value = super().normalize(unfrozen_value) + self.schema.validate(unfrozen_value) + value = unfrozen_value # Validators may update the instance inplace except AttributeError: jsonschema.validate(instance=unfrozen_value, schema=self.schema) - return frozen_value + return recursively_freeze(value) def parse(self, source): """ @@ -1292,15 +1291,14 @@ def normalize(self, x): :param str x: the value to parse. :return: the normalized (hashable/immutable) value. """ - frozen_value = recursively_freeze(x) if self.schema is not None: - unfrozen_value = recursively_unfreeze(frozen_value) + unfrozen_value = recursively_unfreeze(x) try: - self.schema.validate(unfrozen_value) # Validators may update the instance inplace - frozen_value = super().normalize(unfrozen_value) + self.schema.validate(unfrozen_value) + x = unfrozen_value # Validators may update the instance inplace except AttributeError: jsonschema.validate(instance=unfrozen_value, schema=self.schema) - return frozen_value + return recursively_freeze(x) def parse(self, x): """ diff --git a/test/dict_parameter_test.py b/test/dict_parameter_test.py index 6b1d98493d..3dd3306dc4 100644 --- a/test/dict_parameter_test.py +++ b/test/dict_parameter_test.py @@ -138,3 +138,7 @@ def test_schema(self): c.normalize({"role": "web", "env": "staging"}) with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"): c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"}) + + # Test with frozen data + frozen_data = luigi.freezing.recursively_freeze({"role": "web", "env": "staging"}) + c.normalize(frozen_data) diff --git a/test/list_parameter_test.py b/test/list_parameter_test.py index 9e27876193..26204e48cf 100644 --- a/test/list_parameter_test.py +++ b/test/list_parameter_test.py @@ -125,3 +125,7 @@ def test_schema(self): c.normalize(valid_list) with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'",): c.normalize(["INVALID_ATTRIBUTE"]) + + # Test with frozen data + frozen_data = luigi.freezing.recursively_freeze(valid_list) + c.normalize(frozen_data)