Skip to content

Commit

Permalink
Add string normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharsadhwani committed Sep 23, 2023
1 parent fd3e5e1 commit 0c69069
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 2 deletions.
25 changes: 25 additions & 0 deletions src/black/linegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from black.strings import (
fix_docstring,
get_string_prefix,
normalize_fstring_quotes,
normalize_string_prefix,
normalize_string_quotes,
normalize_unicode_escape_sequences,
Expand Down Expand Up @@ -480,6 +481,30 @@ def visit_STRING(self, leaf: Leaf) -> Iterator[Line]:

yield from self.visit_default(leaf)

def visit_fstring(self, node: Node) -> Iterator[Line]:
"""Bunch of hacks here. Needs improvement."""
fstring_start = node.children[0]
fstring_end = node.children[-1]

quote_char = fstring_end.value[0]
quote_idx = fstring_start.value.index(quote_char)
prefix, quote = fstring_start.value[:quote_idx], fstring_start.value[quote_idx:]
assert 'f' in prefix or 'F' in prefix
assert quote == fstring_end.value

is_raw_fstring = 'r' in prefix or 'R' in prefix
middles = [node for node in node.children if node.type == token.FSTRING_MIDDLE]
# if ''.join(m.value for m in middles) == 'foo':
# breakpoint()

if self.mode.string_normalization:
middles, quote = normalize_fstring_quotes(quote, middles, is_raw_fstring)

fstring_start.value = prefix + quote
fstring_end.value = quote

yield from self.visit_default(node)

def __post_init__(self) -> None:
"""You are in a twisty little maze of passages."""
self.current_line = Line(mode=self.mode)
Expand Down
69 changes: 67 additions & 2 deletions src/black/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def _cached_compile(pattern: str) -> Pattern[str]:
def normalize_string_quotes(s: str) -> str:
"""Prefer double quotes but only if it doesn't cause more escaping.
Adds or removes backslashes as appropriate. Doesn't parse and fix
strings nested in f-strings.
Adds or removes backslashes as appropriate.
"""
value = s.lstrip(STRING_PREFIX_CHARS)
if value[:3] == '"""':
Expand Down Expand Up @@ -215,6 +214,7 @@ def normalize_string_quotes(s: str) -> str:
s = f"{prefix}{orig_quote}{body}{orig_quote}"
new_body = sub_twice(escaped_orig_quote, rf"\1\2{orig_quote}", new_body)
new_body = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_body)
# TODO: can probably be removed
if "f" in prefix.casefold():
matches = re.findall(
r"""
Expand Down Expand Up @@ -243,6 +243,71 @@ def normalize_string_quotes(s: str) -> str:

return f"{prefix}{new_quote}{new_body}{new_quote}"

def normalize_fstring_quotes(
quote: str,
middles: list[str],
is_raw_fstring: bool
) -> tuple[str, str]:
"""Prefer double quotes but only if it doesn't cause more escaping.
Adds or removes backslashes as appropriate.
"""
if quote == '"""':
return middles, quote

elif quote == "'''":
new_quote = '"""'
elif quote == '"':
new_quote = "'"
else:
new_quote = '"'

unescaped_new_quote = _cached_compile(rf"(([^\\]|^)(\\\\)*){new_quote}")
escaped_new_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){new_quote}")
escaped_orig_quote = _cached_compile(rf"([^\\]|^)\\((?:\\\\)*){quote}")
if is_raw_fstring:
for middle in middles:
if unescaped_new_quote.search(middle.value):
# There's at least one unescaped new_quote in this raw string
# so converting is impossible
return middles, quote

# Do not introduce or remove backslashes in raw strings
return middles, new_quote

new_segments = []
for middle in middles:
segment = middle.value
# remove unnecessary escapes
new_segment = sub_twice(escaped_new_quote, rf"\1\2{new_quote}", segment)
if segment != new_segment:
# Consider the string without unnecessary escapes as the original
middle.value = new_segment

new_segment = sub_twice(escaped_orig_quote, rf"\1\2{quote}", new_segment)
new_segment = sub_twice(unescaped_new_quote, rf"\1\\{new_quote}", new_segment)
new_segments.append(new_segment)


if new_quote == '"""' and new_segments[-1][-1:] == '"':
# edge case:
new_segments[-1] = new_segments[-1][:-1] + '\\"'

for middle, new_segment in zip(middles, new_segments):
orig_escape_count = middle.value.count("\\")
new_escape_count = new_segment.count("\\")

if new_escape_count > orig_escape_count:
return middles, quote # Do not introduce more escaping

if new_escape_count == orig_escape_count and quote == '"':
return middles, quote # Prefer double quotes

for middle, new_segment in zip(middles, new_segments):
middle.value = new_segment

return middles, new_quote


def normalize_unicode_escape_sequences(leaf: Leaf) -> None:
"""Replace hex codes in Unicode escape sequences with lowercase representation."""
Expand Down

0 comments on commit 0c69069

Please sign in to comment.