Skip to content

Commit

Permalink
nonglobal rules
Browse files Browse the repository at this point in the history
  • Loading branch information
Бекчурин Владислав committed Oct 27, 2020
1 parent e391e56 commit 8d51874
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 80 deletions.
133 changes: 72 additions & 61 deletions test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,128 +21,138 @@
class TestUnitsSimpleString(unittest.TestCase):

def test_preferred_single(self):
unify.rules['preferred_quote'] = "'"
rules = {'preferred_quote': "'"}

result = unify.format_code('"foo"')
result = unify.format_code('"foo"', rules)
self.assertEqual(result, "'foo'")

result = unify.format_code('f"foo"')
result = unify.format_code('f"foo"', rules)
self.assertEqual(result, "f'foo'")

result = unify.format_code('r"foo"')
result = unify.format_code('r"foo"', rules)
self.assertEqual(result, "r'foo'")

result = unify.format_code('u"foo"')
result = unify.format_code('u"foo"', rules)
self.assertEqual(result, "u'foo'")

result = unify.format_code('b"foo"')
result = unify.format_code('b"foo"', rules)
self.assertEqual(result, "b'foo'")

def test_preferred_double(self):
unify.rules['preferred_quote'] = '"'
rules = {'preferred_quote': '"'}

result = unify.format_code("'foo'")
result = unify.format_code("'foo'", rules)
self.assertEqual(result, '"foo"')

result = unify.format_code("f'foo'")
result = unify.format_code("f'foo'", rules)
self.assertEqual(result, 'f"foo"')

result = unify.format_code("r'foo'")
result = unify.format_code("r'foo'", rules)
self.assertEqual(result, 'r"foo"')

result = unify.format_code("u'foo'")
result = unify.format_code("u'foo'", rules)
self.assertEqual(result, 'u"foo"')

result = unify.format_code("b'foo'")
result = unify.format_code("b'foo'", rules)
self.assertEqual(result, 'b"foo"')

def test_keep_single(self):
unify.rules['preferred_quote'] = "'"
result = unify.format_code("'foo'")
rules = {'preferred_quote': "'"}
result = unify.format_code("'foo'", rules)
self.assertEqual(result, "'foo'")

def test_keep_double(self):
unify.rules['preferred_quote'] = '"'
result = unify.format_code('"foo"')
rules = {'preferred_quote': '"'}
result = unify.format_code('"foo"', rules)
self.assertEqual(result, '"foo"')


class TestUnitsSimpleQuotedString(unittest.TestCase):

def test_opposite(self):
unify.rules['preferred_quote'] = "'"
unify.rules['escape_simple'] = 'opposite'
rules = {
'preferred_quote': "'",
'escape_simple': 'opposite',
}

result = unify.format_code('''"foo's"''')
result = unify.format_code('''"foo's"''', rules)
self.assertEqual(result, '''"foo's"''')

result = unify.format_code("""'foo"s'""")
result = unify.format_code("""'foo"s'""", rules)
self.assertEqual(result, """'foo"s'""")

result = unify.format_code('''"foo\\"s"''')
result = unify.format_code('''"foo\\"s"''', rules)
self.assertEqual(result, """'foo"s'""")

result = unify.format_code("""'foo\\'s'""")
result = unify.format_code("""'foo\\'s'""", rules)
self.assertEqual(result, '''"foo's"''')

def test_backslash(self):
unify.rules['preferred_quote'] = "'"
unify.rules['escape_simple'] = 'backslash'
rules = {
'preferred_quote': "'",
'escape_simple': 'backslash',
}

result = unify.format_code('''"foo's"''')
result = unify.format_code('''"foo's"''', rules)
self.assertEqual(result, """'foo\\'s'""")

result = unify.format_code("""'foo"s'""")
result = unify.format_code("""'foo"s'""", rules)
self.assertEqual(result, """'foo"s'""")

result = unify.format_code('''"foo\\"s"''')
result = unify.format_code('''"foo\\"s"''', rules)
self.assertEqual(result, """'foo"s'""")

result = unify.format_code("""'foo\\'s'""")
result = unify.format_code("""'foo\\'s'""", rules)
self.assertEqual(result, """'foo\\'s'""")

def test_keep_unformatted(self):
unify.rules['preferred_quote'] = "'"
unify.rules['escape_simple'] = 'opposite'
rules = {
'preferred_quote': "'",
'escape_simple': 'opposite',
}

result = unify.format_code('''f"foo's{some_var}"''')
result = unify.format_code('''f"foo's{some_var}"''', rules)
self.assertEqual(result, '''f"foo's{some_var}"''')

result = unify.format_code("""r'foo\\'s'""")
result = unify.format_code("""r'foo\\'s'""", rules)
self.assertEqual(result, """r'foo\\'s'""")

def test_backslash_train(self):
unify.rules['preferred_quote'] = "'"
unify.rules['escape_simple'] = 'opposite'
rules = {
'preferred_quote': "'",
'escape_simple': 'opposite',
}

result = unify.format_code('''"a'b\\'c\\\\'d\\\\\\'e\\\\\\\\'f"''',
rules)

result = unify.format_code('''"a'b\\'c\\\\'d\\\\\\'e\\\\\\\\'f"''')
self.assertEqual(result, '''"a'b'c\\\\'d\\\\'e\\\\\\\\'f"''')

result = unify.format_code('''"\\'a"''')
result = unify.format_code('''"\\'a"''', rules)
self.assertEqual(result, '''"'a"''')

result = unify.format_code('''"\\\\'a"''')
result = unify.format_code('''"\\\\'a"''', rules)
self.assertEqual(result, '''"\\\\'a"''')


class TestUnitsTripleQuote(unittest.TestCase):

def test_no_change(self):
unify.rules['preferred_quote'] = "'"
rules = {'preferred_quote': "'"}

result = unify.format_code('''"""foo"""''')
result = unify.format_code('''"""foo"""''', rules)
self.assertEqual(result, '''"""foo"""''')

result = unify.format_code('''f"""foo"""''')
result = unify.format_code('''f"""foo"""''', rules)
self.assertEqual(result, '''f"""foo"""''')

result = unify.format_code('''r"""\\t"""''')
result = unify.format_code('''r"""\\t"""''', rules)
self.assertEqual(result, '''r"""\\t"""''')

result = unify.format_code('''u"""foo"""''')
result = unify.format_code('''u"""foo"""''', rules)
self.assertEqual(result, '''u"""foo"""''')

result = unify.format_code('''b"""foo"""''')
result = unify.format_code('''b"""foo"""''', rules)
self.assertEqual(result, '''b"""foo"""''')


Expand All @@ -153,58 +163,59 @@ def test_detect_encoding_with_bad_encoding(self):
self.assertEqual('latin-1', unify.detect_encoding(filename))

def test_format_code(self):
unify.rules['preferred_quote'] = "'"
rules = {'preferred_quote': "'"}

self.assertEqual("x = 'abc' \\\n'next'\n",
unify.format_code('x = "abc" \\\n"next"\n'))
unify.format_code('x = "abc" \\\n"next"\n', rules))

self.assertEqual("x = f'abc' \\\nf'next'\n",
unify.format_code('x = f"abc" \\\nf"next"\n'))
unify.format_code('x = f"abc" \\\nf"next"\n', rules))

self.assertEqual("x = u'abc' \\\nu'next'\n",
unify.format_code('x = u"abc" \\\nu"next"\n'))
unify.format_code('x = u"abc" \\\nu"next"\n', rules))

self.assertEqual("x = b'abc' \\\nb'next'\n",
unify.format_code('x = b"abc" \\\nb"next"\n'))
unify.format_code('x = b"abc" \\\nb"next"\n', rules))

def test_format_code_with_backslash_in_comment(self):
unify.rules['preferred_quote'] = "'"
rules = {'preferred_quote': "'"}

self.assertEqual("x = 'abc' #\\\n'next'\n",
unify.format_code('x = "abc" #\\\n"next"\n'))
unify.format_code('x = "abc" #\\\n"next"\n', rules))

self.assertEqual("x = f'abc' #\\\nf'next'\n",
unify.format_code('x = f"abc" #\\\nf"next"\n'))
unify.format_code('x = f"abc" #\\\nf"next"\n', rules))

self.assertEqual("x = r'abc' #\\\nr'next'\n",
unify.format_code('x = r"abc" #\\\nr"next"\n'))
unify.format_code('x = r"abc" #\\\nr"next"\n', rules))

self.assertEqual("x = r'abc' \\\nr'next'\n",
unify.format_code('x = r"abc" \\\nr"next"\n'))
unify.format_code('x = r"abc" \\\nr"next"\n', rules))

self.assertEqual("x = u'abc' #\\\nu'next'\n",
unify.format_code('x = u"abc" #\\\nu"next"\n'))
unify.format_code('x = u"abc" #\\\nu"next"\n', rules))

self.assertEqual("x = b'abc' #\\\nb'next'\n",
unify.format_code('x = b"abc" #\\\nb"next"\n'))
unify.format_code('x = b"abc" #\\\nb"next"\n', rules))

def test_format_code_with_syntax_error(self):
unify.rules['preferred_quote'] = "'"
rules = {'preferred_quote': "'"}

self.assertEqual('foo("abc"\n',
unify.format_code('foo("abc"\n'))
unify.format_code('foo("abc"\n', rules))

self.assertEqual('foo(f"abc"\n',
unify.format_code('foo(f"abc"\n'))
unify.format_code('foo(f"abc"\n', rules))

self.assertEqual('foo(r"Tabs \t, new lines \n."\n',
unify.format_code('foo(r"Tabs \t, new lines \n."\n'))
unify.format_code('foo(r"Tabs \t, new lines \n."\n',
rules))

self.assertEqual('foo(u"abc"\n',
unify.format_code('foo(u"abc"\n'))
unify.format_code('foo(u"abc"\n', rules))

self.assertEqual('foo(b"abc"\n',
unify.format_code('foo(b"abc"\n'))
unify.format_code('foo(b"abc"\n', rules))


class TestSystem(unittest.TestCase):
Expand Down
36 changes: 17 additions & 19 deletions unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,12 @@
unicode = str


# dict with transform rules
rules = {}


class AbstractString:
"""Interface to transform strings."""
__metaclass__ = ABCMeta

@abstractmethod
def reformat(self): pass
def reformat(self, rules): pass

@abstractproperty
def token(self): pass
Expand All @@ -75,7 +71,7 @@ class ImmutableString(AbstractString):
def __init__(self, body):
self.body = body

def reformat(self): pass
def reformat(self, rules): pass

@property
def token(self):
Expand All @@ -100,7 +96,7 @@ def __init__(self, prefix, quote, body):
self.old_prefix = prefix
self.old_quote = quote

def reformat(self):
def reformat(self, rules):
preferred_quote = rules['preferred_quote']
self.quote = preferred_quote

Expand Down Expand Up @@ -133,7 +129,7 @@ def __init__(self, prefix, quote, body):
self.old_quote = quote
self.old_body = body

def reformat(self):
def reformat(self, rules):
preferred_quote = rules['preferred_quote']
escape_simple = rules['escape_simple']
quote_in_body = "'" if "'" in self.body else '"'
Expand Down Expand Up @@ -184,25 +180,25 @@ class SimpleEscapeFstring(SimpleEscapeString):
Use escape_simple and preferred_quote rules.
"""

def reformat(self):
def reformat(self, rules):
if any(br in self.body for br in '{}'):
# don't transform since can't use backslashes in bracket area
# TODO add body parsing and handle this case
return

# can treat this case as simple escape
super().reformat()
super().reformat(rules)


def format_code(source):
def format_code(source, rules):
"""Return source code with quotes unified."""
try:
return _format_code(source)
return _format_code(source, rules)
except (tokenize.TokenError, IndentationError):
return source


def _format_code(source):
def _format_code(source, rules):
"""Return source code with quotes unified."""
if not source:
return source
Expand All @@ -217,7 +213,7 @@ def _format_code(source):
line) in tokenize.generate_tokens(sio.readline):

editable_string = get_editable_string(token_type, token_string)
editable_string.reformat()
editable_string.reformat(rules)
token_string = editable_string.token

modified_tokens.append((token_type, token_string, start, end, line))
Expand Down Expand Up @@ -278,7 +274,7 @@ def detect_encoding(filename):
return 'latin-1'


def format_file(filename, args, standard_out):
def format_file(filename, args, standard_out, rules):
"""Run format_code() on a file.
Returns `True` if any changes are needed and they are not being done
Expand All @@ -288,7 +284,7 @@ def format_file(filename, args, standard_out):
encoding = detect_encoding(filename)
with open_with_encoding(filename, encoding=encoding) as input_file:
source = input_file.read()
formatted_source = format_code(source)
formatted_source = format_code(source, rules)

if source != formatted_source:
if args.in_place:
Expand Down Expand Up @@ -339,8 +335,10 @@ def _main(argv, standard_out, standard_error):

args = parser.parse_args(argv[1:])

rules['preferred_quote'] = args.quote
rules['escape_simple'] = args.escape_simple
rules = {
'preferred_quote': args.quote,
'escape_simple': args.escape_simple,
}
filenames = list(set(args.files))
changes_needed = False
failure = False
Expand All @@ -358,7 +356,7 @@ def _main(argv, standard_out, standard_error):
]
else:
try:
if format_file(name, args=args, standard_out=standard_out):
if format_file(name, args=args, standard_out=standard_out, rules=rules):
changes_needed = True
except IOError as exception:
print(unicode(exception), file=standard_error)
Expand Down

0 comments on commit 8d51874

Please sign in to comment.