Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG using Gemini and Elasticsearch #162

Merged
merged 9 commits into from
Jan 26, 2024
286 changes: 286 additions & 0 deletions notebooks/integrations/gemini/qa-langchain-gemini-elasticsearch.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2a3143e8-3949-4ecc-905c-8333a43c9c87",
"metadata": {},
"source": [
"# Question Answering using Gemini, Langchain & Elasticsearch\n",
"\n",
"This tutorial demonstrates how to use the [Gemini API](https://ai.google.dev/docs) to create [embeddings](https://ai.google.dev/docs/embeddings_guide) and store them in Elasticsearch. We will learn how to connect Gemini to private data stored in Elasticsearch and build question/answer capabilities over it using [LangChian](https://python.langchain.com/docs/get_started/introduction)."
]
},
{
"cell_type": "markdown",
"id": "68c5e34d-28f9-4195-9f9c-2a8aec1effe6",
"metadata": {},
"source": [
"## setup\n",
"\n",
"* Elastic Credentials - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`, `ELASTIC_API_KEY`).\n",
"\n",
"* `GOOGLE_API_KEY` - To use the Gemini API, you need to [create an API key in Google AI Studio](https://ai.google.dev/tutorials/setup)."
]
},
{
"cell_type": "markdown",
"id": "b8e9a58a-942f-4039-96c0-b276d5b8a97f",
"metadata": {},
"source": [
"## Install packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c4781ec-06a5-48dd-963e-fb832b3f7ca2",
"metadata": {},
"outputs": [],
"source": [
"pip install -q -U google-generativeai elasticsearch langchain langchain_google_genai"
]
},
{
"cell_type": "markdown",
"id": "851db243-ca7d-4a7c-a93b-d22ab149a1bb",
"metadata": {},
"source": [
"## Import packages and credentials"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26e7f569-c680-447b-9246-b5140ff47b6b",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"from getpass import getpass\n",
"from urllib.request import urlopen\n",
"\n",
"from elasticsearch import Elasticsearch, helpers\n",
"from langchain.vectorstores import ElasticsearchStore\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain_google_genai import GoogleGenerativeAIEmbeddings\n",
"from langchain_google_genai import ChatGoogleGenerativeAI\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"from langchain.schema.runnable import RunnablePassthrough"
]
},
{
"cell_type": "markdown",
"id": "b2f68db5-21ac-47b0-941b-1d816b586e18",
"metadata": {},
"source": [
"## Get Credentials"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "543d27f4-2c53-4726-a324-716900d72338",
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"GOOGLE_API_KEY\"] = getpass(\"Google API Key :\")\n",
"ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n",
"ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n",
"elastic_index_name = \"gemini-qa\""
]
},
{
"cell_type": "markdown",
"id": "e8bd47b9-b946-46d1-ba02-adbda118415a",
"metadata": {},
"source": [
"## Add documents"
]
},
{
"cell_type": "markdown",
"id": "bd04e921-206e-4c8b-937a-277d2c5a02e6",
"metadata": {},
"source": [
"### Let's download the sample dataset and deserialize the document."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44a5b79d-326e-4317-82a3-7918a11ff7b7",
"metadata": {},
"outputs": [],
"source": [
"url = \"https://raw.githubusercontent.com/ashishtiwari1993/langchain-elasticsearch-RAG/main/data.json\"\n",
"\n",
"response = urlopen(url)\n",
"\n",
"workplace_docs = json.loads(response.read())"
]
},
{
"cell_type": "markdown",
"id": "8591dce2-3fe6-4c87-b268-4694bb86e803",
"metadata": {},
"source": [
"### Split Documents into Passages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3963b0db-80d5-4908-897c-bec6357adc0a",
"metadata": {},
"outputs": [],
"source": [
"metadata = []\n",
"content = []\n",
"\n",
"for doc in workplace_docs:\n",
" content.append(doc[\"content\"])\n",
" metadata.append({\n",
" \"name\": doc[\"name\"],\n",
" \"summary\": doc[\"summary\"],\n",
" \"rolePermissions\":doc[\"rolePermissions\"]\n",
" })\n",
"\n",
"text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n",
"docs = text_splitter.create_documents(content, metadatas=metadata)"
]
},
{
"cell_type": "markdown",
"id": "a066d5dc-dbc9-495f-934f-bfe96e0fdeec",
"metadata": {},
"source": [
"## Index Documents into Elasticsearch using Gemini Embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42ba370a-e4b4-4375-b71a-2aee7c40a330",
"metadata": {},
"outputs": [],
"source": [
"query_embedding = GoogleGenerativeAIEmbeddings(\n",
" model=\"models/embedding-001\", task_type=\"retrieval_document\"\n",
")\n",
"\n",
"es = ElasticsearchStore.from_documents(\n",
" docs,\n",
" es_cloud_id=ELASTIC_CLOUD_ID,\n",
" es_api_key=ELASTIC_API_KEY,\n",
" index_name=elastic_index_name,\n",
" embedding=query_embedding\n",
")"
]
},
{
"cell_type": "markdown",
"id": "dbdb2d55-3349-4e95-8087-68f927f0d864",
"metadata": {},
"source": [
"## Create a retriever using Elasticsearch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17920c1e-9228-42f5-893d-29b666d6f7b2",
"metadata": {},
"outputs": [],
"source": [
"query_embedding = GoogleGenerativeAIEmbeddings(\n",
" model=\"models/embedding-001\", task_type=\"retrieval_query\"\n",
")\n",
"\n",
"es = ElasticsearchStore(\n",
maxjakob marked this conversation as resolved.
Show resolved Hide resolved
" es_cloud_id=ELASTIC_CLOUD_ID,\n",
" es_api_key=ELASTIC_API_KEY,\n",
" embedding=query_embedding,\n",
" index_name=elastic_index_name\n",
")\n",
"\n",
"retriever = es.as_retriever(search_kwargs={\"k\": 3})"
]
},
{
"cell_type": "markdown",
"id": "3647d005-d70e-4c3a-b784-052b21e9f143",
"metadata": {},
"source": [
"## Fromat Docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04ee4d3d-09fb-4a35-bc66-8a2951c402a8",
"metadata": {},
"outputs": [],
"source": [
"def format_docs(docs):\n",
" return \"\\n\\n\".join(doc.page_content for doc in docs)"
]
},
{
"cell_type": "markdown",
"id": "864afb6a-a671-434a-bd30-006c79ccda24",
"metadata": {},
"source": [
"## Create a Chain using Prompt Template + `gemini-pro` model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4818aef7-3535-494d-a5d4-16ef6d0581af",
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Answer the question based only on the following context:\\n\n",
"\n",
"{context}\n",
"\n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"\n",
"chain = (\n",
" {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()} \n",
" | prompt \n",
" | ChatGoogleGenerativeAI(model=\"gemini-pro\", temperature=0.7) \n",
" | StrOutputParser()\n",
")\n",
"\n",
"chain.invoke(\"what is our sales goals?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}