Skip to content

Commit

Permalink
Align parameters for "max_token, repetition_penalty,presence_penalty,…
Browse files Browse the repository at this point in the history
…frequency_penalty" (#608)

* align max_tokens

Signed-off-by: Xinyao Wang <[email protected]>

* aligin repetition_penalty

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* align penalty parameters

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* align max_tokens

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* fix bug

Signed-off-by: Xinyao Wang <[email protected]>

* debug

Signed-off-by: Xinyao Wang <[email protected]>

* debug

Signed-off-by: Xinyao Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix langchain version bug

Signed-off-by: Xinyao Wang <[email protected]>

* fix langchain version bug

Signed-off-by: Xinyao Wang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xinyao Wang <[email protected]>
Co-authored-by: kevinintel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: lvliang-intel <[email protected]>
  • Loading branch information
4 people authored Sep 18, 2024
1 parent 00227b8 commit 3a31295
Show file tree
Hide file tree
Showing 30 changed files with 107 additions and 57 deletions.
48 changes: 33 additions & 15 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down Expand Up @@ -214,11 +216,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -350,11 +354,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -399,11 +405,13 @@ async def handle_request(self, request: Request):
chat_request = AudioChatCompletionRequest.parse_obj(data)
parameters = LLMParams(
# relatively lower max_tokens for audio conversation
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand All @@ -428,11 +436,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -472,11 +482,13 @@ async def handle_request(self, request: Request):
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -520,7 +532,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -569,7 +583,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
Expand Down Expand Up @@ -758,7 +774,9 @@ async def handle_request(self, request: Request):
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
Expand Down
4 changes: 3 additions & 1 deletion comps/cores/proto/api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,9 @@ class AudioChatCompletionRequest(BaseModel):
max_tokens: Optional[int] = 1024
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
presence_penalty: Optional[float] = 1.03
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None


Expand Down Expand Up @@ -345,6 +346,7 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.03
user: Optional[str] = None
use_beam_search: Optional[bool] = False
best_of: Optional[int] = None
Expand Down
6 changes: 6 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,14 @@ class RerankedDoc(BaseDoc):
class LLMParamsDoc(BaseDoc):
model: Optional[str] = None # for openai and ollama
query: str
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down Expand Up @@ -179,11 +182,14 @@ def chat_template_must_contain_variables(cls, v):


class LLMParams(BaseDoc):
max_tokens: int = 1024
max_new_tokens: int = 1024
top_k: int = 10
top_p: float = 0.95
typical_p: float = 0.95
temperature: float = 0.01
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
repetition_penalty: float = 1.03
streaming: bool = True

Expand Down
2 changes: 1 addition & 1 deletion comps/llms/faq-generation/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
Expand Down
3 changes: 3 additions & 0 deletions comps/llms/faq-generation/tgi/langchain/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ docarray[full]
fastapi
huggingface_hub
langchain
langchain-huggingface
langchain-openai
langchain_community
langchainhub
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/summarization/tgi/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def llm_generate(input: LLMParamsDoc):
llm_endpoint = os.getenv("TGI_LLM_ENDPOINT", "http://localhost:8080")
llm = HuggingFaceEndpoint(
endpoint_url=llm_endpoint,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
typical_p=input.typical_p,
Expand Down
6 changes: 3 additions & 3 deletions comps/llms/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ curl http://${your_ip}:8008/v1/chat/completions \

### 3.3 Consume LLM Service

You can set the following model parameters according to your actual needs, such as `max_new_tokens`, `streaming`.
You can set the following model parameters according to your actual needs, such as `max_tokens`, `streaming`.

The `streaming` parameter determines the format of the data returned by the API. It will return text string with `streaming=false`, return text streaming flow with `streaming=true`.

Expand All @@ -385,7 +385,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"query":"What is Deep Learning?",
"max_new_tokens":17,
"max_tokens":17,
"top_k":10,
"top_p":0.95,
"typical_p":0.95,
Expand All @@ -401,7 +401,7 @@ curl http://${your_ip}:9000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"query":"What is Deep Learning?",
"max_new_tokens":17,
"max_tokens":17,
"top_k":10,
"top_p":0.95,
"typical_p":0.95,
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ollama/langchain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ docker run --network host -e http_proxy=$http_proxy -e https_proxy=$https_proxy
## Consume the Ollama Microservice

```bash
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_new_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json'
```
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ollama/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def llm_generate(input: LLMParamsDoc):
ollama = Ollama(
base_url=ollama_endpoint,
model=input.model if input.model else model_name,
num_predict=input.max_new_tokens,
num_predict=input.max_tokens,
top_k=input.top_k,
top_p=input.top_p,
temperature=input.temperature,
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/predictionguard/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ curl -X POST http://localhost:9000/v1/chat/completions \
-d '{
"model": "Hermes-2-Pro-Llama-3-8B",
"query": "Tell me a joke.",
"max_new_tokens": 100,
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
Expand All @@ -45,7 +45,7 @@ curl -N -X POST http://localhost:9000/v1/chat/completions \
-d '{
"model": "Hermes-2-Pro-Llama-3-8B",
"query": "Tell me a joke.",
"max_new_tokens": 100,
"max_tokens": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def stream_generator():
for res in client.chat.completions.create(
model=input.model,
messages=messages,
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p,
top_k=input.top_k,
Expand All @@ -69,7 +69,7 @@ async def stream_generator():
response = client.chat.completions.create(
model=input.model,
messages=messages,
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
top_p=input.top_p,
top_k=input.top_k,
Expand Down
2 changes: 1 addition & 1 deletion comps/llms/text-generation/ray_serve/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def llm_generate(input: LLMParamsDoc):
openai_api_base=llm_endpoint + "/v1",
model_name=llm_model,
openai_api_key=os.getenv("OPENAI_API_KEY", "not_needed"),
max_tokens=input.max_new_tokens,
max_tokens=input.max_tokens,
temperature=input.temperature,
streaming=input.streaming,
request_timeout=600,
Expand Down
20 changes: 13 additions & 7 deletions comps/llms/text-generation/tgi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,36 +88,42 @@ curl http://${your_ip}:9000/v1/health_check\

### 3.2 Consume LLM Service

You can set the following model parameters according to your actual needs, such as `max_new_tokens`, `streaming`.
You can set the following model parameters according to your actual needs, such as `max_tokens`, `streaming`.

The `streaming` parameter determines the format of the data returned by the API. It will return text string with `streaming=false`, return text streaming flow with `streaming=true`.

```bash
# non-streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":false}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":false}' \
-H 'Content-Type: application/json'

# streaming mode
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' \
-H 'Content-Type: application/json'

# custom chat template
# consume with SearchedDoc
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"query":"What is Deep Learning?","max_new_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-H 'Content-Type: application/json'
```

# consume with SearchedDoc
For parameters in above modes, please refer to [HuggingFace InferenceClient API](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation) (except we rename 'max_new_tokens' to 'max_tokens')

```bash
# custom chat template
curl http://${your_ip}:9000/v1/chat/completions \
-X POST \
-d '{"initial_query":"What is Deep Learning?","retrieved_docs":[{"text":"Deep Learning is a ..."},{"text":"Deep Learning is b ..."}]}' \
-d '{"query":"What is Deep Learning?","max_tokens":17,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"presence_penalty":1.03", frequency_penalty":0.0, "streaming":true, "chat_template":"### You are a helpful, respectful and honest assistant to help the user with questions.\n### Context: {context}\n### Question: {question}\n### Answer:"}' \
-H 'Content-Type: application/json'
```

For parameters in Chat mode, please refer to [OpenAI API](https://platform.openai.com/docs/api-reference/chat/create)

### 4. Validated Model

| Model | TGI |
Expand Down
4 changes: 2 additions & 2 deletions comps/llms/text-generation/tgi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def llm_generate(input: Union[LLMParamsDoc, ChatCompletionRequest, Searche
text_generation = await llm.text_generation(
prompt=prompt,
stream=new_input.streaming,
max_new_tokens=new_input.max_new_tokens,
max_new_tokens=new_input.max_tokens,
repetition_penalty=new_input.repetition_penalty,
temperature=new_input.temperature,
top_k=new_input.top_k,
Expand Down Expand Up @@ -119,7 +119,7 @@ async def stream_generator():
text_generation = await llm.text_generation(
prompt=prompt,
stream=input.streaming,
max_new_tokens=input.max_new_tokens,
max_new_tokens=input.max_tokens,
repetition_penalty=input.repetition_penalty,
temperature=input.temperature,
top_k=input.top_k,
Expand Down
Loading

0 comments on commit 3a31295

Please sign in to comment.