diff --git a/internal/db/db_schema.go b/internal/db/db_schema.go index eecf3c79b..af67dbe41 100644 --- a/internal/db/db_schema.go +++ b/internal/db/db_schema.go @@ -262,7 +262,7 @@ func getSearchConfigFromIndex(fullText *input.FullText) string { if fullText.LanguageColumn == "" { return fmt.Sprintf("'%s'", fullText.Language) } else { - return fmt.Sprintf("%s::reconfig", fullText.LanguageColumn) + return fmt.Sprintf("%s::regconfig", fullText.LanguageColumn) } } @@ -1104,6 +1104,7 @@ func (s *dbSchema) addUniqueConstraint(nodeData *schema.NodeData, inputConstrain return nil } +// same logic as parse_db._default_index func (s *dbSchema) getDefaultIndexType(f *field.Field) input.IndexType { // default index type for lists|jsonb when not specified is gin type typ := f.GetFieldType() diff --git a/internal/db/db_schema_test.go b/internal/db/db_schema_test.go index da3a5c7d0..9a344ba6c 100644 --- a/internal/db/db_schema_test.go +++ b/internal/db/db_schema_test.go @@ -1824,7 +1824,7 @@ func TestFullTextIndexMultipleColsLangColumn(t *testing.T) { strconv.Quote("users_name_idx"), getKVDict([]string{ getKVPair("postgresql_using", strconv.Quote("gin")), - getKVPair("postgresql_using_internals", strconv.Quote("to_tsvector(language::reconfig, coalesce(first_name, '') || ' ' || coalesce(last_name, ''))")), + getKVPair("postgresql_using_internals", strconv.Quote("to_tsvector(language::regconfig, coalesce(first_name, '') || ' ' || coalesce(last_name, ''))")), getKVPair("columns", fmt.Sprintf("[%s, %s]", strconv.Quote("first_name"), diff --git a/python/Pipfile b/python/Pipfile index 63ba00c4e..e8397fcef 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -14,6 +14,7 @@ sqlalchemy = "==1.4.35" psycopg2 = "==2.9.3" autopep8 = "==1.5.4" python-dateutil= "==2.8.2" +inflect= "==6.0.2" [requires] python_version = "3.8" diff --git a/python/auto_schema/auto_schema/clause_text.py b/python/auto_schema/auto_schema/clause_text.py index 4eaea9991..c59b9b177 100644 --- a/python/auto_schema/auto_schema/clause_text.py +++ b/python/auto_schema/auto_schema/clause_text.py @@ -7,7 +7,7 @@ clause_regex = re.compile("(.+)'::(.+)") date_regex = re.compile( - '([0-9]{4})-([0-9]{2})-([0-9]{2})[T| ]([0-9]{2}):([0-9]{2}):([0-9]{2})(\.[0-9]{3})?(.+)?') + r"([0-9]{4})-([0-9]{2})-([0-9]{2})[T| ]([0-9]{2}):([0-9]{2}):([0-9]{2})(\.[0-9]{3})?(.+)?") valid_suffixes = { @@ -69,6 +69,9 @@ def normalize(arg): # return the underlying string instead of quoted arg = str(arg).strip("'") + # condition `price > 0` ends up as `price > (0)::numeric` so we're trying to fix that + arg = re.sub(r"\(([0-9]+)\)::numeric", r'\1', arg) + # strip the extra text padding added so we can compare effectively m = clause_regex.match(arg) if m is None: diff --git a/python/auto_schema/auto_schema/cli/__init__.py b/python/auto_schema/auto_schema/cli/__init__.py index 54b204e9b..e7edc3109 100644 --- a/python/auto_schema/auto_schema/cli/__init__.py +++ b/python/auto_schema/auto_schema/cli/__init__.py @@ -5,7 +5,6 @@ import warnings import alembic -import sqlalchemy # if env variable is set, manipulate the path to put local # current directory over possibly installed auto_schema so that we @@ -17,6 +16,7 @@ # run from auto_schema root. conflicts with pip-installed auto_schema when that exists so can't have # that installed when runnning this... from auto_schema.runner import Runner +from auto_schema.parse_db import ParseDB from importlib import import_module @@ -53,6 +53,8 @@ '--changes', help='get changes in schema', action='store_true') parser.add_argument( '--debug', help='if debug flag passed', action='store_true') +parser.add_argument( + '--import_db', help='import given a schema uri', action='store_true') # see https://alembic.sqlalchemy.org/en/latest/offline.html # if true, pased to u @@ -89,6 +91,12 @@ def main(): try: args = parser.parse_args() + + if args.import_db is True: + p = ParseDB(args.engine) + p.parse_and_print() + return + sys.path.append(os.path.relpath(args.schema)) schema = import_module('schema') diff --git a/python/auto_schema/auto_schema/compare.py b/python/auto_schema/auto_schema/compare.py index faee3907c..3e4051081 100644 --- a/python/auto_schema/auto_schema/compare.py +++ b/python/auto_schema/auto_schema/compare.py @@ -3,6 +3,7 @@ from alembic.autogenerate.api import AutogenContext from auto_schema.schema_item import FullTextIndex +from auto_schema.introspection import get_raw_db_indexes from . import ops from alembic.operations import Operations, MigrateOperation import sqlalchemy as sa @@ -444,8 +445,8 @@ def _compare_indexes(autogen_context: AutogenContext, metadata_table: sa.Table, ): - raw_db_indexes = _get_raw_db_indexes( - autogen_context, conn_table) + raw_db_indexes = get_raw_db_indexes( + autogen_context.connection, conn_table) missing_conn_indexes = raw_db_indexes.get('missing') all_conn_indexes = raw_db_indexes.get('all') conn_indexes = {} @@ -524,56 +525,3 @@ def _compare_indexes(autogen_context: AutogenContext, unique=index.unique, info=index.info, ) - - -index_regex = re.compile('CREATE INDEX (.+) USING (gin|btree)(.+)') - - -# sqlalchemy doesn't reflect postgres indexes that have expressions in them so have to manually -# fetch these indices from pg_indices to find them -# warning: "Skipped unsupported reflection of expression-based index accounts_full_text_idx" -def _get_raw_db_indexes(autogen_context: AutogenContext, conn_table: Optional[sa.Table]): - if conn_table is None or _dialect_name(autogen_context) != 'postgresql': - return {'missing': {}, 'all': {}} - - missing = {} - all = {} - # we cache the db hit but the table seems to change across the same call and so we're - # just paying the CPU price. can probably be fixed in some way... - names = set([index.name for index in conn_table.indexes] + - [constraint.name for constraint in conn_table.constraints]) - res = get_db_indexes_for_table(autogen_context.connection, conn_table.name) - - for row in res.fetchall(): - ( - name, - details - ) = row - m = index_regex.match(details) - if m is None: - continue - r = m.groups() - - all[name] = { - 'postgresql_using': r[1], - 'postgresql_using_internals': r[2], - # TODO don't have columns|column to pass to FullTextIndex - } - - # missing! - if name not in names: - missing[name] = { - 'postgresql_using': r[1], - 'postgresql_using_internals': r[2], - # TODO don't have columns|column to pass to FullTextIndex - } - - return {'missing': missing, 'all': all} - - -# use a cache so we only hit the db once for each table -# @functools.lru_cache() -def get_db_indexes_for_table(connection: sa.engine.Connection, tname: str): - res = connection.execute( - "SELECT indexname, indexdef from pg_indexes where tablename = '%s'" % tname) - return res diff --git a/python/auto_schema/auto_schema/introspection.py b/python/auto_schema/auto_schema/introspection.py new file mode 100644 index 000000000..5e506d702 --- /dev/null +++ b/python/auto_schema/auto_schema/introspection.py @@ -0,0 +1,84 @@ +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql +import re +from typing import Optional + + +def get_sorted_enum_values(connection: sa.engine.Connection, enum_type: str): + # we gotta go to the db and check the order + db_sorted_enums = [] + # https://www.postgresql.org/docs/9.5/functions-enum.html + query = "select unnest(enum_range(enum_first(null::%s)));" % ( + enum_type) + for row in connection.execute(query): + db_sorted_enums.append(dict(row)['unnest']) + + return db_sorted_enums + + +index_regex = re.compile('CREATE INDEX (.+) USING (gin|btree|gist)(.+)') + + +def _dialect_name(conn: sa.engine.Connection) -> str: + return conn.dialect.name + + +# sqlalchemy doesn't reflect postgres indexes that have expressions in them so have to manually +# fetch these indices from pg_indices to find them +# warning: "Skipped unsupported reflection of expression-based index accounts_full_text_idx" + +# this only returns those that match a using... +# TODO check what happens when this is not all-caps +def get_raw_db_indexes(connection: sa.engine.Connection, table: Optional[sa.Table]): + if table is None or _dialect_name(connection) != 'postgresql': + return {'missing': {}, 'all': {}} + + missing = {} + all = {} + # we cache the db hit but the table seems to change across the same call and so we're + # just paying the CPU price. can probably be fixed in some way... + names = set([index.name for index in table.indexes] + + [constraint.name for constraint in table.constraints]) + res = _get_db_indexes_for_table(connection, table.name) + + for row in res.fetchall(): + ( + name, + details + ) = row + m = index_regex.match(details) + if m is None: + continue + + r = m.groups() + + all[name] = { + 'postgresql_using': r[1], + 'postgresql_using_internals': r[2], + # TODO don't have columns|column to pass to FullTextIndex + } + + # missing! + if name not in names: + missing[name] = { + 'postgresql_using': r[1], + 'postgresql_using_internals': r[2], + # TODO don't have columns|column to pass to FullTextIndex + } + + return {'missing': missing, 'all': all} + + +# use a cache so we only hit the db once for each table +# @functools.lru_cache() +def _get_db_indexes_for_table(connection: sa.engine.Connection, tname: str): + res = connection.execute( + "SELECT indexname, indexdef from pg_indexes where tablename = '%s'" % tname) + return res + + +def default_index(table: sa.Table, col_name: str): + col = table.columns[col_name] + if isinstance(col.type, postgresql.JSONB) or isinstance(col.type, postgresql.JSON) or isinstance(col.type, postgresql.ARRAY): + return 'gin' + return 'btree' diff --git a/python/auto_schema/auto_schema/parse_db.py b/python/auto_schema/auto_schema/parse_db.py new file mode 100644 index 000000000..c008b0ace --- /dev/null +++ b/python/auto_schema/auto_schema/parse_db.py @@ -0,0 +1,723 @@ +from cgi import test +import sqlalchemy as sa +import json +import re +from sqlalchemy.dialects import postgresql +import inflect +from enum import Enum +from auto_schema.introspection import get_sorted_enum_values, get_raw_db_indexes, default_index +from auto_schema.clause_text import get_clause_text + +# copied from ts/src/schema/schema.ts + + +class DBType(str, Enum): + UUID = "UUID" + Int64ID = "Int64ID" # unsupported right now + Boolean = "Boolean" + Int = "Int" + BigInt = "BigInt" + Float = "Float" + String = "String" + + Timestamp = "Timestamp" + Timestamptz = "Timestamptz" + JSON = "JSON" # JSON type in the database + JSONB = "JSONB" # JSONB type in the database Postgres + Enum = "Enum" # enum type in the database + StringEnum = "StringEnum" # string type in the database + IntEnum = "IntEnum" # int type in the database + + Date = "Date" + Time = "Time" + Timetz = "Timetz" + + List = "List" + + +class ConstraintType(str, Enum): + PrimaryKey = "primary" + ForeignKey = "foreign" + Unique = "unique" + Check = "check" + + +sqltext_regex = re.compile(r"to_tsvector\((.+?), (.+)\)") +edge_name_regex = re.compile('(.+?)To(.+)Edge') + + +class ParseDB(object): + + def __init__(self, engine_conn): + if isinstance(engine_conn, sa.engine.Connection): + self.connection = engine_conn + else: + engine = sa.create_engine(engine_conn) + self.connection = engine.connect() + self.metadata = sa.MetaData() + self.metadata.bind = self.connection + self.metadata.reflect() + + def parse(self): + assoc_edge_config = [ + table for table in self.metadata.sorted_tables if table.name == 'assoc_edge_config'] + existing_edges = {} + + if assoc_edge_config: + for row in self.connection.execute('select * from assoc_edge_config'): + edge = dict(row) + edge_table = edge['edge_table'] + edges = existing_edges.get(edge_table, []) + edges.append(edge) + existing_edges[edge_table] = edges + + nodes = {} + for table in self.metadata.sorted_tables: + if table.name == 'alembic_version' or existing_edges.get(table.name) is not None or table.name == "assoc_edge_config": + continue + + node = self._parse_table(table) + nodes[ParseDB.table_to_node(table.name)] = node + + (unknown_edges, edges_map) = self._parse_edges_info(existing_edges, nodes) + + for (k, v) in edges_map.items(): + node = nodes[k] + node["edges"] = v + + return nodes + + def parse_and_print(self): + print(json.dumps(self.parse())) + + def _parse_edges_info(self, existing_edges: dict, nodes: dict): + unknown_edges = [] + edges_map = {} + + # todo global edge + for item in existing_edges.items(): + table_name = item[0] + edges = item[1] + if len(edges) == 1: + self._handle_single_edge_in_table( + edges[0], nodes, edges_map, unknown_edges) + else: + self._handle_multi_edges_in_table( + edges, nodes, edges_map, unknown_edges) + + return (unknown_edges, edges_map) + + def _parse_edge_name(self, edge: dict): + m = edge_name_regex.match(edge['edge_name']) + if m is None: + return None + return m.groups() + + def _handle_single_edge_in_table(self, edge, nodes, edges_map, unknown_edges): + t = self._parse_edge_name(edge) + # unknown edges + if t is None or nodes.get(t[0], None) is None: + print("unknown edge", edge, '\n') + unknown_edges.append(edge) + return + + if edge["symmetric_edge"]: + # symmetric + node_edges = edges_map.get(t[0], []) + node_edges.append({ + "name": t[1], + "schemaName": t[0], + "symmetric": True, + }) + edges_map[t[0]] = node_edges + else: + + res = self.connection.execute( + 'select id2_type, count(id2_type) from %s group by id2_type' % (edge["edge_table"])).fetchall() + + if len(res) != 1: + print("unknown edge can't determine schemaName", edge, '\n') + unknown_edges.append(edge) + return + + toNode = res[0][0].title() + if nodes.get(toNode, None) is None: + print( + "unknown edge can't determine schemaName because toNode is unknown", edge, toNode, '\n') + unknown_edges.append(edge) + return + + node_edges = edges_map.get(t[0], []) + node_edges.append({ + "name": t[1], + "schemaName": toNode, + }) + edges_map[t[0]] = node_edges + + def _handle_multi_edges_in_table(self, edges, nodes, edges_map, unknown_edges): + edge_types = {} + for edge in edges: + edge_types[str(edge['edge_type'])] = edge + + seen = {} + for edge in edges: + edge_type = str(edge["edge_type"]) + if seen.get(edge_type, False): + continue + + seen[edge_type] = True + + # assoc edge group?? + if edge["symmetric_edge"] or edge['inverse_edge_type'] is None: + self._handle_single_edge_in_table( + edge, nodes, edges_map, unknown_edges) + continue + + # for inverse edges, we don't know which schema should be the source of truth + # so it ends up being randomly placed in one or the other + + inverse_edge_type = str(edge['inverse_edge_type']) + if inverse_edge_type not in edge_types: + print('unknown inverse edge', inverse_edge_type) + unknown_edges.append(edge) + continue + + inverse_edge = edge_types[inverse_edge_type] + seen[inverse_edge_type] = True + + t1 = self._parse_edge_name(edge) + t2 = self._parse_edge_name(inverse_edge) + + # pattern or polymorphic edge... + if t1 is None or t2 is None or nodes.get(t1[0], None) is None or nodes.get(t2[0], None) is None: + print("unknown edge or inverse edge", edge, inverse_edge, "\n") + + unknown_edges.append(edge) + unknown_edges.append(inverse_edge) + continue + + node = t1[0] + inverseNode = t2[0] + + node_edges = edges_map.get(node, []) + node_edges.append({ + "name": t1[1], + "schemaName": inverseNode, + "inverseEdge": { + "name": t2[1] + }, + }) + edges_map[node] = node_edges + + def _parse_table(self, table: sa.Table): + node = {} + col_indices = {} + col_unique = {} + # parse indices and constraints before columns and get col specific data + indices = self._parse_indices(table, col_indices) + constraints = self._parse_constraints(table, col_unique) + node["fields"] = self._parse_columns(table, col_indices, col_unique) + node["constraints"] = constraints + node["indices"] = indices + + return node + + def _parse_columns(self, table: sa.Table, col_indices: dict, col_unique: dict): + # TODO handle column foreign key so we don't handle them in constraints below... + fields = {} + for col in table.columns: + # we don't return computed fields + if col.computed: + continue + + field = {} + field['storageKey'] = col.name + if col.primary_key: + field['primaryKey'] = True + + if isinstance(col.type, postgresql.ARRAY): + field['type'] = { + "dbType": DBType.List, + "listElemType": self._parse_column_type(col.type.item_type), + } + else: + field["type"] = self._parse_column_type(col.type) + + if col.nullable: + field["nullable"] = True + if col.name in col_indices or col.index: + field["index"] = True + if col.name in col_unique or col.unique: + field["unique"] = True + + fkey = self._parse_foreign_key(col) + if fkey is not None: + field["foreignKey"] = fkey + + server_default = get_clause_text(col.server_default, col.type) + if server_default is not None: + field["serverDefault"] = server_default + + # TODO foreign key, server default + if len(col.constraints) != 0: + raise Exception( + "column %s in table %s has more than one constraint which is not supported" % (col.name, table.name)) + + if col.default: + raise Exception( + "column %s in table %s has default which is not supported" % (col.name, table.name)) + + if col.onupdate: + raise Exception( + "column %s in table %s has onupdate which is not supported" % (col.name, table.name)) + + # if col.key: + # print(col.key) + # raise Exception( + # "column %s in table %s has key which is not supported" % (col.name, table.name)) + + # ignoring comment + fields[col.name] = field + + return fields + + # keep this in sync with testingutils._validate_parsed_data_type + def _parse_column_type(self, col_type): + if isinstance(col_type, sa.TIMESTAMP): + # sqlite doesn't support timestamp with timezone + dialect = self.connection.dialect.name + if col_type.timezone and dialect != 'sqlite': + return { + "dbType": DBType.Timestamptz + } + return { + "dbType": DBType.Timestamp + } + + if isinstance(col_type, sa.Time): + # sqlite doesn't support with timezone + dialect = self.connection.dialect.name + if col_type.timezone and dialect != 'sqlite': + return { + "dbType": DBType.Timetz + } + return { + "dbType": DBType.Time + } + + if isinstance(col_type, sa.Date): + return { + "dbType": DBType.Date + } + + # ignoring precision for now + # TODO + if isinstance(col_type, sa.Numeric): + return { + "dbType": DBType.Float + } + + if isinstance(col_type, postgresql.ENUM): + db_sorted_enums = get_sorted_enum_values( + self.connection, col_type.name) + + return { + "dbType": DBType.Enum, + "values": db_sorted_enums + } + + if isinstance(col_type, postgresql.JSONB): + return { + "dbType": DBType.JSONB + } + + if isinstance(col_type, postgresql.JSON): + return { + "dbType": DBType.JSON + } + + if isinstance(col_type, postgresql.UUID): + return { + "dbType": DBType.UUID + } + + if isinstance(col_type, sa.String): + return { + "dbType": DBType.String + } + + if isinstance(col_type, sa.Boolean): + return { + "dbType": DBType.Boolean + } + + if isinstance(col_type, sa.Integer): + if isinstance(col_type, sa.BigInteger) or col_type.__visit_name__ == 'big_integer' or col_type.__visit_name__ == 'BIGINT': + return { + "dbType": DBType.BigInt + } + return { + "dbType": DBType.Int + } + + raise Exception("unsupported type %s" % str(col_type)) + + def _parse_foreign_key(self, col: sa.Column): + if len(col.foreign_keys) > 1: + raise Exception( + "don't currently support multiple foreign keys in a column ") + + for fkey in col.foreign_keys: + return { + "schema": ParseDB.table_to_node(fkey.column.table.name), + "column": fkey.column.name, + } + + return None + + @ classmethod + def _singular(cls, table_name) -> str: + p = inflect.engine() + ret = p.singular_noun(table_name) + # TODO address this for not-tests + # what should the node be called?? + # how does this affect GraphQL/TypeScript names etc? + if ret is False: + return table_name + return ret + + @ classmethod + def table_to_node(cls, table_name) -> str: + return "".join([t.title() + for t in cls._singular(table_name).split("_")]) + + def _parse_constraints(self, table: sa.Table, col_unique: dict): + constraints = [] + for constraint in table.constraints: + constraint_type = None + condition = None + single_col = None + if len(constraint.columns) == 1: + single_col = constraint.columns[0] + + if isinstance(constraint, sa.CheckConstraint): + constraint_type = ConstraintType.Check + condition = constraint.sqltext + + if isinstance(constraint, sa.UniqueConstraint): + if single_col is not None: + col_unique[single_col.name] = True + continue + constraint_type = ConstraintType.Unique + + if isinstance(constraint, sa.ForeignKeyConstraint): + if single_col is not None: + if len(single_col.foreign_keys) == 1: + # handled at the column level + continue + constraint_type = ConstraintType.ForeignKey + + if isinstance(constraint, sa.PrimaryKeyConstraint): + if single_col is not None and single_col.primary_key: + continue + constraint_type = ConstraintType.PrimaryKey + + if not constraint_type: + raise Exception("invalid constraint_type %s" % str(constraint)) + + constraints.append({ + "name": constraint.name, + "type": constraint_type, + "columns": [col.name for col in constraint.columns], + 'condition': condition, + }) + return constraints + + def _parse_indices(self, table: sa.Table, col_indices: dict): + indices = [] + + col_names = set([col.name for col in table.columns]) + generated_columns = self._parse_generated_columns(table, col_names) + + raw_db_indexes = get_raw_db_indexes(self.connection, table) + all_conn_indexes = raw_db_indexes.get('all') + + seen = {} + for name, info in all_conn_indexes.items(): + seen[name] = True + internals = info.get("postgresql_using_internals") + internals = internals.strip() + if internals.startswith("("): + internals = internals[1:] + if internals.endswith(")"): + internals = internals[:-1] + + index_type = info.get("postgresql_using") + + generated_col_info = generated_columns.get(internals, None) + + # nothing to do here. index on a column. + if internals in col_names and default_index(table, internals) == index_type: + col_indices[internals] = True + continue + + if generated_col_info is not None: + idx = { + "name": name, + "columns": generated_col_info.get("columns"), + "fulltext": { + "language": generated_col_info.get("language"), + "indexType": index_type, + "generatedColumnName": internals, + } + } + if generated_col_info.get('weights', None) is not None: + idx['fulltext']['weights'] = generated_col_info['weights'] + + indices.append(idx) + continue + + if internals in col_names and index_type is not None: + indices.append({ + "name": name, + "columns": [internals], + "indexType": index_type, + }) + continue + + internals_parsed = self._parse_postgres_using_internals( + internals, index_type, col_names) + + if internals_parsed.get('fulltext', None): + indices.append({ + 'columns': internals_parsed['columns'], + 'fulltext': internals_parsed['fulltext'], + 'name': name, + }) + continue + + # TODO we can actually punt this to regular indexes below... + # difference is index_type here vs below... + if internals_parsed.get('columns', None): + indices.append({ + "name": name, + "columns": internals_parsed.get('columns'), + "indexType": index_type, + }) + continue + + raise Exception("unsupported index %s in table %s" % + (name, table.name)) + + for index in table.indexes: + if seen.get(index.name, False): + continue + + # we don't get raw sqlite cols above so need this for sqlite... + single_col = None + if len(index.columns) == 1: + single_col = index.columns[0] + index_type = index.kwargs.get('postgresql_using') + default_index_type = default_index(table, single_col.name) + + if (index_type == False and default_index_type == 'btree') or default_index(table, single_col.name) == index_type: + col_indices[single_col.name] = True + continue + + indices.append({ + "name": index.name, + "unique": index.unique, + "columns": [col.name for col in index.columns], + }) + + return indices + + def _parse_generated_columns(self, table: sa.Table, col_names: set): + generated = {} + for col in table.columns: + def unsupported_col(sqltext): + raise Exception("unsupported sqltext %s for col %s in table %s" % ( + sqltext, col.name, table.name)) + + if not col.computed: + continue + + if not isinstance(col.type, postgresql.TSVECTOR): + raise Exception( + "unsupported computed type %s which isn't a tsvector" % str(col.type)) + + sqltext = str(col.computed.sqltext) + # wrap in () if not wrapped. needed for parsing logic to be consistent + if not sqltext.startswith("("): + sqltext = "(%s)" % sqltext + + # all this logic with no coalesce is what we want i think... + # print('sqltext - m', sqltext, m) + res = self._parse_str_into_parts(sqltext) + if len(res) != 1: + raise Exception('parsed incorrect') + + cols = [] + lang = '' + weights = {} + + for child in res[0].children: + text = sqltext[child.beg_cursor:child.end].strip().strip( + '||').lstrip('(').strip() + + weight = None + if text.startswith('setweight'): + idx = text.rfind(',') + weight = text[idx + + 1:].rstrip(')').rstrip('::"char').replace("'", "").strip() + + text = child.str[1:idx] + + m = sqltext_regex.match(text) + + if not m: + unsupported_col(text) + + groups = m.groups() + lang = groups[0].rstrip("::regconfig").strip("'") + # TODO ensure lang is consistent? + + val = groups[1] + starts = [m.start() + for m in re.finditer('COALESCE', val)] + + # no coalesce... + if len(starts) == 0: + starts = [0] + + for i in range(len(starts)): + if i + 1 == len(starts): + curr = val[starts[i]: len(val)] + else: + curr = val[starts[i]: starts[i+1]-1] + + cols2 = self._parse_cols_from( + curr, sqltext, col_names, unsupported_col) + cols = cols + cols2 + # This exists for tests + cols.sort() + + if weight is not None: + l = weights.get(weight, []) + l = l + cols2 + # l.append(coll) + weights[weight] = list(set(l)) + # this exists for tests + weights[weight].sort() + + ret = { + "language": lang, + "columns": cols, + } + if weights: + ret['weights'] = weights + + generated[col.name] = ret + + return generated + + def _parse_cols_from(self, curr: str, sqltext: str, col_names, err_fn): + cols = [] + for s in curr.strip().split('||'): + if not s: + continue + s = s.strip().strip('(').strip(')') + if s.startswith('COALESCE'): + s = s[8:] + + for s2 in s.split(','): + s2 = s2.strip().strip('(').strip(')') + if s2 == "''::text" or s2 == "' '::text": + continue + + if s2 in col_names: + cols.append(s2) + + else: + err_fn(sqltext) + + return cols + + def _parse_postgres_using_internals(self, internals: str, index_type: str, col_names): + # single-col to_tsvector('english'::regconfig, first_name) + # multi-col to_tsvector('english'::regconfig, ((first_name || ' '::text) || last_name)) + m = sqltext_regex.match(internals) + if m: + groups = m.groups() + lang = groups[0].rstrip("::regconfig").strip("'") + + def error_fn(s): + raise Exception('error parsing columns' % s) + + cols = self._parse_cols_from( + groups[1], internals, col_names, error_fn) + + if len(cols) > 0: + return { + 'columns': cols, + 'fulltext': { + 'language': lang, + 'indexType': index_type, + } + } + + cols = [col.strip() for col in internals.split(',')] + # multi-column index + if len(cols) > 1: + return { + 'columns': cols, + } + + return {} + + def _parse_str_into_parts(self, s: str): + res = [] + stack = [] + end = -1 + for i in range(0, len(s)): + c = s[i] + if c == '(': + curr = Tree(i, end+1) + + # how to know when to add another top level + if len(stack) == 0: + res.append(curr) + + if len(stack) > 0: + # add as child + stack[-1].append(curr) + + stack.append(curr) + + if c == ')': + end = i + + curr = stack[-1] + l = curr.beg_paren + + if len(stack) == 1: + l = curr.end + 1 + + curr.str = s[l:i] + curr.end = i+1 + + stack.pop() + + return res + + +class Tree: + def __init__(self, beg_paren, beg_cursor) -> None: + self.children = [] + self.str = '' + self.beg_paren = beg_paren + self.beg_cursor = beg_cursor + self.end = -1 + + def append(self, child): + self.children.append(child) diff --git a/python/auto_schema/auto_schema/schema_item.py b/python/auto_schema/auto_schema/schema_item.py index 66ffb8df3..001d940ba 100644 --- a/python/auto_schema/auto_schema/schema_item.py +++ b/python/auto_schema/auto_schema/schema_item.py @@ -19,3 +19,5 @@ def __init__(self, name: str, **kw) -> None: elif 'column' in info: cols = [info.get('column')] super().__init__(name, *cols, **kw) + # can't seem to always access this from the base class and not sure why so storing this here for now + self.kwwwww = kw diff --git a/python/auto_schema/setup.py b/python/auto_schema/setup.py index 3831e1cc1..5a581e85a 100644 --- a/python/auto_schema/setup.py +++ b/python/auto_schema/setup.py @@ -26,7 +26,8 @@ "datetime==4.3", "psycopg2==2.8.6", "autopep8==1.5.4", - "python-dateutil==2.8.2" + "python-dateutil==2.8.2", + "inflect==6.0.2" ], entry_points={'console_scripts': ["auto_schema = auto_schema.cli:main"]}, include_package_data=True diff --git a/python/auto_schema/tests/conftest.py b/python/auto_schema/tests/conftest.py index c15349b8f..b790373b4 100644 --- a/python/auto_schema/tests/conftest.py +++ b/python/auto_schema/tests/conftest.py @@ -773,7 +773,10 @@ def metadata_with_generated_col_fulltext_search_index(metadata_with_table): sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( "to_tsvector('english', first_name || ' ' || last_name)")), sa.Index('accounts_full_text_idx', - 'full_name', postgresql_using='gin'), + 'full_name', postgresql_using='gin', test_data={ + 'language': 'english', + 'columns': ['first_name', 'last_name'], + }), extend_existing=True) @@ -785,7 +788,101 @@ def metadata_with_generated_col_fulltext_search_index_gist(metadata_with_table): sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( "to_tsvector('english', first_name || ' ' || last_name)")), sa.Index('accounts_full_text_idx', - 'full_name', postgresql_using='gist'), + 'full_name', postgresql_using='gist', test_data={ + 'language': 'english', + 'columns': ['first_name', 'last_name'], + }), + + extend_existing=True) + + return metadata_with_table + + +def metadata_with_generated_col_fulltext_search_index_matched_weights(metadata_with_table): + sa.Table('accounts', metadata_with_table, + sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( + "setweight(to_tsvector('english', coalesce(first_name, '')), 'A') || setweight(to_tsvector('english', coalesce(last_name, '')), 'A')")), + sa.Index('accounts_full_text_idx', + 'full_name', postgresql_using='gin', test_data={ + 'language': 'english', + 'weights': { + 'A': ['first_name', 'last_name'], + }, + 'columns': ['first_name', 'last_name'], + }), + + extend_existing=True) + + return metadata_with_table + + +def metadata_with_generated_col_fulltext_search_index_matched_weights_no_coalesce(metadata_with_table): + sa.Table('accounts', metadata_with_table, + sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( + "setweight(to_tsvector('english', first_name), 'A') || setweight(to_tsvector('english', last_name), 'A')")), + sa.Index('accounts_full_text_idx', + 'full_name', postgresql_using='gin', test_data={ + 'language': 'english', + 'weights': { + 'A': ['first_name', 'last_name'], + }, + 'columns': ['first_name', 'last_name'], + }), + + extend_existing=True) + + return metadata_with_table + + +def metadata_with_generated_col_fulltext_search_index_mismatched_weights(metadata_with_table): + sa.Table('accounts', metadata_with_table, + sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( + "setweight(to_tsvector('english', coalesce(first_name, '')), 'A') || setweight(to_tsvector('english', coalesce(last_name, '')), 'B')")), + sa.Index('accounts_full_text_idx', + 'full_name', postgresql_using='gin', test_data={ + 'language': 'english', + 'weights': { + 'A': ['first_name'], + 'B': ['last_name'], + }, + 'columns': ['first_name', 'last_name'], + }), + + extend_existing=True) + + return metadata_with_table + + +def metadata_with_generated_col_fulltext_search_index_one_weight(metadata_with_table): + sa.Table('accounts', metadata_with_table, + sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( + "(setweight(to_tsvector('english', coalesce(first_name, '')), 'A') || to_tsvector('english', coalesce(last_name, '')))")), + sa.Index('accounts_full_text_idx', + 'full_name', postgresql_using='gin', test_data={ + 'language': 'english', + 'weights': { + 'A': ['first_name'], + }, + 'columns': ['first_name', 'last_name'], + }), + + extend_existing=True) + + return metadata_with_table + + +def metadata_with_generated_col_fulltext_search_index_cols_in_setweight(metadata_with_table): + sa.Table('accounts', metadata_with_table, + sa.Column('full_name', postgresql.TSVECTOR(), sa.Computed( + "(setweight(to_tsvector('simple', coalesce(first_name, '') || ' ' || coalesce(last_name, '')), 'A')) ")), + sa.Index('accounts_full_text_idx', + 'full_name', postgresql_using='gin', test_data={ + 'language': 'simple', + 'weights': { + 'A': ['first_name', 'last_name'], + }, + 'columns': ['first_name', 'last_name'], + }), extend_existing=True) diff --git a/python/auto_schema/tests/runner_test.py b/python/auto_schema/tests/runner_test.py index 8c85eae38..745430342 100644 --- a/python/auto_schema/tests/runner_test.py +++ b/python/auto_schema/tests/runner_test.py @@ -33,7 +33,6 @@ def test_index_added_and_removed(self, new_test_runner, metadata_with_table): conftest.metadata_with_table_with_index, "add index accounts_first_name_idx to accounts", "drop index accounts_first_name_idx from accounts", - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_two_tables") @@ -48,6 +47,9 @@ def test_compute_changes_with_foreign_key_table(self, new_test_runner, metadata_ assert len(r.compute_changes()) == 2 testingutils.assert_no_changes_made(r) + testingutils.run_and_validate_with_standard_metadata_tables( + r, metadata_with_foreign_key, new_table_names=['accounts', 'contacts']) + @pytest.mark.usefixtures("metadata_with_foreign_key_to_same_table") def test_compute_changes_with_foreign_key_to_same_table(self, new_test_runner, metadata_with_foreign_key_to_same_table): r = new_test_runner(metadata_with_foreign_key_to_same_table) @@ -133,7 +135,6 @@ def post_r2_func(r2): conftest.metadata_with_multi_column_index, "add index accounts_first_name_last_name_idx to accounts", "drop index accounts_first_name_last_name_idx from accounts", - validate_schema=False, post_r2_func=post_r2_func ) @@ -1378,8 +1379,6 @@ def test_full_text_index_added_and_removed(self, new_test_runner, metadata_with_ conftest.metadata_with_fulltext_search_index, "add full text index accounts_first_name_idx to accounts", "drop full text index accounts_first_name_idx from accounts", - # skip validation because of complications with idx - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_multicolumn_fulltext_search") @@ -1401,8 +1400,6 @@ def test_multi_col_full_text_index_added_and_removed(self, new_test_runner, meta conftest.metadata_with_multicolumn_fulltext_search_index, "add full text index accounts_full_text_idx to accounts", "drop full text index accounts_full_text_idx from accounts", - # skip validation because of complications with idx - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_table") @@ -1416,8 +1413,6 @@ def test_multi_col_full_text_index_added_and_removed_gist(self, new_test_runner, conftest.metadata_with_multicolumn_fulltext_search_index_gist, "add full text index accounts_full_text_idx to accounts", "drop full text index accounts_full_text_idx from accounts", - # skip validation because of complications with idx - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_table") @@ -1428,20 +1423,16 @@ def test_multi_col_full_text_index_added_and_removed_btree(self, new_test_runner conftest.metadata_with_multicolumn_fulltext_search_index_btree, "add full text index accounts_full_text_idx to accounts", "drop full text index accounts_full_text_idx from accounts", - # skip validation because of complications with idx - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_table") - def test_full_text_index_with_generated_column(self, new_test_runner, metadata_with_table): + def test_full_text_index_with_generated_column_gin(self, new_test_runner, metadata_with_table): testingutils.make_changes_and_restore( new_test_runner, metadata_with_table, conftest.metadata_with_generated_col_fulltext_search_index, "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", - # skip validation because of complications with idx - validate_schema=False ) @pytest.mark.usefixtures("metadata_with_table") @@ -1452,8 +1443,56 @@ def test_full_text_index_with_generated_column_gist(self, new_test_runner, metad conftest.metadata_with_generated_col_fulltext_search_index_gist, "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", - # skip validation because of complications with idx - validate_schema=False + ) + + @pytest.mark.usefixtures("metadata_with_table") + def test_full_text_index_with_generated_column_matched_weights(self, new_test_runner, metadata_with_table): + testingutils.make_changes_and_restore( + new_test_runner, + metadata_with_table, + conftest.metadata_with_generated_col_fulltext_search_index_matched_weights, + "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", + "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", + ) + + @pytest.mark.usefixtures("metadata_with_table") + def test_full_text_index_with_generated_column_matched_weights_no_coalesce(self, new_test_runner, metadata_with_table): + testingutils.make_changes_and_restore( + new_test_runner, + metadata_with_table, + conftest.metadata_with_generated_col_fulltext_search_index_matched_weights_no_coalesce, + "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", + "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", + ) + + @pytest.mark.usefixtures("metadata_with_table") + def test_full_text_index_with_generated_column_mismatched_weights(self, new_test_runner, metadata_with_table): + testingutils.make_changes_and_restore( + new_test_runner, + metadata_with_table, + conftest.metadata_with_generated_col_fulltext_search_index_mismatched_weights, + "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", + "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", + ) + + @pytest.mark.usefixtures("metadata_with_table") + def test_full_text_index_with_generated_column_one_weight(self, new_test_runner, metadata_with_table): + testingutils.make_changes_and_restore( + new_test_runner, + metadata_with_table, + conftest.metadata_with_generated_col_fulltext_search_index_one_weight, + "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", + "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", + ) + + @pytest.mark.usefixtures("metadata_with_table") + def test_full_text_index_with_generated_column_one_cols_in_setweight(self, new_test_runner, metadata_with_table): + testingutils.make_changes_and_restore( + new_test_runner, + metadata_with_table, + conftest.metadata_with_generated_col_fulltext_search_index_cols_in_setweight, + "add column full_name to table accounts\nadd index accounts_full_text_idx to accounts", + "drop index accounts_full_text_idx from accounts\ndrop column full_name from table accounts", ) diff --git a/python/auto_schema/tests/testingutils.py b/python/auto_schema/tests/testingutils.py index 94366bb5b..226aea711 100644 --- a/python/auto_schema/tests/testingutils.py +++ b/python/auto_schema/tests/testingutils.py @@ -4,10 +4,14 @@ from sqlalchemy.dialects import postgresql from auto_schema.clause_text import get_clause_text from auto_schema import runner +from auto_schema.parse_db import ParseDB, DBType, ConstraintType, sqltext_regex +from auto_schema.schema_item import FullTextIndex from sqlalchemy.sql.sqltypes import String from auto_schema import compare +from auto_schema.introspection import get_sorted_enum_values, default_index from . import conftest +from typing import Optional def assert_num_files(r: runner.Runner, expected_count): @@ -127,13 +131,20 @@ def validate_metadata_after_change(r: runner.Runner, old_metadata: sa.MetaData): # TODO why is this here? # assert(len(old_metadata.sorted_tables)) != len(new_metadata.sorted_tables) - new_metadata.bind = r.get_connection() + conn = r.get_connection() + new_metadata.bind = conn + parse_db = ParseDB(conn) + parsed = parse_db.parse() for db_table in new_metadata.sorted_tables: schema_table = next( (t for t in old_metadata.sorted_tables if db_table.name == t.name), None) if schema_table is not None: - _validate_table(schema_table, db_table, dialect, new_metadata) + # we'll do only nodes for now + node_name = ParseDB.table_to_node(db_table.name) + parsed_data = parsed.get(node_name, None) + _validate_table(schema_table, db_table, dialect, + new_metadata, parsed_data) else: # no need to do too much testing on this since we'll just have to trust that alembic works. assert db_table.name == 'alembic_version' @@ -203,42 +214,80 @@ def _get_new_metadata_for_runner(r: runner.Runner) -> sa.MetaData: return new_metadata -def _validate_table(schema_table: sa.Table, db_table: sa.Table, dialect: String, metadata: sa.MetaData): +def _validate_table(schema_table: sa.Table, db_table: sa.Table, dialect: String, metadata: sa.MetaData, parsed_data: Optional[dict]): assert schema_table != db_table assert id(schema_table) != id(db_table) assert schema_table.name == db_table.name - _validate_columns(schema_table, db_table, metadata, dialect) - _validate_constraints(schema_table, db_table, dialect, metadata) - _validate_indexes(schema_table, db_table, metadata, dialect) + _validate_columns(schema_table, db_table, metadata, dialect, parsed_data) + _validate_constraints(schema_table, db_table, + dialect, metadata, parsed_data) + _validate_indexes(schema_table, db_table, metadata, dialect, parsed_data) -def _validate_columns(schema_table: sa.Table, db_table: sa.Table, metadata: sa.MetaData, dialect: String): +def _validate_columns(schema_table: sa.Table, db_table: sa.Table, metadata: sa.MetaData, dialect: String, parsed_data: Optional[dict]): schema_columns = schema_table.columns db_columns = db_table.columns assert len(schema_columns) == len(db_columns) for schema_column, db_column in zip(schema_columns, db_columns): - _validate_column(schema_column, db_column, metadata, dialect) + _validate_column(schema_column, db_column, + metadata, dialect, parsed_data) -def _validate_column(schema_column: sa.Column, db_column: sa.Column, metadata: sa.MetaData, dialect: String): +def _validate_column(schema_column: sa.Column, db_column: sa.Column, metadata: sa.MetaData, dialect: String, parsed_data: Optional[dict] = None): assert schema_column != db_column assert(id(schema_column)) != id(db_column) assert schema_column.name == db_column.name - _validate_column_type(schema_column, db_column, metadata, dialect) + if schema_column.computed is None: + assert db_column.computed == None + + if schema_column.computed is not None: + assert db_column.computed is not None + + parsed_data_column = None + if parsed_data is not None: + parsed_data_fields = parsed_data['fields'] + parsed_data_column = parsed_data_fields.get(schema_column.name, None) + if not schema_column.computed: + assert parsed_data_column is not None + + _validate_column_type(schema_column, db_column, + metadata, dialect, parsed_data_column) + assert schema_column.primary_key == db_column.primary_key + if parsed_data_column: + assert parsed_data_column.get( + "primaryKey", False) == schema_column.primary_key + assert schema_column.nullable == db_column.nullable + if parsed_data_column: + assert parsed_data_column.get( + "nullable", False) == schema_column.nullable - _validate_foreign_key(schema_column, db_column) - _validate_column_server_default(schema_column, db_column) + _validate_foreign_key(schema_column, db_column, parsed_data_column) + + _validate_column_server_default( + schema_column, db_column, parsed_data_column) - # we don't actually support all these below yet but when we do, it should start failing and we should know that - assert schema_column.default == db_column.default assert schema_column.index == db_column.index + # we do sa.Index in all tests + generated code + # this is really handled in _validate_indexes + if parsed_data_column and schema_column.index: + assert parsed_data_column.get( + "index", None) == schema_column.index, schema_column.name + assert schema_column.unique == db_column.unique + # we do sa.UniqueConstraint in all tests + generated code + # this is really handled in _validate_constraints + if parsed_data_column and schema_column.unique: + assert parsed_data_column.get( + "unique", None) == schema_column.unique, schema_column.name + # assert schema_column.autoincrement == db_column.autoincrement # ignore autoincrement for now as there's differences btw default behavior and postgres + # we don't actually support all these below yet but when we do, it should start failing and we should know that + assert schema_column.default == db_column.default assert schema_column.key == db_column.key assert schema_column.onupdate == db_column.onupdate assert schema_column.constraints == db_column.constraints @@ -246,7 +295,7 @@ def _validate_column(schema_column: sa.Column, db_column: sa.Column, metadata: s assert schema_column.comment == db_column.comment -def _validate_column_server_default(schema_column: sa.Column, db_column: sa.Column): +def _validate_column_server_default(schema_column: sa.Column, db_column: sa.Column, parsed_data_column: Optional[dict] = None): schema_clause_text = get_clause_text( schema_column.server_default, schema_column.type) db_clause_text = get_clause_text(db_column.server_default, db_column.type) @@ -256,27 +305,48 @@ def _validate_column_server_default(schema_column: sa.Column, db_column: sa.Colu schema_clause_text) db_clause_text = runner.Runner.convert_postgres_boolean(db_clause_text) + if schema_column.computed is not None: + schema_clause_text = schema_column.computed.sqltext + db_clause_text = db_column.computed.sqltext + # TODO ideally refactor logic used in parse_db here and compare it +# to_tsvector('english', first_name || ' ' || last_name) to_tsvector('english'::regconfig, ((first_name || ' '::text) || last_name)) +# or setweight variants + return + if schema_clause_text is None and db_column.autoincrement == True: assert db_clause_text.startswith("nextval") else: assert str(schema_clause_text) == str(db_clause_text) + if parsed_data_column: + # doesn't apply to autoincrement yet so ignoring this here + assert parsed_data_column.get( + "serverDefault", None) == schema_clause_text, schema_column.name -def _validate_column_type(schema_column: sa.Column, db_column: sa.Column, metadata: sa.MetaData, dialect: String): +def _validate_column_type(schema_column: sa.Column, db_column: sa.Column, metadata: sa.MetaData, dialect: String, parsed_data_column: Optional[dict] = None): # array type. validate contents if isinstance(schema_column.type, postgresql.ARRAY): assert isinstance(db_column.type, postgresql.ARRAY) + parsed_data_type = None + if parsed_data_column is not None: + assert parsed_data_column.get("type").get("dbType") == DBType.List + parsed_data_type = parsed_data_column.get( + "type").get("listElemType") + _validate_column_type_impl( - schema_column.type.item_type, db_column.type.item_type, metadata, dialect, db_column, schema_column) + schema_column.type.item_type, db_column.type.item_type, metadata, dialect, db_column, schema_column, parsed_data_type) else: + parsed_data_type = None + if parsed_data_column is not None: + parsed_data_type = parsed_data_column.get("type") + _validate_column_type_impl( - schema_column.type, db_column.type, metadata, dialect, db_column, schema_column) - pass + schema_column.type, db_column.type, metadata, dialect, db_column, schema_column, parsed_data_type) -def _validate_column_type_impl(schema_column_type, db_column_type, metadata: sa.MetaData, dialect, db_column: sa.Column, schema_column: sa.Column): +def _validate_column_type_impl(schema_column_type, db_column_type, metadata: sa.MetaData, dialect, db_column: sa.Column, schema_column: sa.Column, parsed_data_type: Optional[dict] = None): if isinstance(schema_column_type, sa.TIMESTAMP): # timezone not supported in sqlite so this is just ignored there @@ -304,6 +374,91 @@ def _validate_column_type_impl(schema_column_type, db_column_type, metadata: sa. assert str(schema_column_type) == str(db_column_type) + if parsed_data_type is not None: + _validate_parsed_data_type( + schema_column_type, parsed_data_type, metadata, dialect) + + +def _validate_parsed_data_type(schema_column_type, parsed_data_type: dict, metadata: sa.MetaData, dialect: str): + + if isinstance(schema_column_type, sa.TIMESTAMP): + # sqlite doesn't support timestamp with timezone + if schema_column_type.timezone and dialect != 'sqlite': + assert parsed_data_type == { + "dbType": DBType.Timestamptz + } + else: + assert parsed_data_type == { + "dbType": DBType.Timestamp + } + + if isinstance(schema_column_type, sa.Time): + # sqlite doesn't support with timezone + if schema_column_type.timezone and dialect != 'sqlite': + assert parsed_data_type == { + "dbType": DBType.Timetz + } + else: + assert parsed_data_type == { + "dbType": DBType.Time + } + + if isinstance(schema_column_type, sa.Date): + assert parsed_data_type == { + "dbType": DBType.Date + } + + if isinstance(schema_column_type, sa.Numeric): + assert parsed_data_type == { + "dbType": DBType.Float + } + + if isinstance(schema_column_type, postgresql.ENUM) or (isinstance(schema_column_type, sa.VARCHAR) and len(schema_column_type.enums) > 0): + db_sorted_enums = get_sorted_enum_values( + metadata.bind, schema_column_type.name) + + assert parsed_data_type == { + "dbType": DBType.Enum, + "values": db_sorted_enums, + } + return + + if isinstance(schema_column_type, postgresql.JSONB): + assert parsed_data_type == { + "dbType": DBType.JSONB + } + return + + if isinstance(schema_column_type, postgresql.JSON): + assert parsed_data_type == { + "dbType": DBType.JSON + } + + if isinstance(schema_column_type, postgresql.UUID): + assert parsed_data_type == { + "dbType": DBType.UUID + } + + if isinstance(schema_column_type, sa.String): + assert parsed_data_type == { + "dbType": DBType.String + } + + if isinstance(schema_column_type, sa.Boolean): + assert parsed_data_type == { + "dbType": DBType.Boolean + } + + if isinstance(schema_column_type, sa.Integer): + if isinstance(schema_column_type, sa.BigInteger) or schema_column_type.__visit_name__ == 'big_integer': + assert parsed_data_type == { + "dbType": DBType.BigInt + } + else: + assert parsed_data_type == { + "dbType": DBType.Int + } + def _validate_enum_column_type(metadata: sa.MetaData, db_column: sa.Column, schema_column: sa.Column): # has to be same length @@ -313,13 +468,8 @@ def _validate_enum_column_type(metadata: sa.MetaData, db_column: sa.Column, sche if schema_column.type.enums == db_column.type.enums: return - # we gotta go to the db and check the order - db_sorted_enums = [] - # https://www.postgresql.org/docs/9.5/functions-enum.html - query = "select unnest(enum_range(enum_first(null::%s)));" % ( - db_column.type.name) - for row in metadata.bind.execute(query): - db_sorted_enums.append(dict(row)['unnest']) + db_sorted_enums = get_sorted_enum_values( + metadata.bind, db_column.type.name) assert schema_column.type.enums == db_sorted_enums @@ -332,9 +482,13 @@ def _sort_fn(item): return type(item).__name__ + item.name -def _validate_indexes(schema_table: sa.Table, db_table: sa.Table, metadata: sa.MetaData, dialect: String): +def _validate_indexes(schema_table: sa.Table, db_table: sa.Table, metadata: sa.MetaData, dialect: String, parsed_data: Optional[dict]): # sort indexes so that the order for both are the same - schema_indexes = sorted(schema_table.indexes, key=_sort_fn) + # skip FullTextIndexes because not reflected + # we're ignoring FullTextIndexes for now + # they are tested/confirmed in parsed_data + schema_indexes = sorted([ + idx for idx in schema_table.indexes if not isinstance(idx, FullTextIndex)], key=_sort_fn) db_indexes = sorted(db_table.indexes, key=_sort_fn) assert len(schema_indexes) == len(db_indexes) @@ -347,8 +501,81 @@ def _validate_indexes(schema_table: sa.Table, db_table: sa.Table, metadata: sa.M for schema_column, db_column in zip(schema_index_columns, db_index_columns): _validate_column(schema_column, db_column, metadata, dialect) + if parsed_data: + parsed_indexes = parsed_data["indices"] + + # go through all indexes + for index in schema_table.indexes: + single_col = None + if len(index.columns) == 1: + single_col = index.columns[0] + + parsed_index = [ + i for i in parsed_indexes if i.get("name") == index.name] + + fulltext = None + if len(parsed_index) > 0: + fulltext = parsed_index[0].get('fulltext', None) + + if single_col is not None: + def_index_type = default_index(schema_table, single_col.name) + index_type = index.kwargs.get('postgresql_using') + if fulltext is None and ((index_type == False and def_index_type == 'btree') or def_index_type == index_type): + # when index is on one column, we choose to store it on the column + # in parsed_data since easier to read + assert parsed_data['fields'].get( + single_col.name).get('index', None) == True + continue + + assert len(parsed_index) == 1 + parsed_index = parsed_index[0] + + assert parsed_index.get("unique", False) == index.unique + + if parsed_index.get('fulltext', None) is not None: + fulltext = parsed_index.get('fulltext') + + generated_col = fulltext.get('generatedColumnName', None) + if generated_col: + assert len(index.columns) == 1 + assert index.columns[0].name == generated_col + + test_data = index.kwargs.get('test_data', {}) + + expected = { + 'indexType': index.kwargs.get('postgresql_using'), + 'language': test_data.get('language'), + 'generatedColumnName': generated_col, + } + if test_data.get('weights', None) is not None: + expected['weights'] = test_data.get('weights', None) + + assert fulltext == expected -def _validate_constraints(schema_table: sa.Table, db_table: sa.Table, dialect: String, metadata: sa.MetaData): + assert parsed_index.get( + "columns") == test_data.get('columns') + + else: + + info = index.kwwwww['info'] + + m = sqltext_regex.match(info['postgresql_using_internals']) + groups = m.groups() + lang = groups[0].rstrip("::regconfig").strip("'") + + assert fulltext == { + 'indexType': info['postgresql_using'], + 'language': lang, + } + assert parsed_index.get("columns") == info.get( + 'columns', [info.get('column')]) + + else: + assert parsed_index.get("columns") == [ + col.name for col in index.columns] + + +def _validate_constraints(schema_table: sa.Table, db_table: sa.Table, dialect: String, metadata: sa.MetaData, parsed_data: Optional[dict]): # sort constraints so that the order for both are the same schema_constraints = sorted(schema_table.constraints, key=_sort_fn) db_constraints = sorted(db_table.constraints, key=_sort_fn) @@ -396,9 +623,71 @@ def _validate_constraints(schema_table: sa.Table, db_table: sa.Table, dialect: S for schema_column, db_column in zip(schema_constraint_columns, db_constraint_columns): _validate_column(schema_column, db_column, metadata, dialect) + if parsed_data: + parsed_constraints = parsed_data["constraints"] + for constraint in schema_constraints: + single_col = None + if len(constraint.columns) == 1: + single_col = constraint.columns[0] + + constraint_type = None + condition = None + + if isinstance(constraint, sa.PrimaryKeyConstraint): + constraint_type = ConstraintType.PrimaryKey + if single_col is not None: + assert parsed_data["fields"][single_col.name].get( + 'primaryKey', None) == True + continue + + if isinstance(constraint, sa.UniqueConstraint): + constraint_type = ConstraintType.Unique + if single_col is not None: + assert parsed_data["fields"][single_col.name].get( + 'unique', None) == True + continue + + if isinstance(constraint, sa.ForeignKeyConstraint): + constraint_type = ConstraintType.ForeignKey + if single_col is not None: + assert parsed_data["fields"][single_col.name].get( + 'foreignKey', None) is not None + continue + + if isinstance(constraint, sa.CheckConstraint): + constraint_type = ConstraintType.Check + condition = constraint.sqltext + + parsed_constraint = [ + c for c in parsed_constraints if c.get("name") == constraint.name] + + assert len(parsed_constraint) == 1 + parsed_constraint = parsed_constraint[0] + + assert parsed_constraint.get("type") == constraint_type + assert parsed_constraint.get("columns") == [ + col.name for col in constraint.columns] + + assert get_clause_text(parsed_constraint.get( + "condition", None), None) == get_clause_text(condition, None) + + +def _validate_foreign_key(schema_column: sa.Column, db_column: sa.Column, parsed_data_column: Optional[dict]): + assert len(schema_column.foreign_keys) == len(db_column.foreign_keys) + + if parsed_data_column is not None: + fkey = parsed_data_column.get('foreignKey') + if len(schema_column.foreign_keys) == 0: + assert fkey is None + else: + assert fkey is not None + assert len(schema_column.foreign_keys) == 1 -def _validate_foreign_key(schema_column: sa.Column, db_column: sa.Column): - assert len(schema_column.foreign_keys) == len(schema_column.foreign_keys) + fkeyInfo = list(schema_column.foreign_keys)[0] + assert fkey == { + 'schema': ParseDB.table_to_node(fkeyInfo.column.table.name), + 'column': fkeyInfo.column.name, + } for db_fkey, schema_fkey in zip(db_column.foreign_keys, schema_column.foreign_keys): # similar to what we do in validate_table on column.type @@ -435,8 +724,9 @@ def make_changes_and_restore( metadata_change_func, r2_message, r3_message, - validate_schema=True, post_r2_func=None, + + ): r = new_test_runner(metadata_with_table) run_and_validate_with_standard_metadata_tables( @@ -450,6 +740,8 @@ def make_changes_and_restore( r2.run() + validate_metadata_after_change(r2, r2.get_metadata()) + # should have the expected files with the expected tables assert_num_files(r2, 2) assert_num_tables(r2, 2, ['accounts', 'alembic_version']) @@ -464,9 +756,6 @@ def make_changes_and_restore( r2.downgrade(delete_files=False, revision='-1') r2.upgrade() - if validate_schema: - validate_metadata_after_change(r2, r2.get_metadata()) - r3 = recreate_metadata_fixture( new_test_runner, conftest.metadata_with_base_table_restored(), r2) diff --git a/ts/src/schema/schema.ts b/ts/src/schema/schema.ts index bc6205a7f..ba2b51ac8 100644 --- a/ts/src/schema/schema.ts +++ b/ts/src/schema/schema.ts @@ -296,6 +296,7 @@ export interface TransformedUpdateOperation< // should eventually generate (boolean | null) etc // supported db types +// copied to auto_schema/auto_schema/parse_db.py export enum DBType { UUID = "UUID", Int64ID = "Int64ID", // unsupported right now