Skip to content

Commit

Permalink
Update RetrievalTool Gateway to allow user setting retriever/reranker…
Browse files Browse the repository at this point in the history
… params (#730)

* first code of updated retrieval gateway to allow setting retriever/reranker params

Signed-off-by: minmin-intel <[email protected]>

* send chatcompletion request to allow topk params

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* allow multiple input types

Signed-off-by: minmin-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: minmin-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
minmin-intel and pre-commit-ci[bot] committed Sep 26, 2024
1 parent 2f08e5e commit c7c45d6
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,31 +622,66 @@ def __init__(self, megaservice, host="0.0.0.0", port=8889):
host,
port,
str(MegaServiceEndpoint.RETRIEVALTOOL),
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest], # ChatCompletionRequest,
Union[RerankedDoc, LLMParamsDoc], # ChatCompletionResponse
Union[TextDoc, EmbeddingRequest, ChatCompletionRequest],
Union[RerankedDoc, LLMParamsDoc],
)

async def handle_request(self, request: Request):
def parser_input(data, TypeClass, key):
chat_request = None
try:
chat_request = TypeClass.parse_obj(data)
query = getattr(chat_request, key)
except:
query = None
return query
return query, chat_request

data = await request.json()
query = None
for key, TypeClass in zip(["text", "input", "input"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query = parser_input(data, TypeClass, key)
for key, TypeClass in zip(["text", "input", "messages"], [TextDoc, EmbeddingRequest, ChatCompletionRequest]):
query, chat_request = parser_input(data, TypeClass, key)
if query is not None:
break
if query is None:
raise ValueError(f"Unknown request type: {data}")
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})
if chat_request is None:
raise ValueError(f"Unknown request type: {data}")

if isinstance(chat_request, ChatCompletionRequest):
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)

initial_inputs = {
"messages": query,
"input": query, # has to be input due to embedding expects either input or text
"search_type": chat_request.search_type if chat_request.search_type else "similarity",
"k": chat_request.k if chat_request.k else 4,
"distance_threshold": chat_request.distance_threshold if chat_request.distance_threshold else None,
"fetch_k": chat_request.fetch_k if chat_request.fetch_k else 20,
"lambda_mult": chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
"score_threshold": chat_request.score_threshold if chat_request.score_threshold else 0.2,
"top_n": chat_request.top_n if chat_request.top_n else 1,
}

result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs=initial_inputs,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
else:
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"text": query})

last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]
print("response is ", response)
return response


Expand Down

0 comments on commit c7c45d6

Please sign in to comment.