Skip to content

Commit

Permalink
Updating Wizard to include mixture of experts
Browse files Browse the repository at this point in the history
  • Loading branch information
Cox, Jordan authored and Cox, Jordan committed Jun 18, 2024
1 parent a0b93d9 commit 4c4f840
Show file tree
Hide file tree
Showing 8 changed files with 461 additions and 698 deletions.
85 changes: 5 additions & 80 deletions elm/db_wiz.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,90 +44,14 @@ def __init__(self, connection_string, model=None, token_budget=3500, ref_col=Non
self.connection = psycopg2.connect(self.connection_string)
self.token_budget = token_budget

fpcache = './db_description.txt'
fpcache = './database_manual.txt'

if os.path.exists(fpcache):
with open(fpcache, 'r') as f:
self.database_describe = f.read()

else:
# Initializing database schema
self.database_schema = self.get_schema()
self.database_first_lines = self.get_lines()
self.database_unique_values = self.get_unique_values()

self.database_describe = ('You have been given access to the database '
'schema {}.\n The first ten lines of the database are {}.\n '
'Each column of text contains the following unique '
'values {}.\n The table name is loads.lc_day_profile_demand_enduse.'
.format(self.database_schema,
self.database_first_lines,
self.database_unique_values))

with open(fpcache, 'w') as f:
f.write(self.database_describe)


## Getting database Schema
def get_schema(self):
query = """
SELECT table_name, column_name, data_type
FROM information_schema.columns
WHERE table_schema = 'loads' AND table_name = 'lc_day_profile_demand_enduse'
ORDER BY table_name, ordinal_position;
"""

with self.connection.cursor() as cur:
cur.execute(query)
schema = {}
for table, col, dtype in cur.fetchall():
if table not in schema:
schema[table] = []
schema[table].append({"column": col, "type": dtype})

schema_json = json.dumps(schema)
return schema_json


def json_serial(self, obj):
"""JSON serializer for objects not serializable by default json code"""

if isinstance(obj, (datetime, date)):
return obj.isoformat()
raise TypeError ("Type %s not serializable" % type(obj))

## Getting First 10 lines of database
def get_lines(self):
query = '''
SELECT *
FROM loads.lc_day_profile_demand_enduse
LIMIT 10;
'''

with self.connection.cursor() as cursor:
cursor.execute(query)
first_lines = cursor.fetchall()

first_lines_json = json.dumps(first_lines, default=self.json_serial)
return first_lines_json

# Getting Unique values in non-float columns of the database
def get_unique_values(self):
schema = json.loads(self.database_schema)

with self.connection.cursor() as cursor:

structure_dict = {}
for table in schema:
for entry in schema[table]:
if entry['type'] == 'text':
column_name = entry['column']
query = f'SELECT DISTINCT {column_name} FROM loads.{table}'

cursor.execute(query)
structure_dict[entry['column']] = str(cursor.fetchall())

return json.dumps(structure_dict)
print('Error no expert database file')

# Getting sql from a generic query
def get_sql_for(self, query):
Expand All @@ -137,6 +61,7 @@ def get_sql_for(self, query):
e_query = ('{}\n\nPlease create a SQL query that will answer this '
'user question: "{}"\n\n'
'Return all columns from the database. '
'All the tables are in the schema "loads"'
'Please only return the SQL query with no commentary or preface.'
.format(self.database_describe, query))
out = super().chat(e_query, temperature=0)
Expand All @@ -148,6 +73,7 @@ def run_sql(self, sql):
based on the db connection (self.connection), returns dataframe
response."""
query = sql
print(query)
# Move Connection or cursor to init and test so that you aren't re-intializing
# it with each instance.
with self.connection.cursor() as cursor:
Expand Down Expand Up @@ -175,7 +101,7 @@ def get_py_code(self, query, df):
full_response = out
#print(full_response)
## get python code from response
full_response = full_response[full_response.find('python')+6:]
full_response = full_response[full_response.find('```python')+9:]
full_response = full_response[:full_response.find('```')]
py = full_response
return py
Expand All @@ -186,7 +112,6 @@ def run_py_code(self, py, df):
return plt
except:
print(py)
"""Jordan to write code that takes LLM response and generates plots"""

def chat(self, query,
debug=True,
Expand Down
228 changes: 228 additions & 0 deletions elm/experts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""
ELM mixture of experts
"""
import streamlit as st
import os
import openai
from glob import glob
import pandas as pd
import sys
import copy
import numpy as np


from elm.base import ApiBase
from elm.wizard import EnergyWizard
from elm.db_wiz import DataBaseWizard

model = 'gpt-4'

# NREL-Azure endpoint. You can also use just the openai endpoint.
# NOTE: embedding values are different between OpenAI and Azure models!
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_KEY")
openai.api_type = 'azure'
openai.api_version = os.getenv('AZURE_OPENAI_VERSION')

EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2'
EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
'text-embedding-ada-002-2/embeddings?'
f'api-version={openai.api_version}')
EnergyWizard.URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
f'{model}/chat/completions?'
f'api-version={openai.api_version}')
EnergyWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the '
'articles below to answer the question. If '
'articles do not provide enough information to '
'answer the question, say "I do not know."')
EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE

DataBaseWizard.URL = (f'https://stratus-embeddings-south-central.openai.azure.com/'
f'openai/deployments/{model}/chat/'
f'completions?api-version={openai.api_version}')
DataBaseWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}",
}

st.set_option('deprecation.showPyplotGlobalUse', False)

@st.cache_data
def get_corpus():
"""Get the corpus of text data with embeddings."""
corpus = sorted(glob('./embed/*.json'))
corpus = [pd.read_json(fp) for fp in corpus]
corpus = pd.concat(corpus, ignore_index=True)

return corpus


@st.cache_resource
def get_wizard(model = model):
"""Get the energy wizard object.
Parameters
----------
model : str
State which model to use for the energy wizard.
Returns
-------
response : str
GPT output / answer.
wizard : EnergyWizard
Returns the energy wizard object for use in chat responses.
"""


# Getting Corpus of data. If no corpus throw error for user.
try:
corpus = get_corpus()
except Exception:
print("Error: Have you run 'retrieve_docs.py'?")
st.header("Error")
st.write("Error: Have you run 'retrieve_docs.py'?")
sys.exit(0)

wizard = EnergyWizard(corpus, ref_col='ref', model=model)
return wizard

class MixtureOfExperts(ApiBase):
"""Interface to ask OpenAI LLMs about energy
research either from a database or report."""

"""Parameters
----------
model : str
State which model to use for the energy wizard.
connection string : str
String used to connect to SQL databases.
Returns
-------
response : str
GPT output / answer.
"""

MODEL_ROLE = ("You are an expert given a query. Which of the "
"following best describes the query? Please "
"answer with just the number and nothing else."
"1. This is a query best answered by a text-based report."
"2. This is a query best answered by pulling data from "
"a database and creating a figure.")
"""High level model role, somewhat redundant to MODEL_INSTRUCTION"""

def __init__(self, connection_string, model=None, token_budget=3500, ref_col=None):
self.wizard_db = DataBaseWizard(model = model, connection_string = connection_string)
self.wizard_chat = get_wizard()
self.model = model
super().__init__(model)

def chat(self, query,
debug=True,
stream=True,
temperature=0,
convo=False,
token_budget=None,
new_info_threshold=0.7,
print_references=False,
return_chat_obj=False):
"""Answers a query by doing a semantic search of relevant text with
embeddings and then sending engineered query to the LLM.
Parameters
----------
query : str
Question being asked of EnergyWizard
debug : bool
Flag to return extra diagnostics on the engineered question.
stream : bool
Flag to print subsequent chunks of the response in a streaming
fashion
temperature : float
GPT model temperature, a measure of response entropy from 0 to 1. 0
is more reliable and nearly deterministic; 1 will give the model
more creative freedom and may not return as factual of results.
convo : bool
Flag to perform semantic search with full conversation history
(True) or just the single query (False). Call EnergyWizard.clear()
to reset the chat history.
token_budget : int
Option to override the class init token budget.
new_info_threshold : float
New text added to the engineered query must contain at least this
much new information. This helps prevent (for example) the table of
contents being added multiple times.
print_references : bool
Flag to print references if EnergyWizard is initialized with a
valid ref_col.
return_chat_obj : bool
Flag to only return the ChatCompletion from OpenAI API.
Returns
-------
response : str
GPT output / answer.
query : str
If debug is True, the engineered query asked of GPT will also be
returned here
references : list
If debug is True, the list of references (strs) used in the
engineered prompt is returned here
"""

messages = [{"role": "system", "content": self.MODEL_ROLE},
{"role": "user", "content": query}]
response_message = ''
kwargs = dict(model=self.model,
messages=messages,
temperature=temperature,
stream=stream)

response = self._client.chat.completions.create(**kwargs)

print(response)

if stream:
for chunk in response:
chunk_msg = chunk.choices[0].delta.content or ""
response_message += chunk_msg
print(chunk_msg, end='')

else:
response_message = response["choices"][0]["message"]["content"]


message_placeholder = st.empty()
full_response = ""

if '1' in response_message:
out = self.wizard_chat.chat(query,
debug=True, stream=True, token_budget=6000,
temperature=0.0, print_references=True,
convo=False, return_chat_obj=True)

for response in out[0]:
full_response += response.choices[0].delta.content or ""
message_placeholder.markdown(full_response + "▌")


elif '2' in response_message:
out = self.wizard_db.chat(query,
debug=True, stream=True, token_budget=6000,
temperature=0.0, print_references=True,
convo=False, return_chat_obj=True)

st.pyplot(fig = out, clear_figure = False)

else:
response_message = 'Error cannot find data in report or database.'


return full_response
Loading

0 comments on commit 4c4f840

Please sign in to comment.