Skip to content

Commit

Permalink
Multiple fixes for the server (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Mar 31, 2023
1 parent 10b6aa6 commit 6f42570
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 15 deletions.
5 changes: 2 additions & 3 deletions docs/commands/webserver.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ pip3 install git+https://github.com/huggingface/transformers
### Launch servers
```
python3 -m fastchat.serve.controller --host 0.0.0.0 --port 21001
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name https://
python3 -m fastchat.serve.test_message --model vicuna-13b --controller http://localhost:21001
python3 -m fastchat.serve.gradio_web_server --controller http://localhost:21001
export OPENAI_API_KEY=
python3 -m fastchat.serve.gradio_web_server --controller http://localhost:21001 --moderate --concurrency 20
```
26 changes: 19 additions & 7 deletions fastchat/serve/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import argparse
import asyncio
import dataclasses
import json
import logging
import time
from typing import List, Union
Expand All @@ -17,7 +18,7 @@
import uvicorn

from fastchat.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from fastchat.utils import build_logger
from fastchat.utils import build_logger, server_error_msg


logger = build_logger("controller", "controller.log")
Expand Down Expand Up @@ -116,6 +117,7 @@ def get_worker_address(self, model_name: str):
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
#logger.info(f"speeds: {worker_speeds}, pt: {pt}, worker_name: {worker_name}")
return worker_name

# Check status before returning
Expand Down Expand Up @@ -159,17 +161,27 @@ def remove_stable_workers_by_expiration(self):
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"error_code": 2,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
yield json.dumps(ret).encode() + b"\0"

try:
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True, timeout=5)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
ret = {
"text": server_error_msg,
"error_code": 3,
}
yield json.dumps(ret).encode() + b"\0"

response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"

# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
Expand Down
17 changes: 13 additions & 4 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,15 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) + 2:]
output = data["text"][len(prompt) + 1:].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"]
state.messages[-1][-1] = output + "▌"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
time.sleep(0.04)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
Expand Down Expand Up @@ -304,8 +305,12 @@ def build_demo():
show_label=False).style(container=False)

chatbot = grChatbot(elem_id="chatbot", visible=False).style(height=550)
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press ENTER", visible=False).style(container=False)
with gr.Row():
with gr.Column(scale=10):
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press ENTER", visible=False).style(container=False)
with gr.Column(scale=1, min_width=60):
submit_btn = gr.Button(value="Submit")

with gr.Row(visible=False) as button_row:
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
Expand Down Expand Up @@ -339,6 +344,9 @@ def build_demo():
textbox.submit(add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(http_bot, [state, model_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list)
submit_btn.click(add_text, [state, textbox], [state, chatbot, textbox] + btn_list
).then(http_bot, [state, model_selector, temperature, max_output_tokens],
[state, chatbot] + btn_list)

if args.model_list_mode == "once":
demo.load(load_demo, [url_params], [state, model_selector,
Expand Down Expand Up @@ -367,6 +375,7 @@ def build_demo():

models = get_model_list()

logger.info(args)
demo = build_demo()
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
api_open=False).launch(
Expand Down
4 changes: 3 additions & 1 deletion fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def register_to_controller(self):
assert r.status_code == 200

def send_heart_beat(self):
logger.info(f"Send heart beat. Models: {[self.model_name]}")
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {model_semaphore}. "
f"global_counter: {global_counter}")

url = self.controller_addr + "/receive_heart_beat"
try:
Expand Down

0 comments on commit 6f42570

Please sign in to comment.