Skip to content

Commit

Permalink
Feat: improve tokenizer perf significantly on sql with many strings
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 8, 2023
1 parent 4f0b3ed commit e173dd5
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,6 @@ class Tokenizer(metaclass=_Tokenizer):
"_end",
"_peek",
"_prev_token_line",
"_prev_token_comments",
"_prev_token_type",
)

def __init__(self) -> None:
Expand All @@ -795,8 +793,6 @@ def reset(self) -> None:
self._end = False
self._peek = ""
self._prev_token_line = -1
self._prev_token_comments: t.List[str] = []
self._prev_token_type: t.Optional[TokenType] = None

def tokenize(self, sql: str) -> t.List[Token]:
"""Returns a list of tokens corresponding to the SQL string `sql`."""
Expand Down Expand Up @@ -846,7 +842,7 @@ def _chars(self, size: int) -> str:
return self.sql[start:end]
return ""

def _advance(self, i: int = 1) -> None:
def _advance(self, i: int = 1, alnum=False) -> None:
if self.WHITE_SPACE.get(self._char) is TokenType.BREAK:
self._col = 1
self._line += 1
Expand All @@ -858,14 +854,30 @@ def _advance(self, i: int = 1) -> None:
self._char = self.sql[self._current - 1]
self._peek = "" if self._end else self.sql[self._current]

if alnum:
_col = self._col
_current = self._current
_end = self._end
_peek = self._peek

while _peek.isalnum():
_col += 1
_current += 1
_end = _current >= self.size
_peek = "" if _end else self.sql[_current]

self._col = _col
self._current = _current
self._end = _end
self._peek = _peek
self._char = self.sql[_current - 1]

@property
def _text(self) -> str:
return self.sql[self._start : self._current]

def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
self._prev_token_line = self._line
self._prev_token_comments = self._comments
self._prev_token_type = token_type
self.tokens.append(
Token(
token_type,
Expand Down Expand Up @@ -966,13 +978,13 @@ def _scan_comment(self, comment_start: str) -> bool:

comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
self._advance(alnum=True)

self._comments.append(self._text[comment_start_size : -comment_end_size + 1])
self._advance(comment_end_size - 1)
else:
while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK:
self._advance()
self._advance(alnum=True)
self._comments.append(self._text[comment_start_size:])

# Leading comment is attached to the succeeding token, whilst trailing comment to the preceding.
Expand Down Expand Up @@ -1053,7 +1065,7 @@ def _extract_value(self) -> str:
while True:
char = self._peek.strip()
if char and char not in self.SINGLE_TOKENS:
self._advance()
self._advance(alnum=True)
else:
break

Expand Down Expand Up @@ -1103,47 +1115,30 @@ def _scan_formatted_string(self, string_start: str) -> bool:
return True

def _scan_identifier(self, identifier_end: str) -> None:
text = ""
identifier_end_is_escape = identifier_end in self._IDENTIFIER_ESCAPES

while True:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")

self._advance()
if self._char == identifier_end:
if identifier_end_is_escape and self._peek == identifier_end:
text += identifier_end
self._advance()
continue

break

text += self._char

self._advance()
text = self._extract_string(identifier_end, self._IDENTIFIER_ESCAPES)
self._add(TokenType.IDENTIFIER, text)

def _scan_var(self) -> None:
while True:
char = self._peek.strip()
if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS):
self._advance()
self._advance(alnum=True)
else:
break
self._add(
TokenType.VAR
if self._prev_token_type == TokenType.PARAMETER
if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)

def _extract_string(self, delimiter: str) -> str:
def _extract_string(self, delimiter: str, escapes=None) -> str:
text = ""
delim_size = len(delimiter)
escapes = self._STRING_ESCAPES if escapes is None else escapes

while True:
if self._char in self._STRING_ESCAPES and (
self._peek == delimiter or self._peek in self._STRING_ESCAPES
):
if self._char in escapes and (self._peek == delimiter or self._peek in escapes):
if self._peek == delimiter:
text += self._peek
else:
Expand All @@ -1161,7 +1156,9 @@ def _extract_string(self, delimiter: str) -> str:

if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
text += self._char
self._advance()

current = self._current - 1
self._advance(alnum=True)
text += self.sql[current : self._current - 1]

return text

0 comments on commit e173dd5

Please sign in to comment.