From 35845fcf461225262cdf476fed561fb4cbda71a1 Mon Sep 17 00:00:00 2001 From: Pablo Saiz Date: Mon, 22 Apr 2024 15:49:58 +0200 Subject: [PATCH] marhsmallow: remove deprecation warning --- invenio_records_rest/loaders/marshmallow.py | 5 +++-- invenio_records_rest/schemas/fields/generated.py | 5 +++-- invenio_records_rest/schemas/json.py | 8 ++++---- invenio_records_rest/utils.py | 5 +++++ tests/test_custom_fields.py | 8 ++++---- tests/test_marshmallow_loader.py | 7 +++---- 6 files changed, 22 insertions(+), 16 deletions(-) diff --git a/invenio_records_rest/loaders/marshmallow.py b/invenio_records_rest/loaders/marshmallow.py index e162614..35e04be 100644 --- a/invenio_records_rest/loaders/marshmallow.py +++ b/invenio_records_rest/loaders/marshmallow.py @@ -18,7 +18,8 @@ from flask import request from invenio_rest.errors import RESTValidationError from marshmallow import ValidationError -from marshmallow import __version_info__ as marshmallow_version + +from ..utils import marshmallow_major_version def _flatten_marshmallow_errors(errors, parents=()): @@ -81,7 +82,7 @@ def json_loader(): pid, record = pid_data.data context["pid"] = pid context["record"] = record - if marshmallow_version[0] < 3: + if marshmallow_major_version < 3: result = schema_class(context=context).load(request_json) if result.errors: raise MarshmallowErrors(result.errors) diff --git a/invenio_records_rest/schemas/fields/generated.py b/invenio_records_rest/schemas/fields/generated.py index c5712c6..d19d6df 100644 --- a/invenio_records_rest/schemas/fields/generated.py +++ b/invenio_records_rest/schemas/fields/generated.py @@ -10,9 +10,10 @@ import warnings -from marshmallow import __version_info__ as marshmallow_version from marshmallow import missing as missing_ +from invenio_records_rest.utils import marshmallow_major_version + from .marshmallow_contrib import Function, Method @@ -25,7 +26,7 @@ class GeneratedValue(object): class ForcedFieldDeserializeMixin(object): """Mixin that forces deserialization of marshmallow fields.""" - if marshmallow_version[0] < 3: + if marshmallow_major_version < 3: def __init__(self, *args, **kwargs): """Override the "missing" parameter.""" diff --git a/invenio_records_rest/schemas/json.py b/invenio_records_rest/schemas/json.py index 9d99f35..20bd404 100644 --- a/invenio_records_rest/schemas/json.py +++ b/invenio_records_rest/schemas/json.py @@ -10,12 +10,12 @@ from flask import current_app from invenio_rest.serializer import BaseSchema as Schema -from marshmallow import ValidationError -from marshmallow import __version_info__ as marshmallow_version -from marshmallow import fields, missing, post_load, validates_schema +from marshmallow import ValidationError, fields, missing, post_load, validates_schema from invenio_records_rest.schemas.fields import PersistentIdentifier +from ..utils import marshmallow_major_version + class StrictKeysMixin(Schema): """Ensure only valid keys exists.""" @@ -74,7 +74,7 @@ def load_unknown_fields(self, data, original_data): return data -if marshmallow_version[0] < 3: +if marshmallow_major_version < 3: class RecordMetadataSchemaJSONV1(OriginalKeysMixin): """Schema for records metadata v1 in JSON with injected PID value.""" diff --git a/invenio_records_rest/utils.py b/invenio_records_rest/utils.py index 0387584..1bcba3a 100644 --- a/invenio_records_rest/utils.py +++ b/invenio_records_rest/utils.py @@ -10,6 +10,7 @@ from functools import partial +import pkg_resources import six from flask import abort, current_app, jsonify, make_response, request, url_for from invenio_pidstore.errors import ( @@ -33,6 +34,10 @@ ) from .proxies import current_records_rest +marshmallow_major_version = int( + pkg_resources.get_distribution("marshmallow").version[0] +) + def build_default_endpoint_prefixes(records_rest_endpoints): """Build the default_endpoint_prefixes map.""" diff --git a/tests/test_custom_fields.py b/tests/test_custom_fields.py index c3cca7a..625988d 100644 --- a/tests/test_custom_fields.py +++ b/tests/test_custom_fields.py @@ -12,7 +12,6 @@ from invenio_pidstore.models import PersistentIdentifier as PIDModel from invenio_records import Record from invenio_rest.serializer import BaseSchema as Schema -from marshmallow import __version_info__ as marshmallow_version from marshmallow import missing from invenio_records_rest.schemas import StrictKeysMixin @@ -25,8 +24,9 @@ SanitizedUnicode, TrimmedString, ) +from invenio_records_rest.utils import marshmallow_major_version -if marshmallow_version[0] >= 3: +if marshmallow_major_version >= 3: schema_to_use = Schema from marshmallow import EXCLUDE else: @@ -36,7 +36,7 @@ class CustomFieldSchema(schema_to_use): """Test schema.""" - if marshmallow_version[0] >= 3: + if marshmallow_major_version >= 3: class Meta: """.""" @@ -107,7 +107,7 @@ def deserialize_func(value, ctx, data): class GeneratedFieldsSchema(schema_to_use): """Test schema.""" - if marshmallow_version[0] >= 3: + if marshmallow_major_version >= 3: class Meta: """Meta attributes for the schema.""" diff --git a/tests/test_marshmallow_loader.py b/tests/test_marshmallow_loader.py index d905d51..5f4ca01 100644 --- a/tests/test_marshmallow_loader.py +++ b/tests/test_marshmallow_loader.py @@ -14,9 +14,7 @@ from helpers import get_json from invenio_records.models import RecordMetadata from invenio_rest.serializer import BaseSchema as Schema -from marshmallow import ValidationError -from marshmallow import __version_info__ as marshmallow_version -from marshmallow import fields +from marshmallow import ValidationError, fields from invenio_records_rest.loaders import json_pid_checker from invenio_records_rest.loaders.marshmallow import ( @@ -25,6 +23,7 @@ ) from invenio_records_rest.schemas import Nested from invenio_records_rest.schemas.fields import PersistentIdentifier +from invenio_records_rest.utils import marshmallow_major_version class _TestSchema(Schema): @@ -166,7 +165,7 @@ def has_error(field, parents): def test_marshmallow_errors(test_data): """Test MarshmallowErrors class.""" incomplete_data = dict(test_data[0]) - if marshmallow_version[0] >= 3: + if marshmallow_major_version >= 3: try: res = _TestSchema(context={}).load(json.dumps(incomplete_data)) except ValidationError as error: