Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加长度外推 #126

Merged
merged 12 commits into from
Oct 26, 2023
358 changes: 200 additions & 158 deletions README.md

Large diffs are not rendered by default.

226 changes: 54 additions & 172 deletions apps/web_demo.py
Original file line number Diff line number Diff line change
@@ -1,172 +1,54 @@
import torch
import os
import sys
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, GenerationConfig
import mdtex2html
from threading import Thread
import gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

os.environ["TOKENIZERS_PARALLELISM"] = "false"
max_generate_length: int = 1024
model_path = "TigerResearch/tigerbot-13b-chat"
print(f"loading model: {model_path}...")
device = torch.cuda.current_device()
generation_config = GenerationConfig.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map='auto')

tokenizer = AutoTokenizer.from_pretrained(
model_path,
cache_dir=None,
model_max_length=max_generate_length,
padding_side="left",
truncation_side='left',
padding=True,
truncation=True
)
if tokenizer.model_max_length is None or tokenizer.model_max_length > 1024:
tokenizer.model_max_length = 1024

"""Override Chatbot.postprocess"""


def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f'<br></code></pre>'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>"+line
text = "".join(lines)
return text


def generate_stream(query,
history,
max_input_length,
max_output_length):
tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
prompt_input = tok_ins + "{instruction}" + tok_res

sess_text = ""
if history:
for s in history:
sess_text += tok_ins + s["human"] + tok_res + s["assistant"]
history.append({"human": query, "assistant": ""})

sess_text += tok_ins + query.strip()
input_text = prompt_input.format_map({'instruction': sess_text.split(tok_ins, 1)[1]})
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_input_length)
inputs = {k: v.to(device) for k, v in inputs.items()}

streamer = TextIteratorStreamer(tokenizer,
skip_prompt=True,
skip_special_tokens=True,
spaces_between_special_tokens=False)

generation_kwargs = generation_config.to_dict()
generation_kwargs.update(dict(inputs))
generation_kwargs['streamer'] = streamer
generation_kwargs['max_new_tokens'] = max_output_length

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

answer = ""
for new_text in streamer:
if len(new_text) == 0:
continue
if new_text.endswith(tokenizer.eos_token):
new_text = new_text.rsplit(tokenizer.eos_token, 1)[0]
answer += new_text
history[-1]['assistant'] = answer

yield answer, history


def predict(input, chatbot, max_input_length, max_generate_length, history):
chatbot.append((parse_text(input), ""))
for response, history in generate_stream(
input,
history,
max_input_length=max_input_length,
max_output_length=max_generate_length,
):
if response is None:
break
chatbot[-1] = (parse_text(input), parse_text(response))

yield chatbot, history


def reset_user_input():
return gr.update(value='')


def reset_state():
return [], []


with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">TigerBot</h1>""")

chatbot = gr.Chatbot()
with gr.Row():
with gr.Column(scale=4):
with gr.Column(scale=12):
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
container=False)
with gr.Column(min_width=32, scale=1):
submitBtn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_input_length = gr.Slider(0, 1024, value=512, step=1.0, label="Maximum input length", interactive=True)
max_generate_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum generate length", interactive=True)

history = gr.State([])

submitBtn.click(predict, [user_input, chatbot, max_input_length, max_generate_length, history], [chatbot, history],
show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])

emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

demo.queue().launch(share=True, inbrowser=True)
import streamlit as st
import torch.cuda

from TigerBot.utils.modeling_hack import get_model
from TigerBot.utils.streaming import generate_stream


@st.cache_resource
def cached_get_model(model_path='tigerbot-13b-chat-v4', rope_scaling='yarn', rope_factor=8.0):
return get_model(model_path=model_path, rope_scaling=rope_scaling, rope_factor=rope_factor)


model, tokenizer, generation_config = cached_get_model()

generation_config.do_sample = False
generation_config.max_length = 16384
generation_config.max_new_tokens = 1024

tok_ins = "\n\n### Instruction:\n"
tok_res = "\n\n### Response:\n"
tok_eos = "</s>"

st.title("TigerBot chat")

device = f"cuda:{torch.cuda.current_device()}"

if "messages" not in st.session_state:
st.session_state.messages = list()

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.text(message["content"])

if prompt := st.chat_input("Input here"):
with st.chat_message("user"):
st.text(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("assistant"):
message_placeholder = st.empty()
input_text = "".join(
[(tok_ins if msg["role"] == "user" else tok_res) + msg["content"] + (
tok_eos if msg["role"] == "assistant" else "") for msg in st.session_state.messages])
input_text += tok_res
inputs = tokenizer(input_text, return_tensors='pt',
max_length=generation_config.max_length - generation_config.max_new_tokens)
inputs = {k: v.to(device) for k, v in inputs.items()}
full_answer = ""
print(inputs)
for text in generate_stream(model, tokenizer, inputs['input_ids'], inputs['attention_mask'],
generation_config=generation_config):
print(text, end='', flush=True)
full_answer += text
message_placeholder.text(full_answer)
st.session_state.messages.append({"role": "assistant", "content": full_answer})
60 changes: 24 additions & 36 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import os
from typing import Tuple, Optional

import fire
import torch
import transformers
import readline
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from utils import compared_version
from utils.modeling_hack import get_model
from utils.streaming import generate_stream

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Expand All @@ -18,36 +21,21 @@ def main(
max_input_length: int = 512,
max_generate_length: int = 1024,
model_type: str = 'chat',
use_flash_attn: bool = False
rope_scaling: Optional[str] = None,
rope_factor: float = 8.0,
streaming: bool = True
):
if model_type.lower() not in ['chat', 'base']:
raise ValueError(f"model_type must be one of ['chat', 'base'], got {model_type}")
if use_flash_attn:
assert compared_version(transformers.__version__, '4.34.0'), 'Please update transformers version >= 4.34.0'
print(f"loading model: {model_path}...")
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map='auto', use_flash_attention_2=True)
print("using flash attention...")
else:
print(f"loading model: {model_path}...")
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map='auto')
assert transformers.__version__.startswith('4.34')
assert model_type.lower() in ['chat', 'base'], f"model_type must be one of ['chat', 'base'], got {model_type}"
assert rope_scaling in [None, 'yarn',
'dynamic'], f"rope_scaling must be one of [None, 'yarn', 'dynamic'], got {rope_scaling}"

generation_config = GenerationConfig.from_pretrained(model_path)
generation_config.max_length = max_generate_length
print(generation_config)
model, tokenizer, generation_config = get_model(model_path=model_path, rope_scaling=rope_scaling,
rope_factor=rope_factor)
generation_config.max_new_tokens = max_generate_length
generation_config.max_length = max_input_length + max_generate_length

device = torch.cuda.current_device()

tokenizer = AutoTokenizer.from_pretrained(
model_path,
model_max_length=max_generate_length,
padding_side="left",
truncation_side='left',
padding=True,
truncation=True
)
if tokenizer.model_max_length is None or tokenizer.model_max_length > max_generate_length:
tokenizer.model_max_length = max_generate_length

sess_text = ""
while True:
raw_text = input("prompt(\"exit\" to end, \"clear\" to clear session) >>> ")
Expand All @@ -70,16 +58,16 @@ def main(
input_text = query_text
inputs = tokenizer(input_text, return_tensors='pt', truncation=True, max_length=max_input_length)
inputs = {k: v.to(device) for k, v in inputs.items()}
output = model.generate(**inputs, **generation_config.to_dict())
answer = tokenizer.decode(output[0][inputs['input_ids'].shape[1]:], skip_special_tokens=False,
spaces_between_special_tokens=False)
if answer.endswith(tokenizer.eos_token):
answer = answer.rsplit(tokenizer.eos_token, 1)[0].strip()

sess_text += tok_res + answer

print("=" * 100)
print(answer)
print('=' * 100)
if streaming:
for text in generate_stream(model, tokenizer, inputs['input_ids'], inputs['attention_mask'],
generation_config=generation_config):
print(text, end='', flush=True)
else:
output = model.generate(**inputs, **generation_config.to_dict())
print(tokenizer.decode(output[0][inputs['input_ids'].shape[1]:]))
print('')
print("=" * 100)


Expand Down
11 changes: 3 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@ torch>=2.0.0
evaluate==0.4.0
texttable==1.6.7
toml==0.10.2
numpy==1.18.5
numpy>=1.22.0
sentencepiece==0.1.98
fire==0.5.0
flash-attn==2.1.1
gradio
mdtex2html
sse-starlette
asyncio
aiohttp_sse_client
sseclient
requests
deepspeed==0.9.5
streamlit==1.24.1
16 changes: 0 additions & 16 deletions utils.py

This file was deleted.

Empty file added utils/__init__.py
Empty file.
Loading