Skip to content

Commit

Permalink
added energy wizard example
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Nov 6, 2023
1 parent b88fb8c commit 222431d
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 0 deletions.
116 changes: 116 additions & 0 deletions examples/energy_wizard/retrieve_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import asyncio
import pandas as pd
import logging
import openai
import time
from glob import glob
from rex import init_logger

from elm.pdf import PDFtoTXT
from elm.embed import ChunkAndEmbed
from elm.osti import OstiList


logger = logging.getLogger(__name__)
init_logger(__name__, log_level='DEBUG')
init_logger('elm', log_level='INFO')


# NREL-Azure endpoint. You can also use just the openai endpoint.
# NOTE: embedding values are different between OpenAI and Azure models!
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY")
openai.api_type = 'azure'
openai.api_version = '2023-03-15-preview'

ChunkAndEmbed.EMBEDDING_MODEL = 'text-embedding-ada-002-2'
ChunkAndEmbed.EMBEDDING_URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
'text-embedding-ada-002-2/embeddings?'
f'api-version={openai.api_version}')
ChunkAndEmbed.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

PDF_DIR = './pdfs/'
TXT_DIR = './txt/'
EMBED_DIR = './embed/'

URL = ('https://www.osti.gov/api/v1/records?'
'research_org=NREL'
'&sort=publication_date%20desc'
'&product_type=Technical%20Report'
'&has_fulltext=true'
'&publication_date_start=01/01/2023'
'&publication_date_end=12/31/2023')


if __name__ == '__main__':
os.makedirs(PDF_DIR, exist_ok=True)
os.makedirs(TXT_DIR, exist_ok=True)
os.makedirs(EMBED_DIR, exist_ok=True)

osti = OstiList(URL, n_pages=1)
osti.download(PDF_DIR)

meta = osti.meta.copy()
meta['osti_id'] = meta['osti_id'].astype(str)
meta = meta.drop_duplicates(subset=['osti_id'])
meta['fp'] = PDF_DIR + meta['fn']
meta.to_csv('./meta.csv', index=False)

missing = []
for i, row in meta.iterrows():
if not os.path.exists(row['fp']):
missing.append(i)
meta = meta.drop(missing, axis=0)

for i, row in meta.iterrows():
fp = os.path.join(PDF_DIR, row['fn'])
txt_fp = os.path.join(TXT_DIR, row['fn'].replace('.pdf', '.txt'))
embed_fp = os.path.join(EMBED_DIR, row['fn'].replace('.pdf', '.json'))

assert fp.endswith('.pdf')
assert os.path.exists(fp)

if os.path.exists(txt_fp):
with open(txt_fp, 'r') as f:
text = f.read()
else:
pdf_obj = PDFtoTXT(fp)
text = pdf_obj.clean_poppler(layout=True)
if pdf_obj.is_double_col():
text = pdf_obj.clean_poppler(layout=False)
text = pdf_obj.clean_headers(char_thresh=0.6, page_thresh=0.8,
split_on='\n',
iheaders=[0, 1, 3, -3, -2, -1])
with open(txt_fp, 'w') as f:
f.write(text)
logger.info(f'Saved: {txt_fp}')

if not os.path.exists(embed_fp):
logger.info('Embedding {}/{}: "{}"'
.format(i+1, len(meta), row['title']))
tag = f"Title: {row['title']}\nAuthors: {row['authors']}"
obj = ChunkAndEmbed(text, tag=tag, tokens_per_chunk=500, overlap=1)
embeddings = asyncio.run(obj.run_async(rate_limit=3e4))
if any(e is None for e in embeddings):
raise RuntimeError('Embeddings are None!')
else:
df = pd.DataFrame({'text': obj.text_chunks.chunks,
'embedding': embeddings,
'osti_id': row['osti_id']})
df.to_json(embed_fp, indent=2)
logger.info('Saved: {}'.format(embed_fp))
time.sleep(5)

bad = []
fps = glob(EMBED_DIR + '*.json')
for fp in fps:
data = pd.read_json(fp)
if data['embedding'].isna().any():
bad.append(fp)
assert not any(bad), f'Bad output: {bad}'

logger.info('Finished!')
107 changes: 107 additions & 0 deletions examples/energy_wizard/run_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import streamlit as st
import os
import openai
from glob import glob
import pandas as pd

from elm import EnergyWizard


model = 'gpt-4'

# NREL-Azure endpoint. You can also use just the openai endpoint.
# NOTE: embedding values are different between OpenAI and Azure models!
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY")
openai.api_type = 'azure'
openai.api_version = '2023-03-15-preview'

EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2'
EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
'text-embedding-ada-002-2/embeddings?'
f'api-version={openai.api_version}')
EnergyWizard.URL = ('https://stratus-embeddings-south-central.'
'openai.azure.com/openai/deployments/'
f'{model}/chat/completions?'
f'api-version={openai.api_version}')
EnergyWizard.HEADERS = {"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
"api-key": f"{openai.api_key}"}

EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the '
'articles below to answer the question. If '
'articles do not provide enough information to '
'answer the question, say "I do not know."')
EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE


@st.cache_data
def get_corpus():
"""Get the corpus of text data with embeddings."""
corpus = sorted(glob('./embed/*.json'))
corpus = [pd.read_json(fp) for fp in corpus]
corpus = pd.concat(corpus, ignore_index=True)
meta = pd.read_csv('./meta.csv')

corpus['osti_id'] = corpus['osti_id'].astype(str)
meta['osti_id'] = meta['osti_id'].astype(str)
corpus = corpus.set_index('osti_id')
meta = meta.set_index('osti_id')

corpus = corpus.join(meta, on='osti_id', rsuffix='_record', how='left')

ref = [f"{row['title']} ({row['doi']})" for _, row in corpus.iterrows()]
corpus['ref'] = ref

return corpus


@st.cache_resource
def get_wizard():
"""Get the energy wizard object."""
corpus = get_corpus()
wizard = EnergyWizard(corpus, ref_col='ref', model=model)
return wizard


if __name__ == '__main__':
wizard = get_wizard()
st.title("Energy Wizard")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

msg = "Ask the Energy Wizard a question about NREL research!"
if prompt := st.chat_input(msg):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})

with st.chat_message("assistant"):

message_placeholder = st.empty()
full_response = ""

out = wizard.chat(prompt,
debug=True, stream=True, token_budget=6000,
temperature=0.0, print_references=True,
convo=False, return_chat_obj=True)
references = out[-1]

for response in out[0]:
full_response += response.choices[0].delta.get("content", "")
message_placeholder.markdown(full_response + "▌")

ref_msg = ('\n\nThe wizard was provided with the '
'following documents to support its answer:')
ref_msg += '\n - ' + '\n - '.join(references)
full_response += ref_msg

message_placeholder.markdown(full_response)

st.session_state.messages.append({"role": "assistant",
"content": full_response})

0 comments on commit 222431d

Please sign in to comment.