Skip to content

Commit

Permalink
Merge pull request #81 from SinanAkkoyun/code-chat
Browse files Browse the repository at this point in the history
Chat format: Recognize specified language and offloaded lexguessing to every newline
  • Loading branch information
turboderp authored Oct 7, 2023
2 parents 7b5dccb + fe047c4 commit a9f3f17
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
27 changes: 22 additions & 5 deletions examples/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

import sys, os
import sys, os, re
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import(
Expand Down Expand Up @@ -155,6 +155,11 @@ def get_tokenized_context(max_len):
codeblock_formatter = None if args.no_code_formatting else CodeBlockFormatter()
in_code_block = False

delim_buffer_array = []
delim_pattern = re.compile(r'(`{1,3})')

delim_overflow = ""

# Main loop

print(f" -- Prompt format: {args.mode}")
Expand Down Expand Up @@ -198,10 +203,17 @@ def get_tokenized_context(max_len):
response_text += chunk
responses_ids[-1] = torch.cat([responses_ids[-1], tokens], dim = -1)

# Check for code block delimiters
# Append chunk to delimiter buffer if contains delimiters
if delim_pattern.search(chunk) and len(delim_buffer_array) < 2: # dirty fix for assumption that codeblock start is never smaller than `` + `
# add chunk
delim_buffer_array.append(chunk)
else:
delim_overflow = "".join(delim_buffer_array)
delim_buffer_array = []

codeblock_delimiter = chunk.startswith("```") and codeblock_formatter is not None
if codeblock_delimiter: chunk = chunk[3:] # Suppress delimiter in output
# Check for code block delimiters
# if delim_buffer_array contains a full delimiter (```), codeblock true
codeblock_delimiter = "".join(delim_buffer_array).find("```") != -1 and (codeblock_formatter is not None)

# Print output

Expand All @@ -212,9 +224,13 @@ def get_tokenized_context(max_len):
codeblock_formatter.begin()
print("\n")
in_code_block = True
delim_buffer_array = []

# Print unformatted
print(chunk, end = "")
# if delim buffer is > 0 do not print for now
if len(delim_buffer_array) == 0:
print(chunk, end = "")

sys.stdout.flush()

else:
Expand All @@ -223,6 +239,7 @@ def get_tokenized_context(max_len):
if codeblock_delimiter:
print("\033[0m", end = "") # Reset block color to be certain
in_code_block = False
delim_buffer_array = []

# Print formatted
codeblock_formatter.print_code_block(chunk)
Expand Down
26 changes: 23 additions & 3 deletions examples/chat_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

# Code block formatter for black background


class BlackBackgroundTerminalFormatter(TerminalFormatter):

code_pad: int = 2
block_pad_left: int = 1


def __init__(self):
super().__init__(style = "monokai")

Expand Down Expand Up @@ -91,6 +93,7 @@ class CodeBlockFormatter:

code_block_text: str
lines_printed: int
last_lexer: str

formatter = BlackBackgroundTerminalFormatter()

Expand All @@ -100,14 +103,15 @@ def begin(self):

self.code_block_text = ""
self.lines_printed = 0
self.last_lexer = get_lexer_by_name("text")

self.formatter.begin()


# Print a code block, updating the CLI in real-time

def print_code_block(self, chunk):

# Clear previously printed lines
for _ in range(self.lines_printed): # -1 not needed?
# Move cursor up one line
Expand All @@ -126,7 +130,15 @@ def print_code_block(self, chunk):
self.code_block_text += chunk

# Remove language after codeblock start
code_block_text = re.sub(r'```.*?$', '```', self.code_block_text, flags=re.MULTILINE)
code_block_text = '\n'.join([''] + self.code_block_text.split('\n')[1:])

# Handle delim at end
if code_block_text.endswith("```"):
code_block_text = code_block_text[:-3]


# Get specified language
specified_lang = self.code_block_text.split('\n', 1)[0] # Get 1st line (directly after delimiter, can be language)

# Split updated text into lines and find the longest line
lines = code_block_text.split('\n')
Expand All @@ -140,9 +152,17 @@ def print_code_block(self, chunk):

# Try guessing the lexer for syntax highlighting, if we haven't guessed already
try:
lexer = guess_lexer(padded_text)
if bool(specified_lang):
lexer = get_lexer_by_name(specified_lang)
self.last_lexer = lexer
elif '\n' in chunk: # Offload lexguessing to every newline
lexer = guess_lexer(padded_text)
self.last_lexer = lexer
else:
lexer = self.last_lexer
except ClassNotFound:
lexer = get_lexer_by_name("text") # Fallback to plain text if language isn't supported by pygments
self.last_lexer = lexer

# Highlight
highlighted_text = highlight(padded_text, lexer, self.formatter)
Expand Down

0 comments on commit a9f3f17

Please sign in to comment.