Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add constraint naming convention configuration option #197

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
26 changes: 25 additions & 1 deletion src/sqlacodegen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys
from contextlib import ExitStack
from typing import TextIO
from typing import Sequence, TextIO

from sqlalchemy.engine import create_engine
from sqlalchemy.schema import MetaData
Expand All @@ -14,6 +14,18 @@
from importlib.metadata import entry_points, version


def parse_naming_convs(naming_convs: Sequence[str]) -> dict[str, str]:
d = {}
for naming_conv in naming_convs:
try:
key, value = naming_conv.split("=", 1)
except ValueError:
raise ValueError('Naming convention must be in "key=template" format')

d[key] = value
return d


def main() -> None:
generators = {ep.name: ep for ep in entry_points(group="sqlacodegen.generators")}
parser = argparse.ArgumentParser(
Expand All @@ -40,6 +52,13 @@ def main() -> None:
)
parser.add_argument("--noviews", action="store_true", help="ignore views")
parser.add_argument("--outfile", help="file to write output to (default: stdout)")
parser.add_argument(
"--conv",
nargs="*",
help='constraint naming conventions in "key=template" format \
e.g., --conv "pk=pk_%%(table_name)s" "uq=uq_%%(table_name)s_%%(column_0_name)s"',
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the option to add the naming convention via command line on top of config file (which hasn't been added yet)

I was thinking, for each config option, we can go like this:

  • Check if it is available as command line arg
  • If not, check if it is available in config file
  • If not, fall back to default value

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would users know what to enter here, based on this?

Copy link
Contributor Author

@leonarduschen leonarduschen Apr 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well we assume users interested in using this option are already familiar with SQLAlchemy's constraint naming convention feature from https://docs.sqlalchemy.org/en/14/core/constraints.html#configuring-constraint-naming-conventions or https://alembic.sqlalchemy.org/en/latest/naming.html. In this case, the help string above would enough in my opinion.

We can also add more details in README:

  • Link to SQLAlchemy/Alembic docs above
  • Maybe also basic info about constraint naming convention and its usage in SQLAlchemy, basically summarizing the docs above
  • Things supported in SQLAlchemy but not in sqlacodegen:
    • Using custom function as naming convention token
    • Using the constraint object name as key instead of mnemonic (e.g., PrimaryKeyConstraint cannot be used as key, must use pk)

)

args = parser.parse_args()

if args.version:
Expand All @@ -58,6 +77,11 @@ def main() -> None:
for schema in schemas:
metadata.reflect(engine, schema, not args.noviews, tables)

# Naming convention must be added after reflection to
# avoid the token %(constraint_name)s duplicating the name
if args.conv:
metadata.naming_convention = parse_naming_convs(args.conv)

# Instantiate the generator
generator_class = generators[args.generator].load()
generator = generator_class(metadata, engine, set(args.option or ()))
Expand Down
76 changes: 58 additions & 18 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import TextClause, conv
from sqlalchemy.sql.schema import DEFAULT_NAMING_CONVENTION

from .models import (
ColumnAttribute,
Expand All @@ -56,9 +57,9 @@
get_common_fk_constraints,
get_compiled_expression,
get_constraint_sort_key,
get_explicit_name,
qualified_table_name,
render_callable,
uses_default_name,
)

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -209,22 +210,25 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:

def collect_imports_for_constraint(self, constraint: Constraint | Index) -> None:
if isinstance(constraint, Index):
if len(constraint.columns) > 1 or not uses_default_name(constraint):
if len(constraint.columns) > 1 or get_explicit_name(constraint):
self.add_literal_import("sqlalchemy", "Index")
elif isinstance(constraint, PrimaryKeyConstraint):
if not uses_default_name(constraint):
if get_explicit_name(constraint):
self.add_literal_import("sqlalchemy", "PrimaryKeyConstraint")
elif isinstance(constraint, UniqueConstraint):
if len(constraint.columns) > 1 or not uses_default_name(constraint):
if len(constraint.columns) > 1 or get_explicit_name(constraint):
self.add_literal_import("sqlalchemy", "UniqueConstraint")
elif isinstance(constraint, ForeignKeyConstraint):
if len(constraint.columns) > 1 or not uses_default_name(constraint):
if len(constraint.columns) > 1 or get_explicit_name(constraint):
self.add_literal_import("sqlalchemy", "ForeignKeyConstraint")
else:
self.add_import(ForeignKey)
else:
self.add_import(constraint)

if isinstance(get_explicit_name(constraint), conv):
self.add_literal_import("sqlalchemy.sql.elements", "conv")

def add_import(self, obj: Any) -> None:
# Don't store builtin imports
if getattr(obj, "__module__", "builtins") == "builtins":
Expand Down Expand Up @@ -303,7 +307,13 @@ def generate_model_name(self, model: Model, global_names: set[str]) -> None:
model.name = self.find_free_name(preferred_name, global_names)

def render_module_variables(self, models: list[Model]) -> str:
return "metadata = MetaData()"
module_vars = ["metadata = MetaData()"]
if self.metadata.naming_convention != DEFAULT_NAMING_CONVENTION:
formatted_naming_convention = pformat(self.metadata.naming_convention)
module_vars.append(
f"metadata.naming_convention = {formatted_naming_convention}"
)
return "\n".join(module_vars)

def render_models(self, models: list[Model]) -> str:
rendered = []
Expand All @@ -322,7 +332,7 @@ def render_table(self, table: Table) -> str:
args.append(self.render_column(column, True))

for constraint in sorted(table.constraints, key=get_constraint_sort_key):
if uses_default_name(constraint):
if not get_explicit_name(constraint):
if isinstance(constraint, PrimaryKeyConstraint):
continue
elif isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)):
Expand All @@ -333,7 +343,7 @@ def render_table(self, table: Table) -> str:

for index in sorted(table.indexes, key=lambda i: i.name):
# One-column indexes should be rendered as index=True on columns
if len(index.columns) > 1 or not uses_default_name(index):
if len(index.columns) > 1 or get_explicit_name(index):
args.append(self.render_index(index))

if table.schema:
Expand Down Expand Up @@ -363,26 +373,26 @@ def render_column(self, column: Column[Any], show_name: bool) -> str:
for c in column.foreign_keys
if c.constraint
and len(c.constraint.columns) == 1
and uses_default_name(c.constraint)
and not get_explicit_name(c.constraint)
]
is_unique = any(
isinstance(c, UniqueConstraint)
and set(c.columns) == {column}
and uses_default_name(c)
and not get_explicit_name(c)
for c in column.table.constraints
)
is_unique = is_unique or any(
i.unique and set(i.columns) == {column} and uses_default_name(i)
i.unique and set(i.columns) == {column} and not get_explicit_name(i)
for i in column.table.indexes
)
is_primary = any(
isinstance(c, PrimaryKeyConstraint)
and column.name in c.columns
and uses_default_name(c)
and not get_explicit_name(c)
for c in column.table.constraints
)
has_index = any(
set(i.columns) == {column} and uses_default_name(i)
set(i.columns) == {column} and not get_explicit_name(i)
for i in column.table.indexes
)

Expand Down Expand Up @@ -522,8 +532,13 @@ def add_fk_options(*opts: Any) -> None:
f"Cannot render constraint of type {constraint.__class__.__name__}"
)

if isinstance(constraint, Constraint) and not uses_default_name(constraint):
kwargs["name"] = repr(constraint.name)
if isinstance(constraint, Constraint):
explicit_name = get_explicit_name(constraint)
if explicit_name:
if isinstance(explicit_name, conv):
kwargs["name"] = render_callable(conv.__name__, repr(explicit_name))
else:
kwargs["name"] = repr(explicit_name)

return render_callable(constraint.__class__.__name__, *args, kwargs=kwargs)

Expand Down Expand Up @@ -1017,6 +1032,13 @@ def render_module_variables(self, models: list[Model]) -> str:
return super().render_module_variables(models)

declarations = [f"{self.base_class_name} = declarative_base()"]

if self.metadata.naming_convention != DEFAULT_NAMING_CONVENTION:
formatted_naming_convention = pformat(self.metadata.naming_convention)
declarations.append(
f"Base.metadata.naming_convention = {formatted_naming_convention}"
)

if any(not isinstance(model, ModelClass) for model in models):
declarations.append(f"metadata = {self.base_class_name}.metadata")

Expand Down Expand Up @@ -1089,7 +1111,7 @@ def render_table_args(self, table: Table) -> str:

# Render constraints
for constraint in sorted(table.constraints, key=get_constraint_sort_key):
if uses_default_name(constraint):
if not get_explicit_name(constraint):
if isinstance(constraint, PrimaryKeyConstraint):
continue
if (
Expand All @@ -1102,7 +1124,7 @@ def render_table_args(self, table: Table) -> str:

# Render indexes
for index in sorted(table.indexes, key=lambda i: i.name):
if len(index.columns) > 1 or not uses_default_name(index):
if len(index.columns) > 1 or get_explicit_name(index):
args.append(self.render_index(index))

if table.schema:
Expand Down Expand Up @@ -1272,6 +1294,15 @@ def render_module_variables(self, models: list[Model]) -> str:
return super().render_module_variables(models)

declarations: list[str] = ["mapper_registry = registry()"]

if self.metadata.naming_convention != DEFAULT_NAMING_CONVENTION:
formatted_naming_convention = pformat(self.metadata.naming_convention)
declarations.append(
"mapper_registry.metadata.naming_convention = {}".format(
formatted_naming_convention
)
)

if any(not isinstance(model, ModelClass) for model in models):
declarations.append("metadata = mapper_registry.metadata")

Expand Down Expand Up @@ -1394,6 +1425,15 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:

def render_module_variables(self, models: list[Model]) -> str:
declarations: list[str] = []

if self.metadata.naming_convention != DEFAULT_NAMING_CONVENTION:
formatted_naming_convention = pformat(self.metadata.naming_convention)
declarations.append(
"{}.metadata.naming_convention = {}".format(
self.base_class_name, formatted_naming_convention
)
)

if any(not isinstance(model, ModelClass) for model in models):
declarations.append(f"metadata = {self.base_class_name}.metadata")

Expand Down
60 changes: 54 additions & 6 deletions src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import re
from collections.abc import Mapping

from sqlalchemy import PrimaryKeyConstraint, UniqueConstraint
from sqlalchemy.engine import Connectable
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.schema import (
CheckConstraint,
ColumnCollectionConstraint,
Expand Down Expand Up @@ -53,12 +55,50 @@ def get_common_fk_constraints(
return c1.union(c2)


def uses_default_name(constraint: Constraint | Index) -> bool:
def _handle_constraint_name_token(
constraint_name: str,
convention: str,
values: dict[str, str],
) -> str | conv:
"""
Get explicit name for conventions with the token `constraint_name` using regex

Replace first occurence of the token with (\\w+) and subsequent ones with (\1),
then add ^ and $ for exact match

:param constraint_name: name of constraint
:param convention: naming convention of the constraint as defined in metadata
:param values: mapping of token key and value

Example:
If `convention` is `abc_%(constraint_name)s_123`, the regex pattern will
be `^abc_(\\w+)_123$`, the first (and only) matched group will then be returned

"""
placeholder = "%(constraint_name)s"
try:
pattern = convention % {**values, **{"constraint_name": placeholder}}
except KeyError:
return conv(constraint_name)

pattern = re.escape(pattern)
escaped_placeholder = re.escape(placeholder)

# Replace first occurence with (\w+) and subsequent ones with (\1), then add ^ and $
pattern = pattern.replace(escaped_placeholder, r"(\w+)", 1)
pattern = pattern.replace(escaped_placeholder, r"(\1)")
pattern = "".join(["^", pattern, "$"])

match = re.match(pattern, constraint_name)
return conv(constraint_name) if match is None else match[1]


def get_explicit_name(constraint: Constraint | Index) -> str | conv:
if not constraint.name or constraint.table is None:
return True
return ""

table = constraint.table
values = {"table_name": table.name, "constraint_name": constraint.name}
values = {"table_name": table.name}
if isinstance(constraint, (Index, ColumnCollectionConstraint)):
values.update(
{
Expand Down Expand Up @@ -130,11 +170,19 @@ def uses_default_name(constraint: Constraint | Index) -> bool:
else:
raise TypeError(f"Unknown constraint type: {constraint.__class__.__qualname__}")

if key not in table.metadata.naming_convention:
return constraint.name

convention: str = table.metadata.naming_convention[key]
if "%(constraint_name)s" in convention:
return _handle_constraint_name_token(constraint.name, convention, values)

try:
convention: str = table.metadata.naming_convention[key]
return constraint.name == (convention % values)
parsed = convention % values
# No explicit name needed if constraint name already follows naming convention
return "" if constraint.name == parsed else constraint.name
except KeyError:
return False
return constraint.name


def render_callable(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,36 @@ def test_main() -> None:
check=True,
)
assert completed.stdout.decode().strip() == expected_version


@pytest.fixture
def empty_db_path(tmp_path: Path) -> Path:
path = tmp_path / "test.db"

return path


def test_naming_convention(empty_db_path: Path, tmp_path: Path) -> None:
output_path = tmp_path / "outfile"
subprocess.run(
[
"sqlacodegen",
f"sqlite:///{empty_db_path}",
"--outfile",
str(output_path),
"--conv",
"pk=pk_%(table_name)s",
],
check=True,
)

assert (
output_path.read_text()
== """\
from sqlalchemy import MetaData

metadata = MetaData()
metadata.naming_convention = {'pk': 'pk_%(table_name)s'}

"""
)
Loading