Skip to content

Commit

Permalink
fixed DTOs
Browse files Browse the repository at this point in the history
  • Loading branch information
dvarelas committed Jan 22, 2024
1 parent 6148fa2 commit 3c8cbea
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 42 deletions.
4 changes: 0 additions & 4 deletions llm-gateway/llm_gateway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
__version__ = "0.1.0"
from pathlib import Path
from dotenv import load_dotenv

load_dotenv(Path(".env"))
17 changes: 4 additions & 13 deletions llm-gateway/llm_gateway/api/v1/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse

from llm_gateway.schemas.chat import CreateCompletionDTO, RoleItem, MessageItem, ResponseCompletionDTO
from llm_gateway.services import chat
Expand All @@ -18,19 +17,11 @@ async def create(request: CreateCompletionDTO):
role=RoleItem(message.role),
content=message.content
) for message in request.messages
]
],
temperature=request.temperature,
max_tokens=request.max_tokens
)

response = await chat.completions(parsed_message)

response_dto = {
"id": response["id"],
"provider": request["provider"],
"object": response["object"],
"created": response["created"],
"model": response["model"],
"choices": response["choices"],
"usage": response["usage"]
}

return JSONResponse(content=response_dto)
return response
11 changes: 1 addition & 10 deletions llm-gateway/llm_gateway/api/v1/endpoints/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse

from llm_gateway.schemas.embeddings import CreateEmbeddingDTO, ResponseEmbeddingDTO
from llm_gateway.services import embeddings
Expand All @@ -18,12 +17,4 @@ async def create(request: CreateEmbeddingDTO):

response = await embeddings.embeddings(parsed_message)

response_dto = {
"data": response["data"],
"provider": request["provider"],
"model": response["model"],
"object": response["object"],
"usage": response["usage"]
}

return JSONResponse(content=response_dto)
return response
9 changes: 1 addition & 8 deletions llm-gateway/llm_gateway/api/v1/endpoints/images.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from fastapi import APIRouter
from fastapi.responses import JSONResponse

from llm_gateway.services import images
from llm_gateway.schemas.images import CreateImageDTO, ResponseImageDTO
Expand All @@ -20,10 +19,4 @@ async def create(request: CreateImageDTO):

response = await images.generations(parsed_message)

response_dto = {
"provider": request["provider"],
"created": response["created"],
"data": response["data"]
}

return JSONResponse(content=response_dto)
return response
17 changes: 13 additions & 4 deletions llm-gateway/llm_gateway/providers/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ async def completions(self, message: CreateCompletionDTO):
request = {
"model": message.model,
"messages": message.messages,
"max_tokens": 10
"temperature": message.temperature,
"max_tokens": message.max_tokens,
"stream": False
}
try:
response = self.client.chat.completions.create(**request)
return ResponseCompletionDTO(**response.to_dict())
response_dict = response.model_dump()
response_dict["provider"] = message.provider
return ResponseCompletionDTO(**response_dict)
except Exception as e:
logger.error(e)
raise InternalServerError()
Expand All @@ -34,7 +38,9 @@ async def embeddings(self, embedding_input: CreateEmbeddingDTO):
}
try:
response = self.client.embeddings.create(**request)
return ResponseEmbeddingDTO(**response.to_dict())
response_dict = response.model_dump()
response_dict["provider"] = embedding_input.provider
return ResponseEmbeddingDTO(**response_dict)
except Exception as e:
logger.error(e)
raise InternalServerError()
Expand All @@ -48,7 +54,10 @@ async def generations(self, image_input: CreateImageDTO):
}
try:
response = self.client.images.generate(**request)
return ResponseImageDTO(**response.to_dict())
response_dict = response.model_dump()
response_dict["provider"] = image_input.provider
response_dict["model"] = image_input.model
return ResponseImageDTO(**response_dict)
except Exception as e:
logger.error(e)
raise InternalServerError()
3 changes: 3 additions & 0 deletions llm-gateway/llm_gateway/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class RoleItem(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"


class MessageItem(BaseModel):
Expand All @@ -19,6 +20,8 @@ class CreateCompletionDTO(BaseModel):
provider: str
model: str
messages: List[MessageItem]
temperature: float
max_tokens: int


class ChoiceItem(BaseModel):
Expand Down
5 changes: 2 additions & 3 deletions llm-gateway/llm_gateway/schemas/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ class EmbeddingItem(BaseModel):


class DataItem(BaseModel):
data: List[EmbeddingItem]
model: str
object: str
usage: Usage
embedding: List[float]
index: int


class ResponseEmbeddingDTO(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions llm-gateway/llm_gateway/schemas/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ class DataItem(BaseModel):

class ResponseImageDTO(BaseModel):
provider: str
model: str
created: int
data: List[DataItem]

0 comments on commit 3c8cbea

Please sign in to comment.