From ade005367674cb6eeb990e0cb7d7a5feddc031d8 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Sun, 29 Oct 2023 23:32:58 +0100 Subject: [PATCH 1/3] WIP --- nagato/service/embedding.py | 3 +- nagato/service/prompts.py | 4 ++ nagato/service/query.py | 67 ++++++++++++++++++----------- poetry.lock | 69 +++++++++++++++++++++++++++++- pyproject.toml | 1 + {test => tests}/__init__.py | 0 {test => tests}/embedding.py | 0 {test => tests}/finetune.py | 0 tests/predict.py | 20 +++++++++ tests/predict_with_embedding.py | 23 ++++++++++ {test => tests}/query_embedding.py | 0 11 files changed, 160 insertions(+), 27 deletions(-) rename {test => tests}/__init__.py (100%) rename {test => tests}/embedding.py (100%) rename {test => tests}/finetune.py (100%) create mode 100644 tests/predict.py create mode 100644 tests/predict_with_embedding.py rename {test => tests}/query_embedding.py (100%) diff --git a/nagato/service/embedding.py b/nagato/service/embedding.py index 44c494c..1f6a74d 100644 --- a/nagato/service/embedding.py +++ b/nagato/service/embedding.py @@ -16,7 +16,8 @@ "thenlper/gte-base": {"index_name": "gte-base", "dimensions": 768}, "thenlper/gte-small": {"index_name": "gte-small", "dimensions": 384}, "thenlper/gte-large": {"index_name": "gte-large", "dimensions": 1024}, - "infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768} + "infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768}, + "BAAI/bge-large-en-v1.5": {"index_name": "bge-large", "dimensions": 1024} # Add more mappings here as needed } diff --git a/nagato/service/prompts.py b/nagato/service/prompts.py index 0ae6be9..5be6323 100644 --- a/nagato/service/prompts.py +++ b/nagato/service/prompts.py @@ -18,6 +18,10 @@ ) +def generate_replicate_rag_prompt(context: str, system_prompt: str): + return "You are a helpful assistant" + + def generate_qa_pair_prompt( format: str, context: str, num_of_qa_pairs: int = 10 ) -> str: diff --git a/nagato/service/query.py b/nagato/service/query.py index 20f2a4e..73147d6 100644 --- a/nagato/service/query.py +++ b/nagato/service/query.py @@ -4,8 +4,10 @@ import replicate from decouple import config +import litellm + from nagato.service.prompts import ( - generate_replicaste_system_prompt, + generate_replicate_rag_prompt, ) @@ -17,6 +19,12 @@ def __init__( ): self.provider = provider self.model = model + if self.provider == "REPLICATE": + self.api_key = config("REPLICATE_API_KEY") + elif self.provider == "OPENAI": + self.api_key = config("OPENAI_API_KEY") + else: + self.api_key = None @abstractmethod def predict( @@ -47,19 +55,22 @@ def predict( system_prompt: str = None, callback: Callable = None, ): - client = replicate.Client(api_token=config("REPLICATE_API_KEY")) - output = client.run( - self.model, - input={ - "prompt": input, - "system_prompt": system_prompt, - }, + litellm.api_key = self.api_key + + output = litellm.completion( + model=self.model, + messages=[ + {"content": system_prompt, "role": "system"}, + {"content": input, "role": "user"}, + ], + max_tokens=450, + temperature=0, + stream=enable_streaming, ) if enable_streaming: - for item in output: - callback(item) - else: - return "".join(item for item in output) + for chunk in output: + callback(chunk["choices"][0]["delta"]["content"]) + return output def predict_with_embedding( self, @@ -69,21 +80,27 @@ def predict_with_embedding( callback: Callable = None, system_prompt: str = None, ): - client = replicate.Client(api_token=config("REPLICATE_API_KEY")) - output = client.run( - self.model, - input={ - "prompt": input, - "system_prompt": generate_replicaste_system_prompt( - context=context, system_prompt=system_prompt - ), - }, + litellm.api_key = self.api_key + + output = litellm.completion( + model=self.model, + messages=[ + { + "content": generate_replicate_rag_prompt( + context=context, system_prompt=system_prompt + ), + "role": "system", + }, + {"content": input, "role": "user"}, + ], + max_tokens=450, + temperature=0, + stream=enable_streaming, ) if enable_streaming: - for item in output: - callback(item) - else: - return "".join(item for item in output) + for chunk in output: + callback(chunk["choices"][0]["delta"]["content"]) + return output def get_query_service( diff --git a/poetry.lock b/poetry.lock index 62daa27..71ac520 100644 --- a/poetry.lock +++ b/poetry.lock @@ -143,6 +143,17 @@ doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd- test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (<0.22)"] +[[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = "*" +files = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] + [[package]] name = "async-timeout" version = "4.0.3" @@ -790,6 +801,25 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] +[[package]] +name = "importlib-metadata" +version = "6.8.0" +description = "Read metadata from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_metadata-6.8.0-py3-none-any.whl", hash = "sha256:3ebb78df84a805d7698245025b975d9d67053cd94c79245ba4b3eb694abe68bb"}, + {file = "importlib_metadata-6.8.0.tar.gz", hash = "sha256:dbace7892d8c0c4ac1ad096662232f831d4e64f4c4545bd53016a3e9d4654743"}, +] + +[package.dependencies] +zipp = ">=0.5" + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] +perf = ["ipython"] +testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] + [[package]] name = "jinja2" version = "3.1.2" @@ -912,6 +942,28 @@ files = [ pydantic = ">=1,<3" requests = ">=2,<3" +[[package]] +name = "litellm" +version = "0.12.5" +description = "Library to easily interface with LLM API providers" +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "litellm-0.12.5-py3-none-any.whl", hash = "sha256:2245dbb4d7be88bf9bbc20643de89f89041a5fffd4e2bfe3df09cf6264198968"}, + {file = "litellm-0.12.5.tar.gz", hash = "sha256:6c6ddaf092e41d1834c280a677e3b8592195d7843fdbaeda84b163cc044bab21"}, +] + +[package.dependencies] +appdirs = ">=1.4.4,<2.0.0" +certifi = ">=2023.7.22,<2024.0.0" +click = "*" +importlib-metadata = ">=6.8.0" +jinja2 = ">=3.1.2,<4.0.0" +openai = ">=0.27.0,<0.29.0" +python-dotenv = ">=0.2.0" +tiktoken = ">=0.4.0" +tokenizers = "*" + [[package]] name = "llama-index" version = "0.8.42" @@ -3167,7 +3219,22 @@ files = [ idna = ">=2.0" multidict = ">=4.0" +[[package]] +name = "zipp" +version = "3.17.0" +description = "Backport of pathlib-compatible object wrapper for zip files" +optional = false +python-versions = ">=3.8" +files = [ + {file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"}, + {file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"}, +] + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] + [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "11e5ccb14c0954e143f89b189dbcadfb93f4663b47fc24f6339e5e34d16ddee6" +content-hash = "c6e100ba4794db583716c8b9ee7f3d681c4f6d500aa4b3ee5da8664e5da62daa" diff --git a/pyproject.toml b/pyproject.toml index 96eb4b3..fd734e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ wheel = "^0.41.0" python-dotenv = "^1.0.0" tqdm = "^4.66.1" setuptools = "^68.2.2" +litellm = "^0.12.5" [build-system] requires = ["poetry-core"] diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/embedding.py b/tests/embedding.py similarity index 100% rename from test/embedding.py rename to tests/embedding.py diff --git a/test/finetune.py b/tests/finetune.py similarity index 100% rename from test/finetune.py rename to tests/finetune.py diff --git a/tests/predict.py b/tests/predict.py new file mode 100644 index 0000000..6904bfb --- /dev/null +++ b/tests/predict.py @@ -0,0 +1,20 @@ +from nagato.service import predict + + +def callback_method(item): + print(item) + + +def main(): + result = predict( + input="What was Teslas YoY revenue increase in Q2 2023?", + provider="REPLICATE", + model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f", + system_prompt="You are an helpful assistant", + enable_streaming=True, + callback=callback_method, + ) + print(result) + + +main() diff --git a/tests/predict_with_embedding.py b/tests/predict_with_embedding.py new file mode 100644 index 0000000..85d46e2 --- /dev/null +++ b/tests/predict_with_embedding.py @@ -0,0 +1,23 @@ +from nagato.service import predict_with_embedding + + +def callback_method(item): + print(item) + + +def main(): + result = predict_with_embedding( + input="What was Teslas YoY revenue increase in Q2 2023?", + provider="REPLICATE", + model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f", + system_prompt="You are an helpful assistant", + embedding_provider="", + embedding_model="", + embedding_filter_id="", + enable_streaming=True, + callback=callback_method, + ) + print(result) + + +main() diff --git a/test/query_embedding.py b/tests/query_embedding.py similarity index 100% rename from test/query_embedding.py rename to tests/query_embedding.py From 2afe5cea12e0bd7dc070c012991cf6fd69524408 Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Mon, 30 Oct 2023 00:27:25 +0100 Subject: [PATCH 2/3] Add support for doing predictions with embeddings --- nagato/service/__init__.py | 13 +++++++------ nagato/service/embedding.py | 6 +++++- nagato/service/prompts.py | 10 ++++++++-- nagato/service/query.py | 15 ++++++++------- tests/embedding.py | 4 ++-- tests/predict_with_embedding.py | 9 ++++----- tests/query_embedding.py | 6 +++--- 7 files changed, 37 insertions(+), 26 deletions(-) diff --git a/nagato/service/__init__.py b/nagato/service/__init__.py index 749f529..01618f5 100644 --- a/nagato/service/__init__.py +++ b/nagato/service/__init__.py @@ -75,11 +75,11 @@ def predict_with_embedding( input: str, provider: str, model: str, - embedding_provider: str, + vector_db: str, embedding_model: str, embedding_filter_id: str, callback: Callable = None, - system_prompt: str = None, + system_prompt: str = "You are a helpful assistant", enable_streaming: bool = False, ) -> dict: from nagato.service.query import get_query_service @@ -88,9 +88,10 @@ def predict_with_embedding( query=input, model=embedding_model, filter_id=embedding_filter_id, - provider=embedding_provider, + vector_db=vector_db, ) - context = similarity_search["results"][0]["matches"] + docs = similarity_search["results"][0]["matches"] + context = docs[0]["metadata"]["content"] query_service = get_query_service(provider=provider, model=model) output = query_service.predict_with_embedding( input=input, @@ -105,7 +106,7 @@ def predict_with_embedding( def query_embedding( query: str, model: str = "thenlper/gte-small", - provider: str = "PINECONE", + vector_db: str = "PINECONE", filter_id: str = None, top_k: int = 5, ) -> dict: @@ -115,7 +116,7 @@ def query_embedding( embedding_model = SentenceTransformer(model, use_auth_token=config("HF_API_KEY")) vectordb = get_vector_service( - provider=provider, + provider=vector_db, index_name=MODEL_TO_INDEX[model].get("index_name"), filter_id=filter_id, dimension=MODEL_TO_INDEX[model].get("dimensions"), diff --git a/nagato/service/embedding.py b/nagato/service/embedding.py index 1f6a74d..98c60f8 100644 --- a/nagato/service/embedding.py +++ b/nagato/service/embedding.py @@ -17,7 +17,11 @@ "thenlper/gte-small": {"index_name": "gte-small", "dimensions": 384}, "thenlper/gte-large": {"index_name": "gte-large", "dimensions": 1024}, "infgrad/stella-base-en-v2": {"index_name": "stella-base", "dimensions": 768}, - "BAAI/bge-large-en-v1.5": {"index_name": "bge-large", "dimensions": 1024} + "BAAI/bge-large-en-v1.5": {"index_name": "bge-large", "dimensions": 1024}, + "jinaai/jina-embeddings-v2-base-en": { + "index_name": "jina-embeddings-v2", + "dimensions": 768, + } # Add more mappings here as needed } diff --git a/nagato/service/prompts.py b/nagato/service/prompts.py index 5be6323..73615ca 100644 --- a/nagato/service/prompts.py +++ b/nagato/service/prompts.py @@ -18,8 +18,14 @@ ) -def generate_replicate_rag_prompt(context: str, system_prompt: str): - return "You are a helpful assistant" +def generate_replicate_rag_prompt(context: str, input: str) -> str: + prompt = ( + "You are a helpful assistant that's an expert at answering questions.\n" + "Use the following context to answer any quesitons.\n\n" + f"Context:\n{context}\n\n" + f"Question:\n{input}" + ) + return prompt def generate_qa_pair_prompt( diff --git a/nagato/service/query.py b/nagato/service/query.py index 73147d6..5fc1c7f 100644 --- a/nagato/service/query.py +++ b/nagato/service/query.py @@ -76,24 +76,25 @@ def predict_with_embedding( self, input: str, context: str, + system_prompt: str, enable_streaming: bool = False, callback: Callable = None, - system_prompt: str = None, ): litellm.api_key = self.api_key - + prompt = generate_replicate_rag_prompt(context=context, input=input) output = litellm.completion( model=self.model, messages=[ { - "content": generate_replicate_rag_prompt( - context=context, system_prompt=system_prompt - ), + "content": system_prompt, "role": "system", }, - {"content": input, "role": "user"}, + { + "content": prompt, + "role": "user", + }, ], - max_tokens=450, + max_tokens=2000, temperature=0, stream=enable_streaming, ) diff --git a/tests/embedding.py b/tests/embedding.py index e01d97a..b6af9db 100644 --- a/tests/embedding.py +++ b/tests/embedding.py @@ -5,8 +5,8 @@ def main(): result = create_vector_embeddings( type="PDF", url="https://digitalassets.tesla.com/tesla-contents/image/upload/IR/TSLA-Q2-2023-Update.pdf", - filter_id="008", - model="thenlper/gte-large", + filter_id="010", + model="jinaai/jina-embeddings-v2-base-en", ) print(result) diff --git a/tests/predict_with_embedding.py b/tests/predict_with_embedding.py index 85d46e2..19e4c56 100644 --- a/tests/predict_with_embedding.py +++ b/tests/predict_with_embedding.py @@ -7,13 +7,12 @@ def callback_method(item): def main(): result = predict_with_embedding( - input="What was Teslas YoY revenue increase in Q2 2023?", + input="What was Teslas total revenue in Q2 2023?", provider="REPLICATE", model="homanp/test:bc8afbabceaec8abb9b15fade05ff42db371b01fa251541b49c8ba9a9d44bc1f", - system_prompt="You are an helpful assistant", - embedding_provider="", - embedding_model="", - embedding_filter_id="", + vector_db="PINECONE", + embedding_model="jinaai/jina-embeddings-v2-base-en", + embedding_filter_id="010", enable_streaming=True, callback=callback_method, ) diff --git a/tests/query_embedding.py b/tests/query_embedding.py index 450803a..90b11fa 100644 --- a/tests/query_embedding.py +++ b/tests/query_embedding.py @@ -7,9 +7,9 @@ def callback_method(item): def main(): result = query_embedding( - query="How many cars were sold in total?", - filter_id="007", - model="thenlper/gte-large", + query="What was Teslas total revenue in Q2 2023?", + filter_id="010", + model="jinaai/jina-embeddings-v2-base-en", ) print(result) From 8009b829c97ccb1c0a61d5f089b222c02bc07aea Mon Sep 17 00:00:00 2001 From: Ismail Pelaseyed Date: Mon, 30 Oct 2023 12:35:28 +0100 Subject: [PATCH 3/3] Fix formatting --- nagato/service/query.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nagato/service/query.py b/nagato/service/query.py index 5fc1c7f..cbcfec5 100644 --- a/nagato/service/query.py +++ b/nagato/service/query.py @@ -1,10 +1,8 @@ from abc import ABC, abstractmethod from typing import Callable -import replicate -from decouple import config - import litellm +from decouple import config from nagato.service.prompts import ( generate_replicate_rag_prompt,