-
Notifications
You must be signed in to change notification settings - Fork 24
/
main.py
219 lines (186 loc) · 6.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import asyncio
import logging as lg
import time
import uuid
import os
import typing as tp
import ipaddress
import httpx
from fastapi import FastAPI
from src.initialize import initialize_app, initialize_logging
from src.utils import inject_additional_attributes, timeit
from langtrace_python_sdk import SendUserFeedback, langtrace
from langtrace_python_sdk.utils.with_root_span import with_langtrace_root_span
langtrace.init(api_key=os.environ.get('LANGTRACE_API_KEY'))
initialize_logging()
APP = FastAPI()
INIT_OBJECTS = initialize_app()
DEFAULT_INPUT_QUERY = (
"¿Es de aplicación la ley de garantía integral de la libertad sexual a niños (varones) menores de edad "
"víctimas de violencias sexuales o solo a niñas y mujeres?"
)
DEFAULT_COLLECTION_NAME = "justicio"
@with_langtrace_root_span()
async def call_llm_api(span_id, trace_id, model_name: str, messages: tp.List[tp.Dict[str, str]]):
response = await INIT_OBJECTS.openai_client.chat.completions.create(
model=model_name,
messages=messages,
temperature=INIT_OBJECTS.config_loader["temperature"],
seed=INIT_OBJECTS.config_loader["seed"],
max_tokens=INIT_OBJECTS.config_loader["max_tokens"],
)
return response, span_id, trace_id
@APP.get("/healthcheck")
@timeit
async def healthcheck():
"""Asynchronous Health Check"""
# TODO: healthcheck with embeddings db api and llm api
return {"status": "OK"}
@APP.get("/semantic_search")
@timeit
async def semantic_search(input_query: str = DEFAULT_INPUT_QUERY, collection_name: str = DEFAULT_COLLECTION_NAME):
logger = lg.getLogger(semantic_search.__name__)
logger.info(input_query)
docs = await INIT_OBJECTS.vector_store[collection_name].asimilarity_search_with_score(
query=input_query, k=INIT_OBJECTS.config_loader["top_k_results"]
)
logger.info(docs)
return docs
@APP.get("/semantic_search_tavily")
@timeit
async def semantic_search_tavily(input_query: str = DEFAULT_INPUT_QUERY):
logger = lg.getLogger(semantic_search_tavily.__name__)
logger.info(input_query)
docs = INIT_OBJECTS.tavily_client.search(
query=input_query,
search_depth="advanced",
include_domains=["https://www.boe.es/"],
max_results=10,
topic="general",
include_raw_content=False,
include_answer=False,
)
logger.info(docs)
return docs
async def a_request_get(url):
"""Requests for sync/async load tests"""
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url)
return response.text
@APP.get("/qa_feedback")
@with_langtrace_root_span("Feedback")
@timeit
async def qa_feedback(span_id: str, trace_id: str, user_score: int):
data = {
"spanId": span_id, "traceId": trace_id, "userScore": user_score, "userId": None
}
SendUserFeedback().evaluate(data=data)
return {"feedback": "OK"}
@APP.get("/qa")
@with_langtrace_root_span("RAG Justicio")
@timeit
async def qa(
input_query: str = DEFAULT_INPUT_QUERY,
collection_name: str = DEFAULT_COLLECTION_NAME,
model_name: str = INIT_OBJECTS.config_loader["llm_model_name"],
input_original_query: str | None = None,
ip_request_client: ipaddress.IPv4Address | None = None,
):
logger = lg.getLogger(qa.__name__)
logger.info(input_query)
# Getting context from embedding database (Qdrant)
docs = await INIT_OBJECTS.vector_store[collection_name].asimilarity_search_with_score(
query=input_query, k=INIT_OBJECTS.config_loader["top_k_results"]
)
# Generate response using a LLM (OpenAI)
context_preprocessed = [{"context": doc[0].page_content, "score": doc[1]} for doc in docs]
messages = [
{"role": "system", "content": INIT_OBJECTS.config_loader["prompt_system"]},
{
"role": "system",
"content": INIT_OBJECTS.config_loader["prompt_system_context"],
},
{"role": "system", "content": "A continuación se proporciona el contexto:"},
{"role": "system", "content": str(context_preprocessed)},
{
"role": "system",
"content": "A continuación se proporciona la pregunta del usuario:",
},
{"role": "user", "content": input_query},
]
# logger.info(messages)
additional_attributes = {
"db.collection.name": collection_name,
"service.ip": ip_request_client,
"llm.original_query": input_original_query
}
response, span_id, trace_id = await inject_additional_attributes(
lambda: call_llm_api(model_name=model_name, messages=messages), additional_attributes
)
answer = response.choices[0].message.content
logger.info(answer)
logger.info(response.usage)
response_payload = dict(
scoring_id=str(uuid.uuid4()),
context=docs,
answer=answer,
span_id=str(span_id),
trace_id=str(trace_id),
)
return response_payload
@APP.get("/qa_tavily")
@timeit
async def qa_tavily(input_query: str = DEFAULT_INPUT_QUERY):
logger = lg.getLogger(qa_tavily.__name__)
logger.info(input_query)
# Getting context from internet browser (Tavily)
docs = INIT_OBJECTS.tavily_client.search(
query=input_query,
search_depth="advanced",
include_domains=["https://www.boe.es/"],
max_results=10,
topic="general",
include_raw_content=False,
include_answer=False,
)
# Generate response using a LLM (OpenAI)
context_preprocessed = [{"context": doc["content"], "score": doc["score"]} for doc in docs["results"]]
response = await INIT_OBJECTS.openai_client.chat.completions.create(
model=INIT_OBJECTS.config_loader["llm_model_name"],
messages=[
{"role": "system", "content": INIT_OBJECTS.config_loader["prompt_system"]},
{
"role": "system",
"content": INIT_OBJECTS.config_loader["prompt_system_context"],
},
{"role": "system", "content": "A continuación se proporciona el contexto:"},
{"role": "system", "content": str(context_preprocessed)},
{
"role": "system",
"content": "A continuación se proporciona la pregunta del usuario:",
},
{"role": "user", "content": input_query},
],
temperature=INIT_OBJECTS.config_loader["temperature"],
seed=INIT_OBJECTS.config_loader["seed"],
max_tokens=INIT_OBJECTS.config_loader["max_tokens"],
)
answer = response.choices[0].message.content
logger.info(answer)
logger.info(response.usage)
response_payload = dict(
scoring_id=str(uuid.uuid4()),
context=docs,
answer=answer,
)
return response_payload
@APP.get("/sleep")
@timeit
async def sleep():
time.sleep(5)
return {"status": "OK"}
@APP.get("/asleep")
@timeit
async def asleep():
await asyncio.sleep(5)
return {"status": "OK"}