Skip to content

Commit

Permalink
Fix crashes with comments in parentheses (#4453)
Browse files Browse the repository at this point in the history
Co-authored-by: Jelle Zijlstra <[email protected]>
  • Loading branch information
hauntsaninja and JelleZijlstra authored Sep 16, 2024
1 parent b4d6d86 commit 2a45cec
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

<!-- Changes that affect Black's stable style -->

- Fix crashes involving comments in parenthesised return types or `X | Y` style unions.
(#4453)

### Preview style

<!-- Changes that affect Black's preview style -->
Expand Down
82 changes: 49 additions & 33 deletions src/black/linegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,47 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None
)


def _ensure_trailing_comma(
leaves: List[Leaf], original: Line, opening_bracket: Leaf
) -> bool:
if not leaves:
return False
# Ensure a trailing comma for imports
if original.is_import:
return True
# ...and standalone function arguments
if not original.is_def:
return False
if opening_bracket.value != "(":
return False
# Don't add commas if we already have any commas
if any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
):
return False

# Find a leaf with a parent (comments don't have parents)
leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None)
if leaf_with_parent is None:
return True
# Don't add commas inside parenthesized return annotations
if get_annotation_type(leaf_with_parent) == "return":
return False
# Don't add commas inside PEP 604 unions
if (
leaf_with_parent.parent
and leaf_with_parent.parent.next_sibling
and leaf_with_parent.parent.next_sibling.type == token.VBAR
):
return False
return True


def bracket_split_build_line(
leaves: List[Leaf],
original: Line,
Expand All @@ -1099,40 +1140,15 @@ def bracket_split_build_line(
if component is _BracketSplitComponent.body:
result.inside_brackets = True
result.depth += 1
if leaves:
no_commas = (
# Ensure a trailing comma for imports and standalone function arguments
original.is_def
# Don't add one after any comments or within type annotations
and opening_bracket.value == "("
# Don't add one if there's already one there
and not any(
leaf.type == token.COMMA
and (
Preview.typed_params_trailing_comma not in original.mode
or not is_part_of_annotation(leaf)
)
for leaf in leaves
)
# Don't add one inside parenthesized return annotations
and get_annotation_type(leaves[0]) != "return"
# Don't add one inside PEP 604 unions
and not (
leaves[0].parent
and leaves[0].parent.next_sibling
and leaves[0].parent.next_sibling.type == token.VBAR
)
)

if original.is_import or no_commas:
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue
if _ensure_trailing_comma(leaves, original, opening_bracket):
for i in range(len(leaves) - 1, -1, -1):
if leaves[i].type == STANDALONE_COMMENT:
continue

if leaves[i].type != token.COMMA:
new_comma = Leaf(token.COMMA, ",")
leaves.insert(i + 1, new_comma)
break
if leaves[i].type != token.COMMA:
new_comma = Leaf(token.COMMA, ",")
leaves.insert(i + 1, new_comma)
break

leaves_to_track: Set[LeafID] = set()
if component is _BracketSplitComponent.head:
Expand Down
1 change: 1 addition & 0 deletions src/black/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]:

def is_part_of_annotation(leaf: Leaf) -> bool:
"""Returns whether this leaf is part of a type annotation."""
assert leaf.parent is not None
return get_annotation_type(leaf) is not None


Expand Down
2 changes: 1 addition & 1 deletion src/black/trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def do_match(self, line: Line) -> TMatchResult:
break
i += 1

if not is_part_of_annotation(leaf) and not contains_comment:
if not contains_comment and not is_part_of_annotation(leaf):
string_indices.append(idx)

# Advance to the next non-STRING leaf.
Expand Down
1 change: 1 addition & 0 deletions tests/data/cases/funcdef_return_type_trailing_comma.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def SimplePyFn(
Buffer[UInt8, 2],
Buffer[UInt8, 2],
]: ...

# output
# normal, short, function definition
def foo(a, b) -> tuple[int, float]: ...
Expand Down
130 changes: 130 additions & 0 deletions tests/data/cases/function_trailing_comma.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr
argument1, (one, two,), argument4, argument5, argument6
)

def foo() -> (
# comment inside parenthesised return type
int
):
...

def foo() -> (
# comment inside parenthesised return type
# more
int
# another
):
...

def foo() -> (
# comment inside parenthesised new union return type
int | str | bytes
):
...

def foo() -> (
# comment inside plain tuple
):
pass

def foo(arg: (# comment with non-return annotation
int
# comment with non-return annotation
)):
pass

def foo(arg: (# comment with non-return annotation
int | range | memoryview
# comment with non-return annotation
)):
pass

def foo(arg: (# only before
int
)):
pass

def foo(arg: (
int
# only after
)):
pass

variable: ( # annotation
because
# why not
)

variable: (
because
# why not
)

# output

def f(
Expand Down Expand Up @@ -176,3 +234,75 @@ def func() -> (
argument5,
argument6,
)


def foo() -> (
# comment inside parenthesised return type
int
): ...


def foo() -> (
# comment inside parenthesised return type
# more
int
# another
): ...


def foo() -> (
# comment inside parenthesised new union return type
int
| str
| bytes
): ...


def foo() -> (
# comment inside plain tuple
):
pass


def foo(
arg: ( # comment with non-return annotation
int
# comment with non-return annotation
),
):
pass


def foo(
arg: ( # comment with non-return annotation
int
| range
| memoryview
# comment with non-return annotation
),
):
pass


def foo(arg: int): # only before
pass


def foo(
arg: (
int
# only after
),
):
pass


variable: ( # annotation
because
# why not
)

variable: (
because
# why not
)

0 comments on commit 2a45cec

Please sign in to comment.