Skip to content

Commit

Permalink
Fix: handle a Move edge case in the semantic differ (#4295)
Browse files Browse the repository at this point in the history
* Fix: handle a Move edge case in the semantic differ

* Fixup
  • Loading branch information
georgesittas authored Oct 28, 2024
1 parent b072366 commit 551afff
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 45 deletions.
53 changes: 42 additions & 11 deletions sqlglot/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
from itertools import chain

from sqlglot import Dialect, expressions as exp
from sqlglot.helper import seq_get
Expand Down Expand Up @@ -112,11 +113,19 @@ def diff(
def compute_node_mappings(
original: exp.Expression, copy: exp.Expression
) -> t.Dict[int, exp.Expression]:
return {
id(old_node): new_node
for old_node, new_node in zip(original.walk(), copy.walk())
if id(old_node) in matching_ids
}
node_mapping = {}
for old_node, new_node in zip(
reversed(tuple(original.walk())), reversed(tuple(copy.walk()))
):
# We cache the hash of each new node here to speed up equality comparisons. If the input
# trees aren't copied, these hashes will be evicted before returning the edit script.
new_node._hash = hash(new_node)

old_node_id = id(old_node)
if old_node_id in matching_ids:
node_mapping[old_node_id] = new_node

return node_mapping

source_copy = source.copy() if copy else source
target_copy = target.copy() if copy else target
Expand All @@ -127,13 +136,19 @@ def compute_node_mappings(
}
matchings_copy = [(node_mappings[id(s)], node_mappings[id(t)]) for s, t in matchings]

return ChangeDistiller(**kwargs).diff(
edit_script = ChangeDistiller(**kwargs).diff(
source_copy,
target_copy,
matchings=matchings_copy,
delta_only=delta_only,
)

if not copy:
for node in chain(source.walk(), target.walk()):
node._hash = None

return edit_script


# The expression types for which Update edits are allowed.
UPDATABLE_EXPRESSION_TYPES = (
Expand Down Expand Up @@ -199,11 +214,27 @@ def _generate_edit_script(self, matchings: t.Dict[int, int], delta_only: bool) -
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]

if (
not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES)
or source_node == target_node
):
edit_script.extend(self._generate_move_edits(source_node, target_node, matchings))
identical_nodes = source_node == target_node

if not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) or identical_nodes:
if identical_nodes:
source_parent = source_node.parent
target_parent = target_node.parent

if (
(source_parent and not target_parent)
or (not source_parent and target_parent)
or (
source_parent
and target_parent
and matchings.get(id(source_parent)) != id(target_parent)
)
):
edit_script.append(Move(source=source_node, target=target_node))
else:
edit_script.extend(
self._generate_move_edits(source_node, target_node, matchings)
)

source_non_expression_leaves = dict(_get_non_expression_leaves(source_node))
target_non_expression_leaves = dict(_get_non_expression_leaves(target_node))
Expand Down
94 changes: 60 additions & 34 deletions tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Move, Remove, Update, diff
from sqlglot.expressions import Join, to_table


def diff_delta_only(source, target, matchings=None, **kwargs):
Expand All @@ -14,22 +13,24 @@ def test_simple(self):
self._validate_delta_only(
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
[
Remove(parse_one("a + b")), # the Add node
Insert(parse_one("a - b")), # the Sub node
Remove(expression=parse_one("a + b")), # the Add node
Insert(expression=parse_one("a - b")), # the Sub node
Move(source=parse_one("a"), target=parse_one("a")), # the `a` Column node
Move(source=parse_one("b"), target=parse_one("b")), # the `b` Column node
],
)

self._validate_delta_only(
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
[
Remove(parse_one("b")), # the Column node
Remove(expression=parse_one("b")), # the Column node
],
)

self._validate_delta_only(
diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
[
Insert(parse_one("c")), # the Column node
Insert(expression=parse_one("c")), # the Column node
],
)

Expand All @@ -40,8 +41,8 @@ def test_simple(self):
),
[
Update(
to_table("table_one", quoted=False),
to_table("table_two", quoted=False),
source=exp.to_table("table_one", quoted=False),
target=exp.to_table("table_two", quoted=False),
), # the Table node
],
)
Expand All @@ -53,8 +54,12 @@ def test_lambda(self):
),
[
Update(
exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]),
source=exp.Lambda(
this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]
),
target=exp.Lambda(
this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]
),
),
],
)
Expand All @@ -65,8 +70,8 @@ def test_udf(self):
parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')
),
[
Insert(parse_one('"my.udf2"()')),
Remove(parse_one('"my.udf1"()')),
Insert(expression=parse_one('"my.udf2"()')),
Remove(expression=parse_one('"my.udf1"()')),
],
)
self._validate_delta_only(
Expand All @@ -75,8 +80,8 @@ def test_udf(self):
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
),
[
Insert(exp.column("w")),
Remove(exp.column("z")),
Insert(expression=exp.column("w")),
Remove(expression=exp.column("z")),
],
)

Expand Down Expand Up @@ -132,6 +137,19 @@ def test_node_position_changed(self):
],
)

expr_src = parse_one("SELECT a as a, b as b FROM t WHERE CONCAT('a', 'b') = 'ab'")
expr_tgt = parse_one("SELECT a as a FROM t WHERE CONCAT('a', 'b', b) = 'ab'")

b_alias = expr_src.selects[1]

self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
Remove(expression=b_alias),
Move(source=b_alias.this, target=expr_tgt.find(exp.Concat).expressions[-1]),
],
)

def test_cte(self):
expr_src = """
WITH
Expand All @@ -149,23 +167,30 @@ def test_cte(self):
self._validate_delta_only(
diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)),
[
Remove(parse_one("LOWER(c) AS c")), # the Alias node
Remove(parse_one("LOWER(c)")), # the Lower node
Remove(parse_one("'filter'")), # the Literal node
Insert(parse_one("'different_filter'")), # the Literal node
Remove(expression=parse_one("LOWER(c) AS c")), # the Alias node
Remove(expression=parse_one("LOWER(c)")), # the Lower node
Remove(expression=parse_one("'filter'")), # the Literal node
Insert(expression=parse_one("'different_filter'")), # the Literal node
Move(source=parse_one("c"), target=parse_one("c")), # the new Column c
],
)

def test_join(self):
expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key"
expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key"
expr_src = parse_one("SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key")
expr_tgt = parse_one("SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key")

changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt))
src_join = expr_src.find(exp.Join)
tgt_join = expr_tgt.find(exp.Join)

self.assertEqual(len(changes), 2)
self.assertTrue(isinstance(changes[0], Remove))
self.assertTrue(isinstance(changes[1], Insert))
self.assertTrue(all(isinstance(c.expression, Join) for c in changes))
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
Remove(expression=src_join),
Insert(expression=tgt_join),
Move(source=exp.to_table("t2"), target=exp.to_table("t2")),
Move(source=src_join.args["on"], target=tgt_join.args["on"]),
],
)

def test_window_functions(self):
expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)")
Expand All @@ -176,8 +201,8 @@ def test_window_functions(self):
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
Remove(parse_one("ROW_NUMBER()")),
Insert(parse_one("RANK()")),
Remove(expression=parse_one("ROW_NUMBER()")),
Insert(expression=parse_one("RANK()")),
Update(source=expr_src.selects[0], target=expr_tgt.selects[0]),
],
)
Expand All @@ -197,20 +222,21 @@ def test_pre_matchings(self):
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
Remove(expr_src),
Insert(expr_tgt),
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
Remove(expression=expr_src),
Insert(expression=expr_tgt),
Insert(expression=exp.Literal.number(2)),
Insert(expression=exp.Literal.number(3)),
Insert(expression=exp.Literal.number(4)),
Move(source=exp.Literal.number(1), target=exp.Literal.number(1)),
],
)

self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
[
Insert(exp.Literal.number(2)),
Insert(exp.Literal.number(3)),
Insert(exp.Literal.number(4)),
Insert(expression=exp.Literal.number(2)),
Insert(expression=exp.Literal.number(3)),
Insert(expression=exp.Literal.number(4)),
],
)

Expand Down

0 comments on commit 551afff

Please sign in to comment.