Skip to content

Commit

Permalink
Merge pull request #541 from davidism/tablename
Browse files Browse the repository at this point in the history
rewrite tablename generation again
  • Loading branch information
davidism authored Sep 26, 2017
2 parents f1e5852 + 7134764 commit f93c737
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 53 deletions.
96 changes: 54 additions & 42 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase

from ._compat import iteritems, itervalues, string_types, xrange
from ._compat import itervalues, string_types, xrange

__version__ = '2.2.1'

Expand Down Expand Up @@ -551,31 +551,39 @@ def get_engine(self):


def _should_set_tablename(cls):
"""Traverse the model's MRO. If a primary key column is found before a
table or tablename, then a new tablename should be generated.
This supports:
* Joined table inheritance without explicitly naming sub-models.
* Single table inheritance.
* Inheriting from mixins or abstract models.
:param cls: model to check
:return: True if tablename should be set
"""Determine whether ``__tablename__`` should be automatically generated
for a model.
* If no class in the MRO sets a name, one should be generated.
* If a declared attr is found, it should be used instead.
* If a name is found, it should be used if the class is a mixin, otherwise
one should be generated.
* Abstract models should not have one generated.
Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the
model looks like single or joined-table inheritance. If no primary key is
found, the name will be unset.
"""
if (
cls.__dict__.get('__abstract__', False)
or not any(isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:])
):
return False

for base in cls.__mro__:
d = base.__dict__
if '__tablename__' not in base.__dict__:
continue

if '__tablename__' in d or '__table__' in d:
if isinstance(base.__dict__['__tablename__'], declared_attr):
return False

for name, obj in iteritems(d):
if isinstance(obj, declared_attr):
obj = getattr(cls, name)
return not (
base is cls
or base.__dict__.get('__abstract__', False)
or not isinstance(base, DeclarativeMeta)
)

if isinstance(obj, sqlalchemy.Column) and obj.primary_key:
return True
return True


def camel_to_snake_case(name):
Expand All @@ -591,20 +599,36 @@ def _join(match):


class _BoundDeclarativeMeta(DeclarativeMeta):
def __new__(cls, name, bases, d):
# if tablename is set explicitly, move it to the cache attribute so
# that future subclasses still have auto behavior
if '__tablename__' in d:
d['_cached_tablename'] = d.pop('__tablename__')
def __init__(cls, name, bases, d):
if _should_set_tablename(cls):
cls.__tablename__ = camel_to_snake_case(cls.__name__)

bind_key = (
d.pop('__bind_key__', None)
or getattr(cls, '__bind_key__', None)
)

return DeclarativeMeta.__new__(cls, name, bases, d)
super(_BoundDeclarativeMeta, cls).__init__(name, bases, d)

def __init__(self, name, bases, d):
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
DeclarativeMeta.__init__(self, name, bases, d)
if bind_key is not None and hasattr(cls, '__table__'):
cls.__table__.info['bind_key'] = bind_key

if bind_key is not None and hasattr(self, '__table__'):
self.__table__.info['bind_key'] = bind_key
def __table_cls__(cls, *args, **kwargs):
"""This is called by SQLAlchemy during mapper setup. It determines the
final table object that the model will use.
If no primary key is found, that indicates single-table inheritance,
so no table will be created and ``__tablename__`` will be unset.
"""
for arg in args:
if (
(isinstance(arg, sqlalchemy.Column) and arg.primary_key)
or isinstance(arg, sqlalchemy.PrimaryKeyConstraint)
):
return sqlalchemy.Table(*args, **kwargs)

if '__tablename__' in cls.__dict__:
del cls.__tablename__


def get_state(app):
Expand Down Expand Up @@ -638,18 +662,6 @@ class Model(object):
#: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed.
query = None

_cached_tablename = None

@declared_attr
def __tablename__(cls):
if (
'_cached_tablename' not in cls.__dict__ and
_should_set_tablename(cls)
):
cls._cached_tablename = camel_to_snake_case(cls.__name__)

return cls._cached_tablename


class SQLAlchemy(object):
"""This class is used to control the SQLAlchemy integration to one
Expand Down
92 changes: 81 additions & 11 deletions tests/test_table_name.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

from sqlalchemy.ext.declarative import declared_attr


Expand Down Expand Up @@ -25,6 +27,7 @@ class Duck(db.Model):
class Mallard(Duck):
pass

assert '__tablename__' not in Mallard.__dict__
assert Mallard.__tablename__ == 'duck'


Expand All @@ -39,8 +42,10 @@ class Donald(Duck):
assert Donald.__tablename__ == 'donald'


def test_mixin_name(db):
"""Primary key provided by mixin should still allow model to set tablename."""
def test_mixin_id(db):
"""Primary key provided by mixin should still allow model to set
tablename.
"""
class Base(object):
id = db.Column(db.Integer, primary_key=True)

Expand All @@ -51,28 +56,57 @@ class Duck(Base, db.Model):
assert Duck.__tablename__ == 'duck'


def test_mixin_attr(db):
"""A declared attr tablename will be used down multiple levels of
inheritance.
"""
class Mixin(object):
@declared_attr
def __tablename__(cls):
return cls.__name__.upper()

class Bird(Mixin, db.Model):
id = db.Column(db.Integer, primary_key=True)

class Duck(Bird):
# object reference
id = db.Column(db.ForeignKey(Bird.id), primary_key=True)

class Mallard(Duck):
# string reference
id = db.Column(db.ForeignKey('DUCK.id'), primary_key=True)

assert Bird.__tablename__ == 'BIRD'
assert Duck.__tablename__ == 'DUCK'
assert Mallard.__tablename__ == 'MALLARD'


def test_abstract_name(db):
"""Abstract model should not set a name. Subclass should set a name."""
"""Abstract model should not set a name. Subclass should set a name."""
class Base(db.Model):
__abstract__ = True
id = db.Column(db.Integer, primary_key=True)

class Duck(Base):
pass

assert Base.__tablename__ == 'base'
assert '__tablename__' not in Base.__dict__
assert Duck.__tablename__ == 'duck'


def test_complex_inheritance(db):
"""Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class."""
"""Joined table inheritance, but the new primary key is provided by a
mixin, not directly on the class.
"""
class Duck(db.Model):
id = db.Column(db.Integer, primary_key=True)

class IdMixin(object):
@declared_attr
def id(cls):
return db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True)
return db.Column(
db.Integer, db.ForeignKey(Duck.id), primary_key=True
)

class RubberDuck(IdMixin, Duck):
pass
Expand All @@ -81,18 +115,55 @@ class RubberDuck(IdMixin, Duck):


def test_manual_name(db):
"""Setting a manual name prevents generation for the immediate model. A
name is generated for joined but not single-table inheritance.
"""
class Duck(db.Model):
__tablename__ = 'DUCK'
id = db.Column(db.Integer, primary_key=True)
type = db.Column(db.String)

__mapper_args__ = {
'polymorphic_on': type
}

class Daffy(Duck):
id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True)

__mapper_args__ = {
'polymorphic_identity': 'Warner'
}

class Donald(Duck):
__mapper_args__ = {
'polymorphic_identity': 'Disney'
}

assert Duck.__tablename__ == 'DUCK'
assert Daffy.__tablename__ == 'daffy'
assert '__tablename__' not in Donald.__dict__
assert Donald.__tablename__ == 'DUCK'
# polymorphic condition for single-table query
assert 'WHERE "DUCK".type' in str(Donald.query)


def test_primary_constraint(db):
"""Primary key will be picked up from table args."""
class Duck(db.Model):
id = db.Column(db.Integer)

__table_args__ = (
db.PrimaryKeyConstraint(id),
)

assert Duck.__table__ is not None
assert Duck.__tablename__ == 'duck'


def test_no_access_to_class_property(db):
"""Ensure the implementation doesn't access class properties or declared
attrs while inspecting the unmapped model.
"""
class class_property(object):
def __init__(self, f):
self.f = f
Expand All @@ -106,14 +177,13 @@ class Duck(db.Model):
class ns(object):
accessed = False

# Since there's no id provided by the following model,
# _should_set_tablename will scan all attributes. If it's working
# properly, it won't access the class property, but will access the
# declared_attr.

class Witch(Duck):
@declared_attr
def is_duck(self):
# declared attrs will be accessed during mapper configuration,
# but make sure they're not accessed before that
info = inspect.getouterframes(inspect.currentframe())[2]
assert info[3] != '_should_set_tablename'
ns.accessed = True

@class_property
Expand Down

0 comments on commit f93c737

Please sign in to comment.