Skip to content

Commit

Permalink
Add EnergyWizardPostgres code and associated package reqs
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed May 7, 2024
1 parent 8dc4083 commit 79c541c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 13 deletions.
47 changes: 47 additions & 0 deletions elm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
"""
from abc import ABC
import os
import json
import numpy as np
import asyncio
import aiohttp
import openai
import boto3
import requests
import tiktoken
import time
Expand Down Expand Up @@ -67,6 +69,19 @@ def __init__(self, model=None):
self._client = openai.AzureOpenAI(api_key=key,
api_version=version,
azure_endpoint=endpoint)
elif 'amazon' in self.EMBEDDING_MODEL.lower():
access_key = os.getenv('AWS_ACCESS_KEY_ID')
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
session_token = os.getenv('AWS_SESSION_TOKEN')

assert access_key is not None, "Must set AWS_ACCESS_KEY_ID!"
assert secret_access_key is not None, "Must set AWS_SECRET_ACCESS_KEY!"
assert session_token is not None, "Must set AWS_SESSION_TOKEN!"
self._client = boto3.client(service_name='bedrock-runtime',
region_name='us-west-2',
aws_access_key_id = access_key,
aws_secret_access_key = secret_access_key,
aws_session_token= session_token)
else:
key = os.environ.get("OPENAI_API_KEY")
assert key is not None, "Must set OPENAI_API_KEY!"
Expand Down Expand Up @@ -338,6 +353,38 @@ def get_embedding(cls, text):

return embedding

def get_aws_embedding(self, text):#self, text):
"""Get the 1D array (list) embedding of a text string as generated by AWS Titan.
Parameters
----------
text : str
Text to embed
Returns
-------
embedding : list
List of float that represents the numerical embedding of the text
"""

body = json.dumps({"inputText": text, })

model_id = self.EMBEDDING_MODEL
accept = 'application/json'
content_type = 'application/json'

response = self._client.invoke_model(
body=body,
modelId=model_id,
accept=accept,
contentType=content_type
)

response_body = json.loads(response['body'].read())
embedding = response_body.get('embedding')

return embedding

@staticmethod
def count_tokens(text, model):
"""Return the number of tokens in a string.
Expand Down
71 changes: 58 additions & 13 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
from abc import ABC, abstractmethod
import copy
import os
import psycopg2
import numpy as np

from elm.base import ApiBase
Expand Down Expand Up @@ -380,12 +382,16 @@ def make_ref_list(self, idx):
class EnergyWizardPostgres(EnergyWizardBase):
"""Interface to ask OpenAI LLMs about energy research.
This class is for execution with a postgres vector database
TODO: slater describe the vector DB here
This class is for execution with a postgres vector database.
Connecting to the database requires the use of the psycopg2
python package, environment variables storing the db user and
password, and the specification of other connection paremeters
such as host, port, and name. The database has the following
columns: id, embedding, chunks, and metadata.
"""

def __init__(self, model=None, token_budget=3500,
vector_db_args=None):
def __init__(self, db_host, db_port, db_name,
model=None, token_budget=3500):
"""
Parameters
----------
Expand All @@ -395,11 +401,27 @@ def __init__(self, model=None, token_budget=3500,
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
vector_db_args :
TODO: slater implement required vector database stuff here and set
self.cursor and whatnot
db_host : str
Host url for postgres database.
db_port : str
Port for postres database. ex: '5432'
db_name : str
Postgres database name.
"""

db_user = os.getenv("EWIZ_DB_USER")
db_password = os.getenv('EWIZ_DB_PASSWORD')
assert db_user is not None, "Must set user for postgreSQL database!"
assert db_password is not None, "Must set user for postgreSQL database!"

self.conn = psycopg2.connect(user=db_user,
password=db_password,
host=db_host,
port=db_port,
database=db_name)

self.cursor = self.conn.cursor()

super().__init__(model, token_budget=token_budget)

def query_vector_db(self, query, limit=100):
Expand All @@ -419,22 +441,35 @@ def query_vector_db(self, query, limit=100):
1D array of related strings
score : np.ndarray
1D array of float scores of strings
idx : np.ndarray
1D array of indices in the text corpus corresponding to the
ids : np.ndarray
1D array of IDs in the text corpus corresponding to the
ranked strings/scores outputs.
"""

# TODO: Slater implement vector db query here
query_embedding = self.get_aws_embedding(query)

self.cursor.execute("SELECT ewiz_kb.id, "
"ewiz_kb.chunks, "
"ewiz_kb.embedding <=> %s::vector as similarity_score "
"FROM ewiz_schema.ewiz_kb "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )

result = self.cursor.fetchall()

strings = [s[1] for s in result]
scores = [s[2] for s in result]
best = [s[0] for s in result]

return strings, scores, best

def make_ref_list(self, idx):
def make_ref_list(self, ids):
"""Make a reference list
Parameters
----------
used_index : np.ndarray
Indices of the used text from the text corpus
IDs of the used text from the text corpus
Returns
-------
Expand All @@ -443,6 +478,16 @@ def make_ref_list(self, idx):
["{ref_title} ({ref_url})"]
"""
# TODO: Slater implement vector db-to-meta-data query here to get
# information about the results (e.g., links and titles and whatnot)
# metadata is not stored in db at the moment, query will be updated

placeholders = ', '.join(['%s'] * len(ids))

sql_query = ("SELECT ewiz_kb.metadata "
"FROM ewiz_schema.ewiz_kb "
"WHERE ewiz_kb.id IN (" + placeholders + ")")

self.cursor.execute(sql_query, ids)

ref_list = self.cursor.fetchall()

return ref_list
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ python-slugify
scipy
tabulate
tiktoken
psycopg2
boto3

0 comments on commit 79c541c

Please sign in to comment.