Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix schema validation and add custom validators #3220

Merged
merged 3 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 65 additions & 8 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,26 @@ def run(self):

$ luigi --module my_tasks MyTask --tags '{"role": "UNKNOWN_VALUE", "env": "staging"}'

Finally, the provided schema can be a custom validator:

.. code-block:: python

custom_validator = jsonschema.Draft4Validator(
schema={
"type": "object",
"patternProperties": {
".*": {"type": "string", "enum": ["web", "staging"]},
}
}
)

class MyTask(luigi.Task):
tags = luigi.DictParameter(schema=custom_validator)

def run(self):
logging.info("Find server with role: %s", self.tags['role'])
server = aws.ec2.find_my_resource(self.tags)

"""

def __init__(
Expand All @@ -1105,7 +1125,9 @@ def __init__(
"The 'jsonschema' package is not installed so the parameter can not be validated "
"even though a schema is given."
)
self.schema = schema
self.schema = None
else:
self.schema = schema
super().__init__(
*args,
**kwargs,
Expand All @@ -1115,10 +1137,14 @@ def normalize(self, value):
"""
Ensure that dictionary parameter is converted to a FrozenOrderedDict so it can be hashed.
"""
frozen_value = recursively_freeze(value)
if self.schema is not None:
jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema)
return frozen_value
unfrozen_value = recursively_unfreeze(value)
try:
self.schema.validate(unfrozen_value)
value = unfrozen_value # Validators may update the instance inplace
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I defined the unfrozen_value variable for clarity but we could just overwrite the value variable at line 1141. What do you prefer?

except AttributeError:
jsonschema.validate(instance=unfrozen_value, schema=self.schema)
return recursively_freeze(value)

def parse(self, source):
"""
Expand Down Expand Up @@ -1212,6 +1238,31 @@ def run(self):
$ luigi --module my_tasks MyTask --numbers '[]' # must have at least 1 element
$ luigi --module my_tasks MyTask --numbers '[-999, 999]' # elements must be in [0, 10]

Finally, the provided schema can be a custom validator:

.. code-block:: python

custom_validator = jsonschema.Draft4Validator(
schema={
"type": "array",
"items": {
"type": "number",
"minimum": 0,
"maximum": 10
},
"minItems": 1
}
)

class MyTask(luigi.Task):
grades = luigi.ListParameter(schema=custom_validator)

def run(self):
sum = 0
for element in self.grades:
sum += element
avg = sum / len(self.grades)

"""

def __init__(
Expand All @@ -1225,7 +1276,9 @@ def __init__(
"The 'jsonschema' package is not installed so the parameter can not be validated "
"even though a schema is given."
)
self.schema = schema
self.schema = None
else:
self.schema = schema
super().__init__(
*args,
**kwargs,
Expand All @@ -1238,10 +1291,14 @@ def normalize(self, x):
:param str x: the value to parse.
:return: the normalized (hashable/immutable) value.
"""
frozen_value = recursively_freeze(x)
if self.schema is not None:
jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema)
return frozen_value
unfrozen_value = recursively_unfreeze(x)
try:
self.schema.validate(unfrozen_value)
x = unfrozen_value # Validators may update the instance inplace
except AttributeError:
jsonschema.validate(instance=unfrozen_value, schema=self.schema)
return recursively_freeze(x)

def parse(self, x):
"""
Expand Down
20 changes: 20 additions & 0 deletions test/dict_parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from helpers import unittest, in_parse

Expand Down Expand Up @@ -113,6 +114,7 @@ def test_schema(self):
with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"):
b.normalize({"role": "UNKNOWN_VALUE", "env": "staging"})

# Check that warnings are properly emitted
with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False):
with pytest.warns(
UserWarning,
Expand All @@ -122,3 +124,21 @@ def test_schema(self):
)
):
luigi.ListParameter(schema={"type": "object"})

# Test with a custom validator
validator = Draft4Validator(
schema={
"type": "object",
"patternProperties": {
".*": {"type": "string", "enum": ["web", "staging"]},
},
}
)
c = luigi.DictParameter(schema=validator)
c.normalize({"role": "web", "env": "staging"})
with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"):
c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"})

# Test with frozen data
frozen_data = luigi.freezing.recursively_freeze({"role": "web", "env": "staging"})
c.normalize(frozen_data)
28 changes: 24 additions & 4 deletions test/list_parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from helpers import unittest, in_parse

Expand Down Expand Up @@ -76,10 +77,7 @@ def test_schema(self):
)

# Check that the default value is validated
with pytest.raises(
ValidationError,
match=r"'INVALID_ATTRIBUTE' is not of type 'number'",
):
with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'"):
a.normalize(["INVALID_ATTRIBUTE"])

# Check that empty list is not valid
Expand All @@ -100,6 +98,7 @@ def test_schema(self):
with pytest.raises(ValidationError, match="-999 is less than the minimum of 0"):
a.normalize(invalid_list_value)

# Check that warnings are properly emitted
with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False):
with pytest.warns(
UserWarning,
Expand All @@ -109,3 +108,24 @@ def test_schema(self):
)
):
luigi.ListParameter(schema={"type": "array", "items": {"type": "number"}})

# Test with a custom validator
validator = Draft4Validator(
schema={
"type": "array",
"items": {
"type": "number",
"minimum": 0,
"maximum": 10,
},
"minItems": 1,
}
)
c = luigi.DictParameter(schema=validator)
c.normalize(valid_list)
with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'",):
c.normalize(["INVALID_ATTRIBUTE"])

# Test with frozen data
frozen_data = luigi.freezing.recursively_freeze(valid_list)
c.normalize(frozen_data)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ deps =
postgres: psycopg2<3.0
postgres: pg8000>=1.23.0
mysql-connector-python>=8.0.12
py35,py36: mysql-connector-python<8.0.32
gcloud: google-api-python-client>=1.6.6,<2.0
avro-python3
gcloud: google-auth==1.4.1
Expand Down