Skip to content

Commit

Permalink
Fix vllm microservice performance issue. (#731)
Browse files Browse the repository at this point in the history
* Fix vllm microservice performance issue.

Signed-off-by: Yao, Qing <[email protected]>

* Refine llm generate parameters

Signed-off-by: Yao, Qing <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Yao, Qing <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yao531441 and pre-commit-ci[bot] authored Sep 25, 2024
1 parent f8f02e2 commit 2159f9a
Showing 1 changed file with 14 additions and 30 deletions.
44 changes: 14 additions & 30 deletions comps/llms/text-generation/vllm/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
llm = VLLMOpenAI(openai_api_key="EMPTY", openai_api_base=llm_endpoint + "/v1", model_name=model_name)


@opea_telemetry
Expand Down Expand Up @@ -56,6 +57,13 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables
parameters = {
"max_tokens": input.max_tokens,
"top_p": input.top_p,
"temperature": input.temperature,
"frequency_penalty": input.frequency_penalty,
"presence_penalty": input.presence_penalty,
}

if isinstance(input, SearchedDoc):
if logflag:
Expand All @@ -76,23 +84,11 @@ def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc])
if logflag:
logger.info(f"[ SearchedDoc ] final input: {new_input}")

llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=new_input.max_tokens,
model_name=model_name,
top_p=new_input.top_p,
temperature=new_input.temperature,
frequency_penalty=new_input.frequency_penalty,
presence_penalty=new_input.presence_penalty,
streaming=new_input.streaming,
)

if new_input.streaming:

def stream_generator():
async def stream_generator():
chat_response = ""
for text in llm.stream(new_input.query):
async for text in llm.astream(new_input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
Expand All @@ -105,7 +101,7 @@ def stream_generator():
return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(new_input.query)
response = llm.invoke(new_input.query, **parameters)
if logflag:
logger.info(response)

Expand All @@ -131,23 +127,11 @@ def stream_generator():
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)

llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=input.max_tokens,
model_name=model_name,
top_p=input.top_p,
temperature=input.temperature,
frequency_penalty=input.frequency_penalty,
presence_penalty=input.presence_penalty,
streaming=input.streaming,
)

if input.streaming:

def stream_generator():
async def stream_generator():
chat_response = ""
for text in llm.stream(input.query):
async for text in llm.astream(input.query, **parameters):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
Expand All @@ -160,7 +144,7 @@ def stream_generator():
return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(input.query)
response = llm.invoke(input.query, **parameters)
if logflag:
logger.info(response)

Expand Down

0 comments on commit 2159f9a

Please sign in to comment.