From f4d12421837ea9497b34eb081cc08db95de8b83f Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 16 Mar 2021 17:33:58 -0700 Subject: [PATCH] feat: script to benchmark DB migrations --- requirements/base.txt | 2 + scripts/benchmark_migration.py | 215 +++++++++++++++++++++++ setup.cfg | 2 +- setup.py | 3 +- superset/examples/big_data.py | 2 +- superset/utils/{data.py => mock_data.py} | 148 +++++++++++++++- 6 files changed, 360 insertions(+), 12 deletions(-) create mode 100644 scripts/benchmark_migration.py rename superset/utils/{data.py => mock_data.py} (51%) diff --git a/requirements/base.txt b/requirements/base.txt index 2c3c3b0664fc0..3596304d98922 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -109,6 +109,8 @@ geographiclib==1.50 # via geopy geopy==2.0.0 # via apache-superset +graphlib-backport==1.0.3 +# via apache-superset gunicorn==20.0.4 # via apache-superset holidays==0.10.3 diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py new file mode 100644 index 0000000000000..0faa92a88552b --- /dev/null +++ b/scripts/benchmark_migration.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import importlib.util +import logging +import re +import time +from collections import defaultdict +from inspect import getsource +from pathlib import Path +from types import ModuleType +from typing import Dict, List, Set, Type + +import click +from flask_appbuilder import Model +from flask_migrate import downgrade, upgrade +from graphlib import TopologicalSorter # pylint: disable=wrong-import-order +from sqlalchemy import inspect + +from superset import db +from superset.utils.mock_data import add_sample_rows + +logger = logging.getLogger(__name__) + + +def import_migration_script(filepath: Path) -> ModuleType: + """ + Import migration script as if it were a module. + """ + spec = importlib.util.spec_from_file_location(filepath.stem, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # type: ignore + return module + + +def extract_modified_tables(module: ModuleType) -> Set[str]: + """ + Extract the tables being modified by a migration script. + + This function uses a simple approach of looking at the source code of + the migration script looking for patterns. It could be improved by + actually traversing the AST. + """ + + tables: Set[str] = set() + for function in {"upgrade", "downgrade"}: + source = getsource(getattr(module, function)) + tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL)) + tables.update(re.findall(r'add_column\(\s*"(\w+?)"\s*,', source, re.DOTALL)) + tables.update(re.findall(r'drop_column\(\s*"(\w+?)"\s*,', source, re.DOTALL)) + + return tables + + +def find_models(module: ModuleType) -> List[Type[Model]]: + """ + Find all models in a migration script. + """ + models: List[Type[Model]] = [] + tables = extract_modified_tables(module) + + # add models defined explicitly in the migration script + queue = list(module.__dict__.values()) + while queue: + obj = queue.pop() + if hasattr(obj, "__tablename__"): + tables.add(obj.__tablename__) + elif isinstance(obj, list): + queue.extend(obj) + elif isinstance(obj, dict): + queue.extend(obj.values()) + + # add implicit models + # pylint: disable=no-member, protected-access + for obj in Model._decl_class_registry.values(): + if hasattr(obj, "__table__") and obj.__table__.fullname in tables: + models.append(obj) + + # sort topologically so we can create entities in order and + # maintain relationships (eg, create a database before creating + # a slice) + sorter = TopologicalSorter() + for model in models: + inspector = inspect(model) + dependent_tables: List[str] = [] + for column in inspector.columns.values(): + for foreign_key in column.foreign_keys: + dependent_tables.append(foreign_key.target_fullname.split(".")[0]) + sorter.add(model.__tablename__, *dependent_tables) + order = list(sorter.static_order()) + models.sort(key=lambda model: order.index(model.__tablename__)) + + return models + + +@click.command() +@click.argument("filepath") +@click.option("--limit", default=1000, help="Maximum number of entities.") +@click.option("--force", is_flag=True, help="Do not prompt for confirmation.") +@click.option("--no-auto-cleanup", is_flag=True, help="Do not remove created models.") +def main( + filepath: str, limit: int = 1000, force: bool = False, no_auto_cleanup: bool = False +) -> None: + auto_cleanup = not no_auto_cleanup + session = db.session() + + print(f"Importing migration script: {filepath}") + module = import_migration_script(Path(filepath)) + + revision: str = getattr(module, "revision", "") + down_revision: str = getattr(module, "down_revision", "") + if not revision or not down_revision: + raise Exception( + "Not a valid migration script, couldn't find down_revision/revision" + ) + + print(f"Migration goes from {down_revision} to {revision}") + current_revision = db.engine.execute( + "SELECT version_num FROM alembic_version" + ).scalar() + print(f"Current version of the DB is {current_revision}") + + print("\nIdentifying models used in the migration:") + models = find_models(module) + model_rows: Dict[Type[Model], int] = {} + for model in models: + rows = session.query(model).count() + print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") + model_rows[model] = rows + session.close() + + if current_revision != down_revision: + if not force: + click.confirm( + "\nRunning benchmark will downgrade the Superset DB to " + f"{down_revision} and upgrade to {revision} again. There may " + "be data loss in downgrades. Continue?", + abort=True, + ) + downgrade(revision=down_revision) + + print("Benchmarking migration") + results: Dict[str, float] = {} + start = time.time() + upgrade(revision=revision) + duration = time.time() - start + results["Current"] = duration + print(f"Migration on current DB took: {duration:.2f} seconds") + + min_entities = 10 + new_models: Dict[Type[Model], List[Model]] = defaultdict(list) + while min_entities <= limit: + downgrade(revision=down_revision) + print(f"Running with at least {min_entities} entities of each model") + for model in models: + missing = min_entities - model_rows[model] + if missing > 0: + print(f"- Adding {missing} entities to the {model.__name__} model") + try: + added_models = add_sample_rows(session, model, missing) + except Exception: + session.rollback() + raise + model_rows[model] = min_entities + session.commit() + + if auto_cleanup: + new_models[model].extend(added_models) + + start = time.time() + upgrade(revision=revision) + duration = time.time() - start + print(f"Migration for {min_entities}+ entities took: {duration:.2f} seconds") + results[f"{min_entities}+"] = duration + min_entities *= 10 + + if auto_cleanup: + print("Cleaning up DB") + # delete in reverse order of creation to handle relationships + for model, entities in list(new_models.items())[::-1]: + session.query(model).filter( + model.id.in_(entity.id for entity in entities) + ).delete(synchronize_session=False) + session.commit() + + if current_revision != revision and not force: + click.confirm(f"\nRevert DB to {revision}?", abort=True) + upgrade(revision=revision) + print("Reverted") + + print("\nResults:\n") + for label, duration in results.items(): + print(f"{label}: {duration:.2f} s") + + +if __name__ == "__main__": + from superset.app import create_app + + app = create_app() + with app.app_context(): + # pylint: disable=no-value-for-parameter + main() diff --git a/setup.cfg b/setup.cfg index d1b2e4db1047c..8fd75a0bd6c56 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,cachelib,celery,click,colorama,contextlib2,cron_descriptor,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_jwt_extended,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,freezegun,geohash,geopy,graphlib,holidays,humanize,isodate,jinja2,jwt,markdown,markupsafe,marshmallow,marshmallow_enum,msgpack,numpy,pandas,parameterized,parsedatetime,pathlib2,pgsanity,pkg_resources,polyline,prison,pyarrow,pyhive,pyparsing,pytest,pytz,redis,requests,retry,selenium,setuptools,simplejson,slack,sqlalchemy,sqlalchemy_utils,sqlparse,typing_extensions,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/setup.py b/setup.py index f839e2a32ef29..985aab45551fe 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,7 @@ def get_git_sha(): "flask-migrate", "flask-wtf", "geopy", + "graphlib-backport", "gunicorn>=20.0.2, <20.1", "humanize", "isodate", @@ -103,7 +104,7 @@ def get_git_sha(): "selenium>=3.141.0", "simplejson>=3.15.0", "slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions - "sqlalchemy>=1.3.16, <2.0, !=1.3.21", + "sqlalchemy>=1.3.16, <1.4, !=1.3.21", "sqlalchemy-utils>=0.36.6,<0.37", "sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562 "typing-extensions>=3.7.4.3,<4", # needed to support typing.Literal on py37 diff --git a/superset/examples/big_data.py b/superset/examples/big_data.py index d46e184903288..f837effd34bc4 100644 --- a/superset/examples/big_data.py +++ b/superset/examples/big_data.py @@ -20,7 +20,7 @@ import sqlalchemy.sql.sqltypes -from superset.utils.data import add_data, ColumnInfo +from superset.utils.mock_data import add_data, ColumnInfo COLUMN_TYPES = [ sqlalchemy.sql.sqltypes.INTEGER(), diff --git a/superset/utils/data.py b/superset/utils/mock_data.py similarity index 51% rename from superset/utils/data.py rename to superset/utils/mock_data.py index 9a1987c41dc85..06327ef89262b 100644 --- a/superset/utils/data.py +++ b/superset/utils/mock_data.py @@ -14,17 +14,30 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import decimal +import json +import logging +import os import random import string import sys from datetime import date, datetime, time, timedelta -from typing import Any, Callable, cast, Dict, List, Optional +from typing import Any, Callable, cast, Dict, List, Optional, Type +from uuid import uuid4 import sqlalchemy.sql.sqltypes +import sqlalchemy_utils +from flask_appbuilder import Model from sqlalchemy import Column, inspect, MetaData, Table +from sqlalchemy.orm import Session +from sqlalchemy.sql import func from sqlalchemy.sql.visitors import VisitableType from typing_extensions import TypedDict +from superset import db + +logger = logging.getLogger(__name__) + ColumnInfo = TypedDict( "ColumnInfo", { @@ -53,24 +66,35 @@ days_range = (MAXIMUM_DATE - MINIMUM_DATE).days +# pylint: disable=too-many-return-statements, too-many-branches def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]: - if isinstance(sqltype, sqlalchemy.sql.sqltypes.INTEGER): + if isinstance( + sqltype, (sqlalchemy.sql.sqltypes.INTEGER, sqlalchemy.sql.sqltypes.Integer) + ): return lambda: random.randrange(2147483647) if isinstance(sqltype, sqlalchemy.sql.sqltypes.BIGINT): return lambda: random.randrange(sys.maxsize) - if isinstance(sqltype, sqlalchemy.sql.sqltypes.VARCHAR): + if isinstance( + sqltype, (sqlalchemy.sql.sqltypes.VARCHAR, sqlalchemy.sql.sqltypes.String) + ): length = random.randrange(sqltype.length or 255) + length = max(8, length) # for unique values + length = min(100, length) # for FAB perms return lambda: "".join(random.choices(string.printable, k=length)) - if isinstance(sqltype, sqlalchemy.sql.sqltypes.TEXT): + if isinstance( + sqltype, (sqlalchemy.sql.sqltypes.TEXT, sqlalchemy.sql.sqltypes.Text) + ): length = random.randrange(65535) # "practicality beats purity" length = max(length, 2048) return lambda: "".join(random.choices(string.printable, k=length)) - if isinstance(sqltype, sqlalchemy.sql.sqltypes.BOOLEAN): + if isinstance( + sqltype, (sqlalchemy.sql.sqltypes.BOOLEAN, sqlalchemy.sql.sqltypes.Boolean) + ): return lambda: random.choice([True, False]) if isinstance( @@ -87,13 +111,49 @@ def get_type_generator(sqltype: sqlalchemy.sql.sqltypes) -> Callable[[], Any]: ) if isinstance( - sqltype, (sqlalchemy.sql.sqltypes.TIMESTAMP, sqlalchemy.sql.sqltypes.DATETIME) + sqltype, + ( + sqlalchemy.sql.sqltypes.TIMESTAMP, + sqlalchemy.sql.sqltypes.DATETIME, + sqlalchemy.sql.sqltypes.DateTime, + ), ): return lambda: datetime.fromordinal(MINIMUM_DATE.toordinal()) + timedelta( seconds=random.randrange(days_range * 86400) ) - raise Exception(f"Unknown type {sqltype}. Please add it to `get_type_generator`.") + if isinstance(sqltype, sqlalchemy.sql.sqltypes.Numeric): + # since decimal is used in some models to store time, return a value that + # is a reasonable timestamp + return lambda: decimal.Decimal(datetime.now().timestamp() * 1000) + + if isinstance(sqltype, sqlalchemy.sql.sqltypes.JSON): + return lambda: { + "".join(random.choices(string.printable, k=8)): random.randrange(65535) + for _ in range(10) + } + + if isinstance( + sqltype, + ( + sqlalchemy.sql.sqltypes.BINARY, + sqlalchemy_utils.types.encrypted.encrypted_type.EncryptedType, + ), + ): + length = random.randrange(sqltype.length or 255) + return lambda: os.urandom(length) + + if isinstance(sqltype, sqlalchemy_utils.types.uuid.UUIDType): + return uuid4 + + if isinstance(sqltype, sqlalchemy.sql.sqltypes.BLOB): + length = random.randrange(sqltype.length or 255) + return lambda: os.urandom(length) + + logger.warning( + "Unknown type %s. Please add it to `get_type_generator`.", type(sqltype) + ) + return lambda: "UNKNOWN TYPE" def add_data( @@ -161,5 +221,75 @@ def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, An def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]: - func = get_type_generator(column["type"]) - return [func() for _ in range(num_rows)] + gen = get_type_generator(column["type"]) + return [gen() for _ in range(num_rows)] + + +def add_sample_rows(session: Session, model: Type[Model], count: int) -> List[Model]: + """ + Add entities of a given model. + + :param Model model: a Superset/FAB model + :param int count: how many entities to generate and insert + """ + inspector = inspect(model) + + # select samples to copy relationship values + relationships = inspector.relationships.items() + samples = session.query(model).limit(count).all() if relationships else [] + + entities: List[Model] = [] + max_primary_key: Optional[int] = None + for i in range(count): + sample = samples[i % len(samples)] if samples else None + kwargs = {} + for column in inspector.columns.values(): + # for primary keys, keep incrementing + if column.primary_key: + if max_primary_key is None: + max_primary_key = ( + session.query(func.max(getattr(model, column.name))).scalar() + or 0 + ) + max_primary_key += 1 + kwargs[column.name] = max_primary_key + + # if the column has a foreign key, copy the value from an existing entity + elif column.foreign_keys: + if sample: + kwargs[column.name] = getattr(sample, column.name) + else: + kwargs[column.name] = get_valid_foreign_key(column) + + # should be an enum but it's not + elif column.name == "datasource_type": + kwargs[column.name] = "table" + + # otherwise, generate a random value based on the type + else: + kwargs[column.name] = generate_value(column) + + entities.append(model(**kwargs)) + + session.add_all(entities) + return entities + + +def get_valid_foreign_key(column: Column) -> Any: + foreign_key = list(column.foreign_keys)[0] + table_name, column_name = foreign_key.target_fullname.split(".", 1) + return db.engine.execute(f"SELECT {column_name} FROM {table_name} LIMIT 1").scalar() + + +def generate_value(column: Column) -> Any: + if hasattr(column.type, "enums"): + return random.choice(column.type.enums) + + json_as_string = "json" in column.name.lower() and isinstance( + column.type, sqlalchemy.sql.sqltypes.Text + ) + type_ = sqlalchemy.sql.sqltypes.JSON() if json_as_string else column.type + value = get_type_generator(type_)() + if json_as_string: + value = json.dumps(value) + return value