Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add example of search app, that uses multi-modal embeddings #156

Open
wants to merge 4 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions examples/image_search/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
A simple application,that search over attached images by text query,
using multi-modal embeddings for search
"""

import os
from uuid import uuid4

import uvicorn
from attachment import get_image_attachments
from embeddings import ImageDialEmbeddings
from vector_store import DialImageVectorStore

from aidial_sdk import DIALApp
from aidial_sdk import HTTPException as DIALException
from aidial_sdk.chat_completion import ChatCompletion, Request, Response


def get_env(name: str) -> str:
value = os.getenv(name)
if value is None:
raise ValueError(f"Please provide {name!r} environment variable")
return value


DIAL_URL = get_env("DIAL_URL")
EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "multimodalembedding@001")
EMBEDDINGS_DIMENSIONS = int(os.getenv("EMBEDDINGS_DIMENSIONS") or "1408")


class ImageSearchApplication(ChatCompletion):
async def chat_completion(
self, request: Request, response: Response
) -> None:
with response.create_single_choice() as choice:
message = request.messages[-1]
user_query = message.content

if not user_query:
raise DIALException(
message="Please provide search query", status_code=400
)

image_attachments = get_image_attachments(request.messages)
if not image_attachments:
msg = "No attachment with DIAL Storage URL was found"
raise DIALException(
status_code=422,
message=msg,
display_message=msg,
)
# Create a new local vector store to store image embeddings
vector_store = DialImageVectorStore(
collection_name=str(uuid4()),
embedding_function=ImageDialEmbeddings(
dial_url=DIAL_URL,
embeddings_model=EMBEDDINGS_MODEL,
dimensions=EMBEDDINGS_DIMENSIONS,
),
)
# Show user that embeddings of images are being calculated
with choice.create_stage("Calculating image embeddings"):
# For simplicity of example let's take only images,
# that are uploaded to DIAL Storage already
await vector_store.aadd_images(
uris=[att.url for att in image_attachments if att.url],
metadatas=[
{"url": att.url, "type": att.type, "title": att.title}
for att in image_attachments
if att.url
],
)

# Show user that the search is being performed
with choice.create_stage("Searching for most relevant image"):
search_result = await vector_store.asimilarity_search(
query=user_query, k=1
)

if len(search_result) == 0:
msg = "No relevant image found"
raise DIALException(
status_code=404,
message=msg,
display_message=msg,
)

top_result = search_result[0]
choice.add_attachment(
url=top_result.metadata["url"],
title=top_result.metadata["title"],
type=top_result.metadata["type"],
)
vector_store.delete_collection()


app = DIALApp(DIAL_URL, propagate_auth_headers=True)
app.add_chat_completion("image-search", ImageSearchApplication())


if __name__ == "__main__":
uvicorn.run(app, port=5000)
32 changes: 32 additions & 0 deletions examples/image_search/attachment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List, Optional

from aidial_sdk.chat_completion import Message
from aidial_sdk.chat_completion.request import Attachment

DEFAULT_IMAGE_TYPES = ["image/jpeg", "image/png"]


def get_image_attachments(
messages: List[Message], image_types: Optional[List[str]] = None
) -> List[Attachment]:
if image_types is None:
image_types = DEFAULT_IMAGE_TYPES

attachments = []
for message in messages:
if (
message.custom_content is not None
and message.custom_content.attachments is not None
):
attachments = message.custom_content.attachments
for attachment in attachments:
if (
# For simplicity of example let's take only images,
# that are uploaded to DIAL Storage already
attachment.url
and attachment.type
and attachment.type in image_types
):
attachments.append(attachment)

return attachments
57 changes: 57 additions & 0 deletions examples/image_search/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List

import httpx
from langchain_core.embeddings import Embeddings


class ImageDialEmbeddings(Embeddings):
def __init__(
self,
dial_url: str,
embeddings_model: str,
dimensions: int,
) -> None:
self._dial_url = dial_url
self._embeddings_url = (
f"{self._dial_url}/openai/deployments/{embeddings_model}/embeddings"
)
self._dimensions = dimensions
self._client = httpx.Client()

def embed_documents(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError(
"This embeddings should not be used with text documents"
)

def embed_query(self, text: str) -> List[float]:
# Auth headers are propagated by the DIALApp
response = self._client.post(
self._embeddings_url,
json={"input": [text], "dimensions": self._dimensions},
)
data = response.json()
assert data.get("data") and len(data.get("data")) == 1
return data.get("data")[0].get("embedding")

def embed_image(self, uris: List[str]) -> List[List[float]]:
result = []
for uri in uris:
# Auth headers are propagated by the DIALApp
response = self._client.post(
self._embeddings_url,
json={
"input": [],
"dimensions": self._dimensions,
"custom_input": [
{
"type": "image/png",
"url": uri,
}
],
},
)
data = response.json()
assert data.get("data") and len(data.get("data")) == 1
result.append(data.get("data")[0].get("embedding"))
assert len(result) == len(uris)
return result
5 changes: 5 additions & 0 deletions examples/image_search/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
aidial-sdk>=0.10
langchain-core==0.2.9
langchain-community==0.2.9
chromadb==0.5.4
uvicorn==0.30.1
20 changes: 20 additions & 0 deletions examples/image_search/vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import List, Optional

from langchain_community.vectorstores import Chroma
from langchain_core.runnables.config import run_in_executor


class DialImageVectorStore(Chroma):
def encode_image(self, uri: str) -> str:
"""
Overload of Chroma encode_image method, that does not download image content
"""
return uri

async def aadd_images(
self, uris: List[str], metadatas: Optional[List[dict]] = None
):
"""
Async version of add_images, that is present in Chroma
"""
return await run_in_executor(None, self.add_images, uris, metadatas)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ exclude = [
".pytest_cache",
"**/__pycache__",
"build",
"examples/langchain_rag"
"examples/langchain_rag",
"examples/image_search"
]

[tool.black]
Expand Down