Skip to content

Commit

Permalink
global: move to db.session.query syntax
Browse files Browse the repository at this point in the history
* this change is a working solution for sqlalchemy ~= 1.4 but a
  necessity for >= 2.0
  • Loading branch information
utnapischtim committed Oct 3, 2024
1 parent 834c47f commit 7dd7952
Showing 1 changed file with 60 additions and 35 deletions.
95 changes: 60 additions & 35 deletions invenio_files_rest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of Invenio.
# Copyright (C) 2015-2019 CERN.
# Copyright (C) 2024 Graz University of Technology.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.
Expand Down Expand Up @@ -291,22 +292,26 @@ def validate_name(self, key, name):
@classmethod
def get_by_name(cls, name):
"""Fetch a specific location object."""
return cls.query.filter_by(
name=name,
).one_or_none()
return (
db.session.query(cls)
.filter_by(
name=name,
)
.one_or_none()
)

@classmethod
def get_default(cls):
"""Fetch the default location object."""
try:
return cls.query.filter_by(default=True).one_or_none()
return db.session.query(cls).filter_by(default=True).one_or_none()
except MultipleResultsFound:
return None

@classmethod
def all(cls):
"""Return query that fetches all locations."""
return Location.query.all()
return db.session.query(Location).all()

def __repr__(self):
"""Return representation of location."""
Expand Down Expand Up @@ -550,12 +555,14 @@ def get(cls, bucket_id):
:param bucket_id: Bucket identifier.
:returns: Bucket instance.
"""
return cls.query.filter_by(id=bucket_id, deleted=False).one_or_none()
return (
db.session.query(cls).filter_by(id=bucket_id, deleted=False).one_or_none()
)

@classmethod
def all(cls):
"""Return query of all buckets (excluding deleted)."""
return cls.query.filter_by(deleted=False)
return db.session.query(cls).filter_by(deleted=False)

@classmethod
def delete(cls, bucket_id):
Expand Down Expand Up @@ -620,10 +627,14 @@ class BucketTag(db.Model):
@classmethod
def get(cls, bucket, key):
"""Get tag object."""
return cls.query.filter_by(
bucket_id=as_bucket_id(bucket),
key=key,
).one_or_none()
return (
db.session.query(cls)
.filter_by(
bucket_id=as_bucket_id(bucket),
key=key,
)
.one_or_none()
)

@classmethod
def create(cls, bucket, key, value):
Expand Down Expand Up @@ -654,7 +665,7 @@ def get_value(cls, bucket, key):
def delete(cls, bucket, key):
"""Delete a tag."""
with db.session.begin_nested():
cls.query.filter_by(
db.session.query(cls).filter_by(
bucket_id=as_bucket_id(bucket),
key=key,
).delete()
Expand Down Expand Up @@ -723,13 +734,13 @@ def validate_uri(self, key, uri):
@classmethod
def get(cls, file_id):
"""Get a file instance."""
return cls.query.filter_by(id=file_id).one_or_none()
return db.session.query(cls).filter_by(id=file_id).one_or_none()

@classmethod
def get_by_uri(cls, uri):
"""Get a file instance by URI."""
assert uri is not None
return cls.query.filter_by(uri=uri).one_or_none()
return db.session.query(cls).filter_by(uri=uri).one_or_none()

@classmethod
def create(cls):
Expand Down Expand Up @@ -1260,9 +1271,11 @@ def create(
raise BucketLockedError()

with db.session.begin_nested():
latest_obj = cls.query.filter(
cls.bucket == bucket, cls.key == key, cls.is_head.is_(True)
).one_or_none()
latest_obj = (
db.session.query(cls)
.filter(cls.bucket == bucket, cls.key == key, cls.is_head.is_(True))
.one_or_none()
)
if latest_obj is not None:
latest_obj.is_head = False
db.session.add(latest_obj)
Expand Down Expand Up @@ -1310,7 +1323,7 @@ def get(cls, bucket, key, version_id=None):
filters.append(cls.is_head.is_(True))
filters.append(cls.file_id.isnot(None))

return cls.query.filter(*filters).one_or_none()
return db.session.query(cls).filter(*filters).one_or_none()

@classmethod
def get_versions(cls, bucket, key, desc=True):
Expand All @@ -1328,7 +1341,7 @@ def get_versions(cls, bucket, key, desc=True):

order = cls.created.desc() if desc else cls.created.asc()

return cls.query.filter(*filters).order_by(cls.key, order)
return db.session.query(cls).filter(*filters).order_by(cls.key, order)

@classmethod
def delete(cls, bucket, key):
Expand Down Expand Up @@ -1370,7 +1383,9 @@ def get_by_bucket(cls, bucket, versions=False, with_deleted=False):
if not with_deleted:
filters.append(cls.file_id.isnot(None))

return cls.query.filter(*filters).order_by(cls.key, cls.created.desc())
return (
db.session.query(cls).filter(*filters).order_by(cls.key, cls.created.desc())
)

@classmethod
def relink_all(cls, old_file, new_file):
Expand Down Expand Up @@ -1470,10 +1485,14 @@ def copy(self, object_version=None, key=None):
@classmethod
def get(cls, object_version, key):
"""Get the tag object."""
return cls.query.filter_by(
version_id=as_object_version_id(object_version),
key=key,
).one_or_none()
return (
db.session.query(cls)
.filter_by(
version_id=as_object_version_id(object_version),
key=key,
)
.one_or_none()
)

@classmethod
def create(cls, object_version, key, value):
Expand Down Expand Up @@ -1515,7 +1534,9 @@ def delete(cls, object_version, key=None):
Default: delete all tags.
"""
with db.session.begin_nested():
q = cls.query.filter_by(version_id=as_object_version_id(object_version))
q = db.session.query(cls).filter_by(
version_id=as_object_version_id(object_version)
)
if key:
q = q.filter_by(key=key)
q.delete()
Expand Down Expand Up @@ -1710,7 +1731,7 @@ def create(cls, bucket, key, size, chunk_size):
@classmethod
def get(cls, bucket, key, upload_id, with_completed=False):
"""Fetch a specific multipart object."""
q = cls.query.filter_by(
q = db.session.query(cls).filter_by(
upload_id=upload_id,
bucket_id=as_bucket_id(bucket),
key=key,
Expand All @@ -1723,15 +1744,15 @@ def get(cls, bucket, key, upload_id, with_completed=False):
@classmethod
def query_expired(cls, dt, bucket=None):
"""Query all uncompleted multipart uploads."""
q = cls.query.filter(cls.created < dt).filter_by(completed=True)
q = db.session.query(cls).filter(cls.created < dt).filter_by(completed=True)
if bucket:
q = q.filter(cls.bucket_id == as_bucket_id(bucket))
return q

@classmethod
def query_by_bucket(cls, bucket):
"""Query all uncompleted multipart uploads."""
return cls.query.filter(cls.bucket_id == as_bucket_id(bucket))
return db.session.query(cls).filter(cls.bucket_id == as_bucket_id(bucket))


class Part(db.Model, Timestamp):
Expand Down Expand Up @@ -1792,9 +1813,11 @@ def create(cls, mp, part_number, stream=None, **kwargs):
@classmethod
def get_or_none(cls, mp, part_number):
"""Get part number."""
return cls.query.filter_by(
upload_id=mp.upload_id, part_number=part_number
).one_or_none()
return (
db.session.query(cls)
.filter_by(upload_id=mp.upload_id, part_number=part_number)
.one_or_none()
)

@classmethod
def get_or_create(cls, mp, part_number):
Expand All @@ -1807,9 +1830,11 @@ def get_or_create(cls, mp, part_number):
@classmethod
def delete(cls, mp, part_number):
"""Get part number."""
return cls.query.filter_by(
upload_id=mp.upload_id, part_number=part_number
).delete()
return (
db.session.query(cls)
.filter_by(upload_id=mp.upload_id, part_number=part_number)
.delete()
)

@classmethod
def query_by_multipart(cls, multipart):
Expand All @@ -1822,7 +1847,7 @@ def query_by_multipart(cls, multipart):
upload_id = (
multipart.upload_id if isinstance(multipart, MultipartObject) else multipart
)
return cls.query.filter_by(upload_id=upload_id)
return db.session.query(cls).filter_by(upload_id=upload_id)

@classmethod
def count(cls, mp):
Expand Down

0 comments on commit 7dd7952

Please sign in to comment.