Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lite llm #36

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions nagato/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def predict_with_embedding(
input: str,
provider: str,
model: str,
embedding_provider: str,
vector_db: str,
embedding_model: str,
embedding_filter_id: str,
callback: Callable = None,
system_prompt: str = None,
system_prompt: str = "You are a helpful assistant",
enable_streaming: bool = False,
) -> dict:
from nagato.service.query import get_query_service
Expand All @@ -88,9 +88,10 @@ def predict_with_embedding(
query=input,
model=embedding_model,
filter_id=embedding_filter_id,
provider=embedding_provider,
vector_db=vector_db,
)
context = similarity_search["results"][0]["matches"]
docs = similarity_search["results"][0]["matches"]
context = docs[0]["metadata"]["content"]
query_service = get_query_service(provider=provider, model=model)
output = query_service.predict_with_embedding(
input=input,
Expand All @@ -105,7 +106,7 @@ def predict_with_embedding(
def query_embedding(
query: str,
model: str = "thenlper/gte-small",
provider: str = "PINECONE",
vector_db: str = "PINECONE",
filter_id: str = None,
top_k: int = 5,
) -> dict:
Expand All @@ -115,7 +116,7 @@ def query_embedding(

embedding_model = SentenceTransformer(model, use_auth_token=config("HF_API_KEY"))
vectordb = get_vector_service(
provider=provider,
provider=vector_db,
index_name=MODEL_TO_INDEX[model].get("index_name"),
filter_id=filter_id,
dimension=MODEL_TO_INDEX[model].get("dimensions"),
Expand Down
7 changes: 6 additions & 1 deletion nagato/service/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
"thenlper/gte-base": {"index_name": "gte-base", "dimensions": 768},
"thenlper/gte-small": {"index_name": "gte-small", "dimensions": 384},
"thenlper/gte-large": {"index_name": "gte-large", "dimensions": 1024},
"infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768}
"infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768},
"BAAI/bge-large-en-v1.5": {"index_name": "bge-large", "dimensions": 1024},
"jinaai/jina-embeddings-v2-base-en": {
"index_name": "jina-embeddings-v2",
"dimensions": 768,
}
# Add more mappings here as needed
}

Expand Down
10 changes: 10 additions & 0 deletions nagato/service/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
)


def generate_replicate_rag_prompt(context: str, input: str) -> str:
prompt = (
"You are a helpful assistant that's an expert at answering questions.\n"
"Use the following context to answer any quesitons.\n\n"
f"Context:\n{context}\n\n"
f"Question:\n{input}"
)
return prompt


def generate_qa_pair_prompt(
format: str, context: str, num_of_qa_pairs: int = 10
) -> str:
Expand Down
70 changes: 43 additions & 27 deletions nagato/service/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from abc import ABC, abstractmethod
from typing import Callable

import replicate
import litellm
from decouple import config

from nagato.service.prompts import (
generate_replicaste_system_prompt,
generate_replicate_rag_prompt,
)


Expand All @@ -17,6 +17,12 @@ def __init__(
):
self.provider = provider
self.model = model
if self.provider == "REPLICATE":
self.api_key = config("REPLICATE_API_KEY")
elif self.provider == "OPENAI":
self.api_key = config("OPENAI_API_KEY")
else:
self.api_key = None

@abstractmethod
def predict(
Expand Down Expand Up @@ -47,43 +53,53 @@ def predict(
system_prompt: str = None,
callback: Callable = None,
):
client = replicate.Client(api_token=config("REPLICATE_API_KEY"))
output = client.run(
self.model,
input={
"prompt": input,
"system_prompt": system_prompt,
},
litellm.api_key = self.api_key

output = litellm.completion(
model=self.model,
messages=[
{"content": system_prompt, "role": "system"},
{"content": input, "role": "user"},
],
max_tokens=450,
temperature=0,
stream=enable_streaming,
)
if enable_streaming:
for item in output:
callback(item)
else:
return "".join(item for item in output)
for chunk in output:
callback(chunk["choices"][0]["delta"]["content"])
return output

def predict_with_embedding(
self,
input: str,
context: str,
system_prompt: str,
enable_streaming: bool = False,
callback: Callable = None,
system_prompt: str = None,
):
client = replicate.Client(api_token=config("REPLICATE_API_KEY"))
output = client.run(
self.model,
input={
"prompt": input,
"system_prompt": generate_replicaste_system_prompt(
context=context, system_prompt=system_prompt
),
},
litellm.api_key = self.api_key
prompt = generate_replicate_rag_prompt(context=context, input=input)
output = litellm.completion(
model=self.model,
messages=[
{
"content": system_prompt,
"role": "system",
},
{
"content": prompt,
"role": "user",
},
],
max_tokens=2000,
temperature=0,
stream=enable_streaming,
)
if enable_streaming:
for item in output:
callback(item)
else:
return "".join(item for item in output)
for chunk in output:
callback(chunk["choices"][0]["delta"]["content"])
return output


def get_query_service(
Expand Down
69 changes: 68 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ wheel = "^0.41.0"
python-dotenv = "^1.0.0"
tqdm = "^4.66.1"
setuptools = "^68.2.2"
litellm = "^0.12.5"

[build-system]
requires = ["poetry-core"]
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions test/embedding.py → tests/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ def main():
result = create_vector_embeddings(
type="PDF",
url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf",
filter_id="008",
model="thenlper/gte-large",
filter_id="010",
model="jinaai/jina-embeddings-v2-base-en",
)
print(result)

Expand Down
File renamed without changes.
20 changes: 20 additions & 0 deletions tests/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from nagato.service import predict


def callback_method(item):
print(item)


def main():
result = predict(
input="What was Teslas YoY revenue increase in Q2 2023?",
provider="REPLICATE",
model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f",
system_prompt="You are an helpful assistant",
enable_streaming=True,
callback=callback_method,
)
print(result)


main()
22 changes: 22 additions & 0 deletions tests/predict_with_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from nagato.service import predict_with_embedding


def callback_method(item):
print(item)


def main():
result = predict_with_embedding(
input="What was Teslas total revenue in Q2 2023?",
provider="REPLICATE",
model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f",
vector_db="PINECONE",
embedding_model="jinaai/jina-embeddings-v2-base-en",
embedding_filter_id="010",
enable_streaming=True,
callback=callback_method,
)
print(result)


main()
6 changes: 3 additions & 3 deletions test/query_embedding.py → tests/query_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ def callback_method(item):

def main():
result = query_embedding(
query="How many cars were sold in total?",
filter_id="007",
model="thenlper/gte-large",
query="What was Teslas total revenue in Q2 2023?",
filter_id="010",
model="jinaai/jina-embeddings-v2-base-en",
)
print(result)

Expand Down