Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "add best_of and use_beam_search for completions interface" #2401

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fastchat/protocol/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class CompletionResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: Union[UsageInfo, List[UsageInfo]]
usage: UsageInfo


class CompletionResponseStreamChoice(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,11 @@ class CompletionRequest(BaseModel):
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None


class CompletionResponseChoice(BaseModel):
index: int
text: Union[str, List[str]]
text: str
logprobs: Optional[int] = None
finish_reason: Optional[Literal["stop", "length"]] = None

Expand Down
29 changes: 3 additions & 26 deletions fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,6 @@ async def get_gen_params(
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
best_of: Optional[int] = None,
n: Optional[int] = 1,
use_beam_search: Optional[bool] = None,
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
Expand Down Expand Up @@ -290,11 +287,6 @@ async def get_gen_params(
"stop_token_ids": conv.stop_token_ids,
}

if best_of is not None:
gen_params.update({"n": n, "best_of": best_of})
if use_beam_search is not None:
gen_params.update({"use_beam_search": use_beam_search})

new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
Expand Down Expand Up @@ -502,18 +494,12 @@ async def create_completion(request: CompletionRequest):
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
best_of=request.best_of,
n=request.n,
use_beam_search=request.use_beam_search,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, worker_addr)
)
text_completions.append(content)
# when use with best_of, only need send one request
if request.best_of:
break

try:
all_tasks = await asyncio.gather(*text_completions)
Expand All @@ -533,18 +519,9 @@ async def create_completion(request: CompletionRequest):
finish_reason=content.get("finish_reason", "stop"),
)
)
idx = 0
while True:
info = content["usage"]
if isinstance(info, list):
info = info[idx]

task_usage = UsageInfo.parse_obj(info)

for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
idx += 1
break
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
Expand Down
70 changes: 21 additions & 49 deletions fastchat/serve/vllm_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.serve.model_worker import (
BaseModelWorker,
logger,
Expand Down Expand Up @@ -75,9 +74,6 @@ async def generate_stream(self, params):
if self.tokenizer.eos_token_id is not None:
stop_token_ids.append(self.tokenizer.eos_token_id)
echo = params.get("echo", True)
use_beam_search = params.get("use_beam_search", False)
best_of = params.get("best_of", None)
n = params.get("n", 1)

# Handle stop_str
stop = set()
Expand All @@ -94,51 +90,27 @@ async def generate_stream(self, params):
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
try:
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop=list(stop),
max_tokens=max_new_tokens,
best_of=best_of,
)

results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
prompt_tokens = len(request_output.prompt_token_ids)
output_usage = []
for out in request_output.outputs:
completion_tokens = len(out.token_ids)
total_tokens = prompt_tokens + completion_tokens
output_usage.append(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)

if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]

if sampling_params.best_of is None:
text_outputs = [" ".join(text_outputs)]
ret = {"text": text_outputs, "error_code": 0, "usage": output_usage}
yield (json.dumps(ret) + "\0").encode()
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{e}",
"error_code": ErrorCode.PARAM_OUT_OF_RANGE,
"usage": {},
}
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
)
results_generator = engine.generate(context, sampling_params, request_id)

async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [
prompt + output.text for output in request_output.outputs
]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
# Note: usage is not supported yet
ret = {"text": text_outputs, "error_code": 0, "usage": {}}
yield (json.dumps(ret) + "\0").encode()

async def generate(self, params):
Expand Down