Skip to content

Commit

Permalink
Merge pull request #1 from superagent-ai/vectorstore
Browse files Browse the repository at this point in the history
Vectorstore
  • Loading branch information
homanp committed Jan 16, 2024
2 parents ab84f29 + 8893c74 commit 4d7e45a
Show file tree
Hide file tree
Showing 15 changed files with 436 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
API_BASE_URL=https://rag.superagent.sh
COHERE_API_KEY=
HUGGINGFACE_API_KEY=
1 change: 0 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ jobs:
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
Expand Down
Empty file added api/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions api/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict
from fastapi import APIRouter
from models.ingest import RequestPayload
from service.embedding import EmbeddingService

router = APIRouter()


@router.post("/ingest")
async def ingest(payload: RequestPayload) -> Dict:
embedding_service = EmbeddingService(
files=payload.files,
index_name=payload.index_name,
vector_credentials=payload.vector_database,
)
documents = await embedding_service.generate_documents()
chunks = await embedding_service.generate_chunks(documents=documents)
await embedding_service.generate_embeddings(nodes=chunks)
return {"success": True}
16 changes: 16 additions & 0 deletions api/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fastapi import APIRouter
from models.query import RequestPayload, ResponsePayload
from service.vector_database import get_vector_service, VectorService

router = APIRouter()


@router.post("/query", response_model=ResponsePayload)
async def query(payload: RequestPayload):
vector_service: VectorService = get_vector_service(
index_name=payload.index_name, credentials=payload.vector_database
)
chunks = await vector_service.query(input=payload.input, top_k=4)
documents = await vector_service.convert_to_dict(points=chunks)
results = await vector_service.rerank(query=payload.input, documents=documents)
return {"success": True, "data": results}
27 changes: 3 additions & 24 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Dict, List
from enum import Enum
from pydantic import BaseModel
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from decouple import config
from router import router


app = FastAPI(
title="SuperRag",
Expand All @@ -21,24 +20,4 @@
allow_headers=["*"],
)


class DatabaseType(Enum):
qdrant = "qdrant"
pinecone = "pinecone"
weaviate = "weaviate"
astra = "astra"


class VectorDatabase(BaseModel):
type: DatabaseType
config: Dict


class RequestPayload(BaseModel):
files: List
vector_database: VectorDatabase


@app.post("/ingest")
async def ingest(payload: RequestPayload) -> Dict:
return payload.model_dump()
app.include_router(router)
17 changes: 17 additions & 0 deletions models/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from enum import Enum
from pydantic import BaseModel


class FileType(Enum):
pdf = "PDF"
docx = "DOCX"
txt = "TXT"
pptx = "PPTX"
csv = "CSV"
xlsx = "XLSX"
md = "MARKDOWN"


class File(BaseModel):
type: FileType
url: str
10 changes: 10 additions & 0 deletions models/ingest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import List
from pydantic import BaseModel
from models.file import File
from models.vector_database import VectorDatabase


class RequestPayload(BaseModel):
files: List[File]
vector_database: VectorDatabase
index_name: str
20 changes: 20 additions & 0 deletions models/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pydantic import BaseModel
from typing import List
from models.vector_database import VectorDatabase


class RequestPayload(BaseModel):
input: str
vector_database: VectorDatabase
index_name: str


class ResponseData(BaseModel):
content: str
file_url: str
page_label: str


class ResponsePayload(BaseModel):
success: bool
data: List[ResponseData]
15 changes: 15 additions & 0 deletions models/vector_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Dict
from enum import Enum
from pydantic import BaseModel


class DatabaseType(Enum):
qdrant = "qdrant"
pinecone = "pinecone"
weaviate = "weaviate"
astra = "astra"


class VectorDatabase(BaseModel):
type: DatabaseType
config: Dict
63 changes: 63 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,27 +1,90 @@
aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.2.0
attrs==23.2.0
Authlib==1.3.0
backoff==2.2.1
beautifulsoup4==4.12.2
black==23.12.1
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
cohere==4.42
cryptography==41.0.7
dataclasses-json==0.6.3
Deprecated==1.2.14
distro==1.9.0
dnspython==2.4.2
fastapi==0.109.0
fastavro==1.9.3
frozenlist==1.4.1
fsspec==2023.12.2
greenlet==3.0.3
grpcio==1.60.0
grpcio-tools==1.60.0
h11==0.14.0
h2==4.1.0
hpack==4.0.0
httpcore==1.0.2
httptools==0.6.1
httpx==0.26.0
hyperframe==6.0.1
idna==3.6
importlib-metadata==6.11.0
joblib==1.3.2
llama-index==0.9.30
loguru==0.7.2
marshmallow==3.20.2
multidict==6.0.4
mypy-extensions==1.0.0
nest-asyncio==1.5.8
networkx==3.2.1
nltk==3.8.1
numpy==1.26.3
openai==1.7.2
packaging==23.2
pandas==2.1.4
pathspec==0.12.1
pinecone-client==2.2.4
platformdirs==4.1.0
portalocker==2.8.2
protobuf==4.25.2
pycparser==2.21
pydantic==2.5.3
pydantic_core==2.14.6
pypdf==3.17.4
python-dateutil==2.8.2
python-decouple==3.8
python-dotenv==1.0.0
pytz==2023.3.post1
PyYAML==6.0.1
qdrant-client==1.7.0
regex==2023.12.25
requests==2.31.0
ruff==0.1.13
setuptools==69.0.3
six==1.16.0
sniffio==1.3.0
soupsieve==2.5
SQLAlchemy==2.0.25
starlette==0.35.1
tenacity==8.2.3
tiktoken==0.5.2
toml==0.10.2
tqdm==4.66.1
typing-inspect==0.9.0
typing_extensions==4.9.0
tzdata==2023.4
urllib3==1.26.18
uvicorn==0.25.0
uvloop==0.19.0
validators==0.22.0
vulture==2.10
watchfiles==0.21.0
weaviate-client==3.26.0
websockets==12.0
wrapt==1.16.0
yarl==1.9.4
zipp==3.17.0
9 changes: 9 additions & 0 deletions router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from fastapi import APIRouter

from api import ingest, query

router = APIRouter()
api_prefix = "/api/v1"

router.include_router(ingest.router, tags=["Ingest"], prefix=api_prefix)
router.include_router(query.router, tags=["Query"], prefix=api_prefix)
Empty file added service/__init__.py
Empty file.
81 changes: 81 additions & 0 deletions service/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import requests
import asyncio

from typing import Any, List, Union
from tempfile import NamedTemporaryFile
from llama_index import Document, SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from litellm import aembedding
from models.file import File
from decouple import config
from service.vector_database import get_vector_service


class EmbeddingService:
def __init__(self, files: List[File], index_name: str, vector_credentials: dict):
self.files = files
self.index_name = index_name
self.vector_credentials = vector_credentials

def _get_datasource_suffix(self, type: str) -> str:
suffixes = {"TXT": ".txt", "PDF": ".pdf", "MARKDOWN": ".md"}
try:
return suffixes[type]
except KeyError:
raise ValueError("Unsupported datasource type")

async def generate_documents(self) -> List[Document]:
documents = []
for file in self.files:
suffix = self._get_datasource_suffix(file.type.value)
with NamedTemporaryFile(suffix=suffix, delete=True) as temp_file:
response = requests.get(url=file.url)
temp_file.write(response.content)
temp_file.flush()
reader = SimpleDirectoryReader(input_files=[temp_file.name])
docs = reader.load_data()
for doc in docs:
doc.metadata["file_url"] = file.url
documents.extend(docs)
return documents

async def generate_chunks(
self, documents: List[Document]
) -> List[Union[Document, None]]:
parser = SimpleNodeParser.from_defaults(chunk_size=350, chunk_overlap=20)
nodes = parser.get_nodes_from_documents(documents, show_progress=False)
return nodes

async def generate_embeddings(
self,
nodes: List[Union[Document, None]],
) -> List[tuple[str, list, dict[str, Any]]]:
async def generate_embedding(node):
if node is not None:
vectors = []
embedding_object = await aembedding(
model="huggingface/intfloat/multilingual-e5-large",
input=node.text,
api_key=config("HUGGINGFACE_API_KEY"),
)
for vector in embedding_object.data:
if vector["object"] == "embedding":
vectors.append(vector["embedding"])
embedding = (
node.id_,
vectors,
{
**node.metadata,
"content": node.text,
},
)
return embedding

tasks = [generate_embedding(node) for node in nodes]
embeddings = await asyncio.gather(*tasks)
vector_service = get_vector_service(
index_name=self.index_name, credentials=self.vector_credentials
)
await vector_service.upsert(embeddings=[e for e in embeddings if e is not None])

return [e for e in embeddings if e is not None]
Loading

0 comments on commit 4d7e45a

Please sign in to comment.