From a2a86e253de8965379642740e49e865b95b9585e Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Fri, 9 Oct 2020 16:44:07 -0700 Subject: [PATCH] Batch commit and fix json load error --- ...0de1855_add_uuid_column_to_import_mixin.py | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py index 46ee76e937c62..bffa983c2c492 100644 --- a/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py +++ b/superset/migrations/versions/b56500de1855_add_uuid_column_to_import_mixin.py @@ -24,6 +24,7 @@ import json import os import time +from json.decoder import JSONDecodeError from uuid import uuid4 import sqlalchemy as sa @@ -120,8 +121,12 @@ def add_uuids(table_name, session, batch_size=default_batch_size): print(f"Done. Assigned {count} uuids in {time.time() - start_time:.3f}s.") -def update_position_json(dashboard, session, uuid_map={}): - layout = json.loads(dashboard.position_json or "{}") +def update_position_json(dashboard, session, uuid_map): + try: + layout = json.loads(dashboard.position_json or "{}") + except JSONDecodeError: + layout = {} + for object_ in layout.values(): if ( isinstance(object_, dict) @@ -136,7 +141,27 @@ def update_position_json(dashboard, session, uuid_map={}): dashboard.position_json = json.dumps(layout, indent=4) session.merge(dashboard) + + +def update_dashboards(session, uuid_map): + message = ( + "Updating dasboard position json with slice uuid.." + if uuid_map + else "Cleaning up slice uuid from dashboard position json.." + ) + print(f"\n{message}\r", end="") + + query = session.query(models["dashboards"]) + dashboard_count = query.count() + for i, dashboard in enumerate(query.all()): + update_position_json(dashboard, session, uuid_map) + if i and i % default_batch_size == 0: + session.commit() + print(f"{message} {i+1}/{dashboard_count}\r", end="") + session.commit() + # Extra whitespace to override very long numbers, e.g. 99999/99999. + print(f"{message} Done. \n") def upgrade(): @@ -165,21 +190,14 @@ def upgrade(): except OperationalError: pass - message = "Updating dashboard position json with slice uuid.." - print(f"\n{message}\r", end="") - # add UUID to Dashboard.position_json - Dashboard = models["dashboards"] - Slice = models["slices"] slice_uuid_map = { slc.id: slc.uuid - for slc in session.query(Slice).options(load_only("id", "uuid")).all() + for slc in session.query(models["slices"]) + .options(load_only("id", "uuid")) + .all() } - dashboard_count = session.query(Dashboard).count() - for i, dashboard in enumerate(session.query(Dashboard).all()): - update_position_json(dashboard, session, slice_uuid_map) - print(f"{message} {i+1}/{dashboard_count}\r", end="") - print(f"{message} Done.") + update_dashboards(session, slice_uuid_map) def downgrade(): @@ -187,9 +205,7 @@ def downgrade(): session = db.Session(bind=bind) # remove uuid from position_json - Dashboard = models["dashboards"] - for dashboard in session.query(Dashboard).all(): - update_position_json(dashboard, session, {}) + update_dashboards(session, {}) # remove uuid column for table_name, model in models.items():