From 551afff58ea7bc1047775bfcd5d80b812fb3f682 Mon Sep 17 00:00:00 2001 From: Jo <46752250+georgesittas@users.noreply.github.com> Date: Mon, 28 Oct 2024 22:40:19 +0200 Subject: [PATCH] Fix: handle a Move edge case in the semantic differ (#4295) * Fix: handle a Move edge case in the semantic differ * Fixup --- sqlglot/diff.py | 53 ++++++++++++++++++++------ tests/test_diff.py | 94 +++++++++++++++++++++++++++++----------------- 2 files changed, 102 insertions(+), 45 deletions(-) diff --git a/sqlglot/diff.py b/sqlglot/diff.py index b637c5181f..ec1e31e6a7 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -10,6 +10,7 @@ from collections import defaultdict from dataclasses import dataclass from heapq import heappop, heappush +from itertools import chain from sqlglot import Dialect, expressions as exp from sqlglot.helper import seq_get @@ -112,11 +113,19 @@ def diff( def compute_node_mappings( original: exp.Expression, copy: exp.Expression ) -> t.Dict[int, exp.Expression]: - return { - id(old_node): new_node - for old_node, new_node in zip(original.walk(), copy.walk()) - if id(old_node) in matching_ids - } + node_mapping = {} + for old_node, new_node in zip( + reversed(tuple(original.walk())), reversed(tuple(copy.walk())) + ): + # We cache the hash of each new node here to speed up equality comparisons. If the input + # trees aren't copied, these hashes will be evicted before returning the edit script. + new_node._hash = hash(new_node) + + old_node_id = id(old_node) + if old_node_id in matching_ids: + node_mapping[old_node_id] = new_node + + return node_mapping source_copy = source.copy() if copy else source target_copy = target.copy() if copy else target @@ -127,13 +136,19 @@ def compute_node_mappings( } matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings] - return ChangeDistiller(**kwargs).diff( + edit_script = ChangeDistiller(**kwargs).diff( source_copy, target_copy, matchings=matchings_copy, delta_only=delta_only, ) + if not copy: + for node in chain(source.walk(), target.walk()): + node._hash = None + + return edit_script + # The expression types for which Update edits are allowed. UPDATABLE_EXPRESSION_TYPES = ( @@ -199,11 +214,27 @@ def _generate_edit_script(self, matchings: t.Dict[int, int], delta_only: bool) - source_node = self._source_index[kept_source_node_id] target_node = self._target_index[kept_target_node_id] - if ( - not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) - or source_node == target_node - ): - edit_script.extend(self._generate_move_edits(source_node, target_node, matchings)) + identical_nodes = source_node == target_node + + if not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) or identical_nodes: + if identical_nodes: + source_parent = source_node.parent + target_parent = target_node.parent + + if ( + (source_parent and not target_parent) + or (not source_parent and target_parent) + or ( + source_parent + and target_parent + and matchings.get(id(source_parent)) != id(target_parent) + ) + ): + edit_script.append(Move(source=source_node, target=target_node)) + else: + edit_script.extend( + self._generate_move_edits(source_node, target_node, matchings) + ) source_non_expression_leaves = dict(_get_non_expression_leaves(source_node)) target_non_expression_leaves = dict(_get_non_expression_leaves(target_node)) diff --git a/tests/test_diff.py b/tests/test_diff.py index 65802277f7..440502eaa3 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -2,7 +2,6 @@ from sqlglot import exp, parse_one from sqlglot.diff import Insert, Move, Remove, Update, diff -from sqlglot.expressions import Join, to_table def diff_delta_only(source, target, matchings=None, **kwargs): @@ -14,22 +13,24 @@ def test_simple(self): self._validate_delta_only( diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")), [ - Remove(parse_one("a + b")), # the Add node - Insert(parse_one("a - b")), # the Sub node + Remove(expression=parse_one("a + b")), # the Add node + Insert(expression=parse_one("a - b")), # the Sub node + Move(source=parse_one("a"), target=parse_one("a")), # the `a` Column node + Move(source=parse_one("b"), target=parse_one("b")), # the `b` Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), [ - Remove(parse_one("b")), # the Column node + Remove(expression=parse_one("b")), # the Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), [ - Insert(parse_one("c")), # the Column node + Insert(expression=parse_one("c")), # the Column node ], ) @@ -40,8 +41,8 @@ def test_simple(self): ), [ Update( - to_table("table_one", quoted=False), - to_table("table_two", quoted=False), + source=exp.to_table("table_one", quoted=False), + target=exp.to_table("table_two", quoted=False), ), # the Table node ], ) @@ -53,8 +54,12 @@ def test_lambda(self): ), [ Update( - exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]), - exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]), + source=exp.Lambda( + this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")] + ), + target=exp.Lambda( + this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")] + ), ), ], ) @@ -65,8 +70,8 @@ def test_udf(self): parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()') ), [ - Insert(parse_one('"my.udf2"()')), - Remove(parse_one('"my.udf1"()')), + Insert(expression=parse_one('"my.udf2"()')), + Remove(expression=parse_one('"my.udf1"()')), ], ) self._validate_delta_only( @@ -75,8 +80,8 @@ def test_udf(self): parse_one('SELECT a, b, "my.udf"(x, y, w)'), ), [ - Insert(exp.column("w")), - Remove(exp.column("z")), + Insert(expression=exp.column("w")), + Remove(expression=exp.column("z")), ], ) @@ -132,6 +137,19 @@ def test_node_position_changed(self): ], ) + expr_src = parse_one("SELECT a as a, b as b FROM t WHERE CONCAT('a', 'b') = 'ab'") + expr_tgt = parse_one("SELECT a as a FROM t WHERE CONCAT('a', 'b', b) = 'ab'") + + b_alias = expr_src.selects[1] + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Remove(expression=b_alias), + Move(source=b_alias.this, target=expr_tgt.find(exp.Concat).expressions[-1]), + ], + ) + def test_cte(self): expr_src = """ WITH @@ -149,23 +167,30 @@ def test_cte(self): self._validate_delta_only( diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)), [ - Remove(parse_one("LOWER(c) AS c")), # the Alias node - Remove(parse_one("LOWER(c)")), # the Lower node - Remove(parse_one("'filter'")), # the Literal node - Insert(parse_one("'different_filter'")), # the Literal node + Remove(expression=parse_one("LOWER(c) AS c")), # the Alias node + Remove(expression=parse_one("LOWER(c)")), # the Lower node + Remove(expression=parse_one("'filter'")), # the Literal node + Insert(expression=parse_one("'different_filter'")), # the Literal node + Move(source=parse_one("c"), target=parse_one("c")), # the new Column c ], ) def test_join(self): - expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key" - expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key" + expr_src = parse_one("SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key") + expr_tgt = parse_one("SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key") - changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)) + src_join = expr_src.find(exp.Join) + tgt_join = expr_tgt.find(exp.Join) - self.assertEqual(len(changes), 2) - self.assertTrue(isinstance(changes[0], Remove)) - self.assertTrue(isinstance(changes[1], Insert)) - self.assertTrue(all(isinstance(c.expression, Join) for c in changes)) + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Remove(expression=src_join), + Insert(expression=tgt_join), + Move(source=exp.to_table("t2"), target=exp.to_table("t2")), + Move(source=src_join.args["on"], target=tgt_join.args["on"]), + ], + ) def test_window_functions(self): expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") @@ -176,8 +201,8 @@ def test_window_functions(self): self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ - Remove(parse_one("ROW_NUMBER()")), - Insert(parse_one("RANK()")), + Remove(expression=parse_one("ROW_NUMBER()")), + Insert(expression=parse_one("RANK()")), Update(source=expr_src.selects[0], target=expr_tgt.selects[0]), ], ) @@ -197,20 +222,21 @@ def test_pre_matchings(self): self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ - Remove(expr_src), - Insert(expr_tgt), - Insert(exp.Literal.number(2)), - Insert(exp.Literal.number(3)), - Insert(exp.Literal.number(4)), + Remove(expression=expr_src), + Insert(expression=expr_tgt), + Insert(expression=exp.Literal.number(2)), + Insert(expression=exp.Literal.number(3)), + Insert(expression=exp.Literal.number(4)), + Move(source=exp.Literal.number(1), target=exp.Literal.number(1)), ], ) self._validate_delta_only( diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), [ - Insert(exp.Literal.number(2)), - Insert(exp.Literal.number(3)), - Insert(exp.Literal.number(4)), + Insert(expression=exp.Literal.number(2)), + Insert(expression=exp.Literal.number(3)), + Insert(expression=exp.Literal.number(4)), ], )