diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index 89a9d806b..1a90a5578 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -656,14 +656,20 @@ def quote(self) -> StringQuoteLiteral: if len(quote) == 2: # Let's assume this is an empty string. quote = quote[:1] - elif len(quote) == 6: - # Let's assume this is an empty triple-quoted string. + elif 3 < len(quote) <= 6: + # Let's assume this can be one of the following: + # >>> """"foo""" + # '"foo' + # >>> """""bar""" + # '""bar' + # >>> """""" + # '' quote = quote[:3] if len(quote) not in {1, 3}: # We shouldn't get here due to construction validation logic, # but handle the case anyway. - raise Exception("Invalid string {self.value}") + raise Exception(f"Invalid string {self.value}") # pyre-ignore We know via the above validation that we will only # ever return one of the four string literals. diff --git a/libcst/_nodes/tests/test_simple_string.py b/libcst/_nodes/tests/test_simple_string.py new file mode 100644 index 000000000..ae020bd25 --- /dev/null +++ b/libcst/_nodes/tests/test_simple_string.py @@ -0,0 +1,31 @@ +import unittest +import libcst as cst + + +class TestSimpleString(unittest.TestCase): + + def test_quote(self) -> None: + test_cases = [ + ('"a"', '"'), + ("'b'", "'"), + + ('""', '"'), + ("''", "'"), + + ('"""c"""', '"""'), + ("'''d'''", "'''"), + + ('""""e"""', '"""'), + ("''''f'''", "'''"), + + ('"""""g"""', '"""'), + ("'''''h'''", "'''"), + + ('""""""', '"""'), + ("''''''", "'''"), + ] + + for s, expected_quote in test_cases: + simple_string = cst.SimpleString(s) + actual = simple_string.quote + self.assertEqual(expected_quote, actual)