Skip to content

Commit

Permalink
vllm langchain: Add Document Retriever Support (#687)
Browse files Browse the repository at this point in the history
* vllm langchain: Add Document Retriever Support

Include SearchedDoc in /v1/chat/completions endpoint to accept document
data retreived from retriever service to parse into LLM for answer
generation.

Signed-off-by: Yeoh, Hoong Tee <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vllm: Update README documentation

Signed-off-by: Yeoh, Hoong Tee <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Yeoh, Hoong Tee <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hteeyeoh and pre-commit-ci[bot] authored Sep 13, 2024
1 parent 574fecf commit 0f2c2b1
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 30 deletions.
41 changes: 39 additions & 2 deletions comps/llms/text-generation/vllm/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ curl http://${your_ip}:8008/v1/completions \

## 🚀3. Set up LLM microservice

Then we warp the VLLM service into LLM microcervice.
Then we warp the VLLM service into LLM microservice.

### Build docker

Expand All @@ -179,11 +179,48 @@ bash build_docker_microservice.sh
bash launch_microservice.sh
```

### Query the microservice
### Consume the microservice

#### Check microservice status

```bash
curl http://${your_ip}:9000/v1/health_check\
-X GET \
-H 'Content-Type: application/json'

# Output
# {"Service Title":"opea_service@llm_vllm/MicroService","Service Description":"OPEA Microservice Infrastructure"}
```

#### Consume vLLM Service

User can set the following model parameters according to needs:

- max_new_tokens: Total output token
- streaming(true/false): return text response in streaming mode or non-streaming mode

```bash
# 1. Non-streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_p":0.95,"temperature":0.01,"streaming":false}' \
-H 'Content-Type: application/json'

# 2. Streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-H 'Content-Type: application/json'

# 3. Custom chat template with streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'

4. # Chat with SearchedDoc (Retrieval context)
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```
145 changes: 117 additions & 28 deletions comps/llms/text-generation/vllm/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,31 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Union

from fastapi.responses import StreamingResponse
from langchain_community.llms import VLLMOpenAI
from langchain_core.prompts import PromptTemplate
from template import ChatTemplate

from comps import (
CustomLogger,
GeneratedDoc,
LLMParamsDoc,
SearchedDoc,
ServiceType,
opea_microservices,
opea_telemetry,
register_microservice,
)
from comps.cores.proto.api_protocol import ChatCompletionRequest

logger = CustomLogger("llm_vllm")
logflag = os.getenv("LOGFLAG", False)

llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")


@opea_telemetry
def post_process_text(text: str):
Expand All @@ -39,39 +47,120 @@ def post_process_text(text: str):
host="0.0.0.0",
port=9000,
)
def llm_generate(input: LLMParamsDoc):
def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, SearchedDoc]):
if logflag:
logger.info(input)
llm_endpoint = os.getenv("vLLM_ENDPOINT", "http://localhost:8008")
model_name = os.getenv("LLM_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=input.max_new_tokens,
model_name=model_name,
top_p=input.top_p,
temperature=input.temperature,
streaming=input.streaming,
)

if input.streaming:

def stream_generator():
chat_response = ""
for text in llm.stream(input.query):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
yield f"data: {chunk_repr}\n\n"

prompt_template = None

if not isinstance(input, SearchedDoc) and input.chat_template:
prompt_template = PromptTemplate.from_template(input.chat_template)
input_variables = prompt_template.input_variables

if isinstance(input, SearchedDoc):
if logflag:
logger.info("[ SearchedDoc ] input from retriever microservice")

prompt = input.initial_query

if input.retrieved_docs:
docs = [doc.text for doc in input.retrieved_docs]
if logflag:
logger.info(f"[llm - chat_stream] stream response: {chat_response}")
yield "data: [DONE]\n\n"
logger.info(f"[ SearchedDoc ] combined retrieved docs: {docs}")

prompt = ChatTemplate.generate_rag_prompt(input.initial_query, docs)

# use default llm parameter for inference
new_input = LLMParamsDoc(query=prompt)

return StreamingResponse(stream_generator(), media_type="text/event-stream")
else:
response = llm.invoke(input.query)
if logflag:
logger.info(response)
return GeneratedDoc(text=response, prompt=input.query)
logger.info(f"[ SearchedDoc ] final input: {new_input}")

llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=new_input.max_new_tokens,
model_name=model_name,
top_p=new_input.top_p,
temperature=new_input.temperature,
streaming=new_input.streaming,
)

if new_input.streaming:

def stream_generator():
chat_response = ""
for text in llm.stream(new_input.query):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ SearchedDoc ] chunk: {chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ SearchedDoc ] stream response: {chat_response}")
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(new_input.query)
if logflag:
logger.info(response)

return GeneratedDoc(text=response, prompt=new_input.query)

elif isinstance(input, LLMParamsDoc):
if logflag:
logger.info("[ LLMParamsDoc ] input from rerank microservice")

prompt = input.query

if prompt_template:
if sorted(input_variables) == ["context", "question"]:
prompt = prompt_template.format(question=input.query, context="\n".join(input.documents))
elif input_variables == ["question"]:
prompt = prompt_template.format(question=input.query)
else:
logger.info(
f"[ LLMParamsDoc ] {prompt_template} not used, we only support 2 input variables ['question', 'context']"
)
else:
if input.documents:
# use rag default template
prompt = ChatTemplate.generate_rag_prompt(input.query, input.documents)

llm = VLLMOpenAI(
openai_api_key="EMPTY",
openai_api_base=llm_endpoint + "/v1",
max_tokens=input.max_new_tokens,
model_name=model_name,
top_p=input.top_p,
temperature=input.temperature,
streaming=input.streaming,
)

if input.streaming:

def stream_generator():
chat_response = ""
for text in llm.stream(input.query):
chat_response += text
chunk_repr = repr(text.encode("utf-8"))
if logflag:
logger.info(f"[ LLMParamsDoc ] chunk: {chunk_repr}")
yield f"data: {chunk_repr}\n\n"
if logflag:
logger.info(f"[ LLMParamsDoc ] stream response: {chat_response}")
yield "data: [DONE]\n\n"

return StreamingResponse(stream_generator(), media_type="text/event-stream")

else:
response = llm.invoke(input.query)
if logflag:
logger.info(response)

return GeneratedDoc(text=response, prompt=input.query)


if __name__ == "__main__":
Expand Down
29 changes: 29 additions & 0 deletions comps/llms/text-generation/vllm/langchain/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import re


class ChatTemplate:
@staticmethod
def generate_rag_prompt(question, documents):
context_str = "\n".join(documents)
if context_str and len(re.findall("[\u4E00-\u9FFF]", context_str)) / len(context_str) >= 0.3:
# chinese context
template = """
### 你将扮演一个乐于助人、尊重他人并诚实的助手,你的目标是帮助用户解答问题。有效地利用来自本地知识库的搜索结果。确保你的回答中只包含相关信息。如果你不确定问题的答案,请避免分享不准确的信息。
### 搜索结果:{context}
### 问题:{question}
### 回答:
"""
else:
template = """
### You are a helpful, respectful and honest assistant to help the user with questions. \
Please refer to the search results obtained from the local knowledge base. \
But be careful to not incorporate the information that you think is not relevant to the question. \
If you don't know the answer to a question, please don't share false information. \n
### Search results: {context} \n
### Question: {question} \n
### Answer:
"""
return template.format(context=context_str, question=question)

0 comments on commit 0f2c2b1

Please sign in to comment.