Skip to content

Commit

Permalink
Add support for astra (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
homanp authored Jan 17, 2024
1 parent fed2193 commit be2c833
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
astrapy==0.7.0
attrs==23.2.0
Authlib==1.3.0
backoff==2.2.1
beautifulsoup4==4.12.2
black==23.12.1
cassandra-driver==3.29.0
cassio==0.1.4
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
Expand All @@ -22,6 +25,7 @@ fastavro==1.9.3
filelock==3.13.1
frozenlist==1.4.1
fsspec==2023.12.2
geomet==0.2.1.post1
greenlet==3.0.3
grpcio==1.60.0
grpcio-tools==1.60.0
Expand All @@ -30,7 +34,7 @@ h2==4.1.0
hpack==4.0.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.26.0
httpx==0.25.2
huggingface-hub==0.20.2
hyperframe==6.0.1
idna==3.6
Expand Down
50 changes: 50 additions & 0 deletions service/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from qdrant_client import QdrantClient
from qdrant_client.http import models as rest
from pinecone import Pinecone, ServerlessSpec
from astrapy.db import AstraDB

from models.vector_database import VectorDatabase

Expand Down Expand Up @@ -242,13 +243,62 @@ async def query(self, input: str, top_k: int = 4) -> List:
return result["data"]["Get"][self.index_name.capitalize()]


class AstraService(VectorService):
def __init__(self, index_name: str, dimension: int, credentials: dict):
super().__init__(
index_name=index_name, dimension=dimension, credentials=credentials
)
self.client = AstraDB(
token=credentials["api_key"],
api_endpoint=credentials["host"],
)
collections = self.client.get_collections()
if self.index_name not in collections["status"]["collections"]:
self.collection = self.client.create_collection(
dimension=dimension, collection_name=index_name
)
self.collection = self.client.collection(collection_name=self.index_name)

async def convert_to_rerank_format(self, chunks: List) -> List:
docs = [
{
"content": chunk.get("text"),
"page_label": chunk.get("page_label"),
"file_url": chunk.get("file_url"),
}
for chunk in chunks
]
return docs

async def upsert(self, embeddings: List[tuple[str, list, dict[str, Any]]]) -> None:
documents = [
{
"_id": _embedding[0],
"text": _embedding[2]["content"],
"$vector": _embedding[1],
**_embedding[2],
}
for _embedding in embeddings
]
for i in range(0, len(documents), 5):
self.collection.insert_many(documents=documents[i : i + 5])

async def query(self, input: str, top_k: int = 4) -> List:
vectors = await self._generate_vectors(input=input)
results = self.collection.vector_find(
vector=vectors, limit=top_k, fields={"text", "page_label", "file_url"}
)
return results


def get_vector_service(
index_name: str, credentials: VectorDatabase, dimension: int = 1024
) -> Type[VectorService]:
services = {
"pinecone": PineconeVectorService,
"qdrant": QdrantService,
"weaviate": WeaviateService,
"astra": AstraService,
# Add other providers here
# e.g "weaviate": WeaviateVectorService,
}
Expand Down

0 comments on commit be2c833

Please sign in to comment.