Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vertex AI client #510

Merged
merged 6 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ LITELLM_POSTGRES_PASSWORD=<your_litellm_postgres_password>
LITELLM_MASTER_KEY=<your_litellm_master_key>
LITELLM_SALT_KEY=<your_litellm_salt_key>
LITELLM_REDIS_PASSWORD=<your_litellm_redis_password>
EMBEDDING_SERVICE_BASE=http://text-embeddings-inference-<gpu|cpu> # Use the 'gpu' profile to run on GPU

# Memory Store
# -----------
Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients import cozo, litellm
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -14,8 +13,8 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await embedder.embed(
[
embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
Expand Down
28 changes: 0 additions & 28 deletions agents-api/agents_api/clients/embed.py

This file was deleted.

63 changes: 59 additions & 4 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
from functools import wraps
from typing import List
from typing import List, Literal

import litellm
from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import get_supported_openai_params
from litellm import (
acompletion as _acompletion,
)
from litellm import (
aembedding as _aembedding,
)
from litellm import (
get_supported_openai_params,
)
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url
from ..env import (
embedding_dimensions,
embedding_model_id,
litellm_master_key,
litellm_url,
)

__all__: List[str] = ["acompletion"]

# TODO: Should check if this is really needed
litellm.drop_params = True


@wraps(_acompletion)
@beartype
async def acompletion(
*, model: str, messages: list[dict], custom_api_key: None | str = None, **kwargs
) -> ModelResponse | CustomStreamWrapper:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

supported_params = get_supported_openai_params(model)
settings = {k: v for k, v in kwargs.items() if k in supported_params}
Expand All @@ -27,3 +45,40 @@ async def acompletion(
base_url=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
)


@wraps(_aembedding)
@beartype
async def aembedding(
*,
inputs: str | list[str],
model: str = embedding_model_id,
dimensions: int = embedding_dimensions,
join_inputs: bool = False,
custom_api_key: None | str = None,
**settings,
) -> list[list[float]]:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

if isinstance(inputs, str):
input = [inputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable input should be renamed to inputs to avoid confusion and to match the function parameter name.

else:
input = ["\n\n".join(inputs)] if join_inputs else inputs

response = await _aembedding(
model=model,
input=input,
# dimensions=dimensions, # FIXME: litellm doesn't support dimensions correctly
api_base=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
drop_params=True,
**settings,
)

embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# FIXME: Truncation should be handled by litellm
result = [embedding["embedding"][:dimensions] for embedding in embedding_list]

return result
10 changes: 4 additions & 6 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)


# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand All @@ -51,6 +52,7 @@

api_key_header_name: str = env.str("AGENTS_API_KEY_HEADER_NAME", default="X-Auth-Key")


# Litellm API
# -----------
litellm_url: str = env.str("LITELLM_URL", default="http://0.0.0.0:4000")
Expand All @@ -59,13 +61,11 @@

# Embedding service
# -----------------
embedding_service_base: str = env.str(
"EMBEDDING_SERVICE_BASE", default="http://0.0.0.0:8082"
)
embedding_model_id: str = env.str(
"EMBEDDING_MODEL_ID", default="Alibaba-NLP/gte-large-en-v1.5"
)
truncate_embed_text: bool = env.bool("TRUNCATE_EMBED_TEXT", default=True)

embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024)


# Temporal
Expand All @@ -91,8 +91,6 @@
api_key_header_name=api_key_header_name,
hostname=hostname,
api_prefix=api_prefix,
embedding_service_base=embedding_service_base,
truncate_embed_text=truncate_embed_text,
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
Expand Down
15 changes: 8 additions & 7 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -61,12 +61,13 @@ async def gather_messages(
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
[query_embedding, *_] = await litellm.aembedding(
inputs="\n\n".join(
[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
]
),
)
query_text = new_raw_messages[-1]["content"]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/create_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/delete_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/delete_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import ResourceDeletedResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/get_agent_details.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import Agent
from ...dependencies.developer_id import get_developer_id
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/list_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import ListResponse, Tool
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/list_agents.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from uuid import UUID

from ...autogen.openapi_model import Agent, ListResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/patch_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
PatchToolRequest,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/update_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
ResourceUpdatedResponse,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID, uuid4

from fastapi import BackgroundTasks, Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/delete_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Annotated

from fastapi import Depends
from uuid import UUID

import agents_api.clients.embed as embedder
from fastapi import Depends

from ...autogen.openapi_model import (
EmbedQueryRequest,
EmbedQueryResponse,
)
from ...clients import litellm
from ...dependencies.developer_id import get_developer_id
from .router import router

Expand All @@ -23,6 +22,6 @@ async def embed(
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await embedder.embed(inputs=text_to_embed)
vectors = await litellm.aembedding(inputs=text_to_embed)

return EmbedQueryResponse(vectors=vectors)
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import Doc
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/list_docs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from uuid import UUID

from ...autogen.openapi_model import Doc, ListResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
DocSearchResponse,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/jobs/routers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal
from uuid import UUID

from fastapi import APIRouter
from uuid import UUID
from temporalio.client import WorkflowExecutionStatus

from agents_api.autogen.openapi_model import JobStatus
Expand Down
Loading
Loading