diff --git a/stdlib/src/builtin/string_literal.mojo b/stdlib/src/builtin/string_literal.mojo index fa698860e0..eeb3c87672 100644 --- a/stdlib/src/builtin/string_literal.mojo +++ b/stdlib/src/builtin/string_literal.mojo @@ -23,7 +23,7 @@ from utils import StringRef, Span, StringSlice from utils import Formattable, Formatter from utils._visualizers import lldb_formatter_wrapping_type -from collections.string import _atol +from collections.string import _atol, _StringSliceIter # ===----------------------------------------------------------------------===# # StringLiteral @@ -273,6 +273,30 @@ struct StringLiteral( """ return self.__str__() + fn __iter__(ref [_]self) -> _StringSliceIter[__lifetime_of(self)]: + """Return an iterator over the string literal. + + Returns: + An iterator over the string. + """ + return _StringSliceIter[__lifetime_of(self)]( + unsafe_pointer=self.unsafe_ptr(), length=self.byte_length() + ) + + fn __getitem__[IndexerType: Indexer](self, idx: IndexerType) -> String: + """Gets the character at the specified position. + + Parameters: + IndexerType: The inferred type of an indexer argument. + + Args: + idx: The index value. + + Returns: + A new string containing the character at the specified position. + """ + return str(self)[idx] + # ===-------------------------------------------------------------------===# # Methods # ===-------------------------------------------------------------------===# @@ -417,6 +441,88 @@ struct StringLiteral( return result + fn split(self, sep: String, maxsplit: Int = -1) raises -> List[String]: + """Split the string literal by a separator. + + Args: + sep: The string to split on. + maxsplit: The maximum amount of items to split from String. + Defaults to unlimited. + + Returns: + A List of Strings containing the input split by the separator. + + Examples: + + ```mojo + # Splitting a space + _ = "hello world".split(" ") # ["hello", "world"] + # Splitting adjacent separators + _ = "hello,,world".split(",") # ["hello", "", "world"] + # Splitting with maxsplit + _ = "1,2,3".split(",", 1) # ['1', '2,3'] + ``` + . + """ + return str(self).split(sep, maxsplit) + + fn split(self, sep: NoneType = None, maxsplit: Int = -1) -> List[String]: + """Split the string literal by every whitespace separator. + + Args: + sep: None. + maxsplit: The maximum amount of items to split from string. Defaults + to unlimited. + + Returns: + A List of Strings containing the input split by the separator. + + Examples: + + ```mojo + # Splitting an empty string or filled with whitespaces + _ = " ".split() # [] + _ = "".split() # [] + + # Splitting a string with leading, trailing, and middle whitespaces + _ = " hello world ".split() # ["hello", "world"] + # Splitting adjacent universal newlines: + _ = "hello \\t\\n\\r\\f\\v\\x1c\\x1d\\x1e\\x85\\u2028\\u2029world".split() + # ["hello", "world"] + ``` + . + """ + return str(self).split(sep, maxsplit) + + fn splitlines(self, keepends: Bool = False) -> List[String]: + """Split the string literal at line boundaries. This corresponds to Python's + [universal newlines]( + https://docs.python.org/3/library/stdtypes.html#str.splitlines) + `"\\t\\n\\r\\r\\n\\f\\v\\x1c\\x1d\\x1e\\x85\\u2028\\u2029"`. + + Args: + keepends: If True, line breaks are kept in the resulting strings. + + Returns: + A List of Strings containing the input split by line boundaries. + """ + return self.as_string_slice().splitlines(keepends) + + fn count(self, substr: String) -> Int: + """Return the number of non-overlapping occurrences of substring + `substr` in the string literal. + + If sub is empty, returns the number of empty strings between characters + which is the length of the string plus one. + + Args: + substr: The substring to count. + + Returns: + The number of occurrences of `substr`. + """ + return str(self).count(substr) + fn lower(self) -> String: """Returns a copy of the string literal with all cased characters converted to lowercase. @@ -436,3 +542,161 @@ struct StringLiteral( """ return str(self).upper() + + fn rjust(self, width: Int, fillchar: StringLiteral = " ") -> String: + """Returns the string right justified in a string literal of specified width. + + Args: + width: The width of the field containing the string. + fillchar: Specifies the padding character. + + Returns: + Returns right justified string, or self if width is not bigger than self length. + """ + return str(self).rjust(width, fillchar) + + fn ljust(self, width: Int, fillchar: StringLiteral = " ") -> String: + """Returns the string left justified in a string literal of specified width. + + Args: + width: The width of the field containing the string. + fillchar: Specifies the padding character. + + Returns: + Returns left justified string, or self if width is not bigger than self length. + """ + return str(self).ljust(width, fillchar) + + fn center(self, width: Int, fillchar: StringLiteral = " ") -> String: + """Returns the string center justified in a string literal of specified width. + + Args: + width: The width of the field containing the string. + fillchar: Specifies the padding character. + + Returns: + Returns center justified string, or self if width is not bigger than self length. + """ + return str(self).center(width, fillchar) + + fn startswith(self, prefix: String, start: Int = 0, end: Int = -1) -> Bool: + """Checks if the string literal starts with the specified prefix between start + and end positions. Returns True if found and False otherwise. + + Args: + prefix: The prefix to check. + start: The start offset from which to check. + end: The end offset from which to check. + + Returns: + True if the self[start:end] is prefixed by the input prefix. + """ + return str(self).startswith(prefix, start, end) + + fn endswith(self, suffix: String, start: Int = 0, end: Int = -1) -> Bool: + """Checks if the string literal end with the specified suffix between start + and end positions. Returns True if found and False otherwise. + + Args: + suffix: The suffix to check. + start: The start offset from which to check. + end: The end offset from which to check. + + Returns: + True if the self[start:end] is suffixed by the input suffix. + """ + return str(self).endswith(suffix, start, end) + + fn isdigit(self) -> Bool: + """Returns True if all characters in the string literal are digits. + + Note that this currently only works with ASCII strings. + + Returns: + True if all characters are digits else False. + """ + return str(self).isdigit() + + fn isupper(self) -> Bool: + """Returns True if all cased characters in the string literal are + uppercase and there is at least one cased character. + + Note that this currently only works with ASCII strings. + + Returns: + True if all cased characters in the string literal are uppercase + and there is at least one cased character, False otherwise. + """ + return str(self).isupper() + + fn islower(self) -> Bool: + """Returns True if all cased characters in the string literal + are lowercase and there is at least one cased character. + + Note that this currently only works with ASCII strings. + + Returns: + True if all cased characters in the string literal are lowercase + and there is at least one cased character, False otherwise. + """ + return str(self).islower() + + fn strip(self) -> String: + """Return a copy of the string literal with leading and trailing whitespaces + removed. + + Returns: + A string with no leading or trailing whitespaces. + """ + return self.lstrip().rstrip() + + fn strip(self, chars: String) -> String: + """Return a copy of the string literal with leading and trailing characters + removed. + + Args: + chars: A set of characters to be removed. Defaults to whitespace. + + Returns: + A string with no leading or trailing characters. + """ + + return self.lstrip(chars).rstrip(chars) + + fn rstrip(self, chars: String) -> String: + """Return a copy of the string literal with trailing characters removed. + + Args: + chars: A set of characters to be removed. Defaults to whitespace. + + Returns: + A string with no trailing characters. + """ + return str(self).rstrip(chars) + + fn rstrip(self) -> String: + """Return a copy of the string with trailing whitespaces removed. + + Returns: + A copy of the string with no trailing whitespaces. + """ + return str(self).rstrip() + + fn lstrip(self, chars: String) -> String: + """Return a copy of the string with leading characters removed. + + Args: + chars: A set of characters to be removed. Defaults to whitespace. + + Returns: + A copy of the string with no leading characters. + """ + return str(self).lstrip(chars) + + fn lstrip(self) -> String: + """Return a copy of the string with leading whitespaces removed. + + Returns: + A copy of the string with no leading whitespaces. + """ + return str(self).lstrip() diff --git a/stdlib/test/builtin/test_string_literal.mojo b/stdlib/test/builtin/test_string_literal.mojo index c4960d31f8..3199aa875d 100644 --- a/stdlib/test/builtin/test_string_literal.mojo +++ b/stdlib/test/builtin/test_string_literal.mojo @@ -105,6 +105,30 @@ def test_replace(): ) +def test_startswith(): + var str = "Hello world" + + assert_true(str.startswith("Hello")) + assert_false(str.startswith("Bye")) + + assert_true(str.startswith("llo", 2)) + assert_true(str.startswith("llo", 2, -1)) + assert_false(str.startswith("llo", 2, 3)) + + +def test_endswith(): + var str = "Hello world" + + assert_true(str.endswith("")) + assert_true(str.endswith("world")) + assert_true(str.endswith("ld")) + assert_false(str.endswith("universe")) + + assert_true(str.endswith("ld", 2)) + assert_true(str.endswith("llo", 2, 5)) + assert_false(str.endswith("llo", 2, 3)) + + def test_comparison_operators(): # Test less than and greater than assert_true(StringLiteral.__lt__("abc", "def")) @@ -142,6 +166,13 @@ def test_hash(): assert_equal(StringLiteral.__hash__("b"), StringLiteral.__hash__("b")) +def test_indexing(): + var s = "hello" + assert_equal(s[False], "h") + assert_equal(s[int(1)], "e") + assert_equal(s[2], "l") + + def test_intable(): assert_equal(StringLiteral.__int__("123"), 123) @@ -149,6 +180,43 @@ def test_intable(): _ = StringLiteral.__int__("hi") +def test_isdigit(): + assert_true("123".isdigit()) + assert_false("abc".isdigit()) + assert_false("123abc".isdigit()) + # TODO: Uncomment this when PR3439 is merged + # assert_false("".isdigit()) + + +def test_islower(): + assert_true("hello".islower()) + assert_false("Hello".islower()) + assert_false("HELLO".islower()) + assert_false("123".islower()) + assert_false("".islower()) + + +def test_isupper(): + assert_true("HELLO".isupper()) + assert_false("Hello".isupper()) + assert_false("hello".isupper()) + assert_false("123".isupper()) + assert_false("".isupper()) + + +def test_iter(): + # Test iterating over a string + var s = "one" + var i = 0 + for c in s: + if i == 0: + assert_equal(c, "o") + elif i == 1: + assert_equal(c, "n") + elif i == 2: + assert_equal(c, "e") + + def test_layout(): # Test empty StringLiteral contents var empty = "".unsafe_ptr() @@ -196,6 +264,178 @@ def test_repr(): assert_equal(StringLiteral.__repr__("\x7f"), r"'\x7f'") +def test_strip(): + assert_equal("".strip(), "") + assert_equal(" ".strip(), "") + assert_equal(" hello".strip(), "hello") + assert_equal("hello ".strip(), "hello") + assert_equal(" hello ".strip(), "hello") + assert_equal(" hello world ".strip(" "), "hello world") + assert_equal("_wrap_hello world_wrap_".strip("_wrap_"), "hello world") + assert_equal(" hello world ".strip(" "), "hello world") + assert_equal(" hello world ".lstrip(), "hello world ") + assert_equal(" hello world ".rstrip(), " hello world") + assert_equal( + "_wrap_hello world_wrap_".lstrip("_wrap_"), "hello world_wrap_" + ) + assert_equal( + "_wrap_hello world_wrap_".rstrip("_wrap_"), "_wrap_hello world" + ) + + +def test_count(): + var str = "Hello world" + + assert_equal(12, str.count("")) + assert_equal(1, str.count("Hell")) + assert_equal(3, str.count("l")) + assert_equal(1, str.count("ll")) + assert_equal(1, str.count("ld")) + assert_equal(0, str.count("universe")) + + assert_equal(String("aaaaa").count("a"), 5) + assert_equal(String("aaaaaa").count("aa"), 3) + + +def test_rjust(): + assert_equal("hello".rjust(4), "hello") + assert_equal("hello".rjust(8), " hello") + assert_equal("hello".rjust(8, "*"), "***hello") + + +def test_ljust(): + assert_equal("hello".ljust(4), "hello") + assert_equal("hello".ljust(8), "hello ") + assert_equal("hello".ljust(8, "*"), "hello***") + + +def test_center(): + assert_equal("hello".center(4), "hello") + assert_equal("hello".center(8), " hello ") + assert_equal("hello".center(8, "*"), "*hello**") + + +def test_split(): + var d = "hello world".split() + assert_true(len(d) == 2) + assert_true(d[0] == "hello") + assert_true(d[1] == "world") + d = "hello \t\n\n\v\fworld".split("\n") + assert_true(len(d) == 3) + assert_true(d[0] == "hello \t" and d[1] == "" and d[2] == "\v\fworld") + + # should split into empty strings between separators + d = "1,,,3".split(",") + assert_true(len(d) == 4) + assert_true(d[0] == "1" and d[1] == "" and d[2] == "" and d[3] == "3") + d = "abababaaba".split("aba") + assert_true(len(d) == 4) + assert_true(d[0] == "" and d[1] == "b" and d[2] == "" and d[3] == "") + + # should split into maxsplit + 1 items + d = "1,2,3".split(",", 0) + assert_true(len(d) == 1) + assert_true(d[0] == "1,2,3") + d = "1,2,3".split(",", 1) + assert_true(len(d) == 2) + assert_true(d[0] == "1" and d[1] == "2,3") + + assert_true(len("".split()) == 0) + assert_true(len(" ".split()) == 0) + assert_true(len("".split(" ")) == 1) + assert_true(len(" ".split(" ")) == 2) + assert_true(len(" ".split(" ")) == 3) + assert_true(len(" ".split(" ")) == 4) + + with assert_raises(): + _ = "".split("") + + # Matches should be properly split in multiple case + var d2 = " " + var in2 = "modcon is coming soon" + var res2 = in2.split(d2) + assert_equal(len(res2), 4) + assert_equal(res2[0], "modcon") + assert_equal(res2[1], "is") + assert_equal(res2[2], "coming") + assert_equal(res2[3], "soon") + + # No match from the delimiter + var d3 = "x" + var in3 = "hello world" + var res3 = in3.split(d3) + assert_equal(len(res3), 1) + assert_equal(res3[0], "hello world") + + # Multiple character delimiter + var d4 = "ll" + var in4 = "hello" + var res4 = in4.split(d4) + assert_equal(len(res4), 2) + assert_equal(res4[0], "he") + assert_equal(res4[1], "o") + + +def test_splitlines(): + # Test with no line breaks + var in1 = "hello world" + var res1 = in1.splitlines() + assert_equal(len(res1), 1) + assert_equal(res1[0], "hello world") + + # Test with \n line break + var in2 = "hello\nworld" + var res2 = in2.splitlines() + assert_equal(len(res2), 2) + assert_equal(res2[0], "hello") + assert_equal(res2[1], "world") + + # Test with \r\n line break + var in3 = "hello\r\nworld" + var res3 = in3.splitlines() + assert_equal(len(res3), 2) + assert_equal(res3[0], "hello") + assert_equal(res3[1], "world") + + # Test with \r line break + var in4 = "hello\rworld" + var res4 = in4.splitlines() + assert_equal(len(res4), 2) + assert_equal(res4[0], "hello") + assert_equal(res4[1], "world") + + # Test with multiple different line breaks + var in5 = "hello\nworld\r\nmojo\rlanguage" + var res5 = in5.splitlines() + assert_equal(len(res5), 4) + assert_equal(res5[0], "hello") + assert_equal(res5[1], "world") + assert_equal(res5[2], "mojo") + assert_equal(res5[3], "language") + + # Test with keepends=True + var res6 = in5.splitlines(keepends=True) + assert_equal(len(res6), 4) + assert_equal(res6[0], "hello\n") + assert_equal(res6[1], "world\r\n") + assert_equal(res6[2], "mojo\r") + assert_equal(res6[3], "language") + + # Test with an empty string + var in7 = "" + var res7 = in7.splitlines() + assert_equal(len(res7), 0) + + # test with keepends=True + var in8 = String("hello\vworld\fmojo\x1clanguage\x1d") + var res10 = in8.splitlines(keepends=True) + assert_equal(len(res10), 4) + assert_equal(res10[0], "hello\v") + assert_equal(res10[1], "world\f") + assert_equal(res10[2], "mojo\x1c") + assert_equal(res10[3], "language\x1d") + + def test_float_conversion(): assert_equal(("4.5").__float__(), 4.5) assert_equal(float("4.5"), 4.5) @@ -213,9 +453,22 @@ def main(): test_rfind() test_replace() test_comparison_operators() + test_count() test_hash() + test_indexing() test_intable() + test_isdigit() + test_islower() + test_isupper() test_layout() test_lower_upper() test_repr() + test_rjust() + test_ljust() + test_center() + test_startswith() + test_endswith() + test_strip() + test_split() + test_splitlines() test_float_conversion()