-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b88fb8c
commit 222431d
Showing
2 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import asyncio | ||
import pandas as pd | ||
import logging | ||
import openai | ||
import time | ||
from glob import glob | ||
from rex import init_logger | ||
|
||
from elm.pdf import PDFtoTXT | ||
from elm.embed import ChunkAndEmbed | ||
from elm.osti import OstiList | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
init_logger(__name__, log_level='DEBUG') | ||
init_logger('elm', log_level='INFO') | ||
|
||
|
||
# NREL-Azure endpoint. You can also use just the openai endpoint. | ||
# NOTE: embedding values are different between OpenAI and Azure models! | ||
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") | ||
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY") | ||
openai.api_type = 'azure' | ||
openai.api_version = '2023-03-15-preview' | ||
|
||
ChunkAndEmbed.EMBEDDING_MODEL = 'text-embedding-ada-002-2' | ||
ChunkAndEmbed.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' | ||
'openai.azure.com/openai/deployments/' | ||
'text-embedding-ada-002-2/embeddings?' | ||
f'api-version={openai.api_version}') | ||
ChunkAndEmbed.HEADERS = {"Content-Type": "application/json", | ||
"Authorization": f"Bearer {openai.api_key}", | ||
"api-key": f"{openai.api_key}"} | ||
|
||
PDF_DIR = './pdfs/' | ||
TXT_DIR = './txt/' | ||
EMBED_DIR = './embed/' | ||
|
||
URL = ('https://www.osti.gov/api/v1/records?' | ||
'research_org=NREL' | ||
'&sort=publication_date%20desc' | ||
'&product_type=Technical%20Report' | ||
'&has_fulltext=true' | ||
'&publication_date_start=01/01/2023' | ||
'&publication_date_end=12/31/2023') | ||
|
||
|
||
if __name__ == '__main__': | ||
os.makedirs(PDF_DIR, exist_ok=True) | ||
os.makedirs(TXT_DIR, exist_ok=True) | ||
os.makedirs(EMBED_DIR, exist_ok=True) | ||
|
||
osti = OstiList(URL, n_pages=1) | ||
osti.download(PDF_DIR) | ||
|
||
meta = osti.meta.copy() | ||
meta['osti_id'] = meta['osti_id'].astype(str) | ||
meta = meta.drop_duplicates(subset=['osti_id']) | ||
meta['fp'] = PDF_DIR + meta['fn'] | ||
meta.to_csv('./meta.csv', index=False) | ||
|
||
missing = [] | ||
for i, row in meta.iterrows(): | ||
if not os.path.exists(row['fp']): | ||
missing.append(i) | ||
meta = meta.drop(missing, axis=0) | ||
|
||
for i, row in meta.iterrows(): | ||
fp = os.path.join(PDF_DIR, row['fn']) | ||
txt_fp = os.path.join(TXT_DIR, row['fn'].replace('.pdf', '.txt')) | ||
embed_fp = os.path.join(EMBED_DIR, row['fn'].replace('.pdf', '.json')) | ||
|
||
assert fp.endswith('.pdf') | ||
assert os.path.exists(fp) | ||
|
||
if os.path.exists(txt_fp): | ||
with open(txt_fp, 'r') as f: | ||
text = f.read() | ||
else: | ||
pdf_obj = PDFtoTXT(fp) | ||
text = pdf_obj.clean_poppler(layout=True) | ||
if pdf_obj.is_double_col(): | ||
text = pdf_obj.clean_poppler(layout=False) | ||
text = pdf_obj.clean_headers(char_thresh=0.6, page_thresh=0.8, | ||
split_on='\n', | ||
iheaders=[0, 1, 3, -3, -2, -1]) | ||
with open(txt_fp, 'w') as f: | ||
f.write(text) | ||
logger.info(f'Saved: {txt_fp}') | ||
|
||
if not os.path.exists(embed_fp): | ||
logger.info('Embedding {}/{}: "{}"' | ||
.format(i+1, len(meta), row['title'])) | ||
tag = f"Title: {row['title']}\nAuthors: {row['authors']}" | ||
obj = ChunkAndEmbed(text, tag=tag, tokens_per_chunk=500, overlap=1) | ||
embeddings = asyncio.run(obj.run_async(rate_limit=3e4)) | ||
if any(e is None for e in embeddings): | ||
raise RuntimeError('Embeddings are None!') | ||
else: | ||
df = pd.DataFrame({'text': obj.text_chunks.chunks, | ||
'embedding': embeddings, | ||
'osti_id': row['osti_id']}) | ||
df.to_json(embed_fp, indent=2) | ||
logger.info('Saved: {}'.format(embed_fp)) | ||
time.sleep(5) | ||
|
||
bad = [] | ||
fps = glob(EMBED_DIR + '*.json') | ||
for fp in fps: | ||
data = pd.read_json(fp) | ||
if data['embedding'].isna().any(): | ||
bad.append(fp) | ||
assert not any(bad), f'Bad output: {bad}' | ||
|
||
logger.info('Finished!') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import streamlit as st | ||
import os | ||
import openai | ||
from glob import glob | ||
import pandas as pd | ||
|
||
from elm import EnergyWizard | ||
|
||
|
||
model = 'gpt-4' | ||
|
||
# NREL-Azure endpoint. You can also use just the openai endpoint. | ||
# NOTE: embedding values are different between OpenAI and Azure models! | ||
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") | ||
openai.api_key = os.getenv("AZURE_OPENAI_API_KEY") | ||
openai.api_type = 'azure' | ||
openai.api_version = '2023-03-15-preview' | ||
|
||
EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2' | ||
EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' | ||
'openai.azure.com/openai/deployments/' | ||
'text-embedding-ada-002-2/embeddings?' | ||
f'api-version={openai.api_version}') | ||
EnergyWizard.URL = ('https://stratus-embeddings-south-central.' | ||
'openai.azure.com/openai/deployments/' | ||
f'{model}/chat/completions?' | ||
f'api-version={openai.api_version}') | ||
EnergyWizard.HEADERS = {"Content-Type": "application/json", | ||
"Authorization": f"Bearer {openai.api_key}", | ||
"api-key": f"{openai.api_key}"} | ||
|
||
EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the ' | ||
'articles below to answer the question. If ' | ||
'articles do not provide enough information to ' | ||
'answer the question, say "I do not know."') | ||
EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE | ||
|
||
|
||
@st.cache_data | ||
def get_corpus(): | ||
"""Get the corpus of text data with embeddings.""" | ||
corpus = sorted(glob('./embed/*.json')) | ||
corpus = [pd.read_json(fp) for fp in corpus] | ||
corpus = pd.concat(corpus, ignore_index=True) | ||
meta = pd.read_csv('./meta.csv') | ||
|
||
corpus['osti_id'] = corpus['osti_id'].astype(str) | ||
meta['osti_id'] = meta['osti_id'].astype(str) | ||
corpus = corpus.set_index('osti_id') | ||
meta = meta.set_index('osti_id') | ||
|
||
corpus = corpus.join(meta, on='osti_id', rsuffix='_record', how='left') | ||
|
||
ref = [f"{row['title']} ({row['doi']})" for _, row in corpus.iterrows()] | ||
corpus['ref'] = ref | ||
|
||
return corpus | ||
|
||
|
||
@st.cache_resource | ||
def get_wizard(): | ||
"""Get the energy wizard object.""" | ||
corpus = get_corpus() | ||
wizard = EnergyWizard(corpus, ref_col='ref', model=model) | ||
return wizard | ||
|
||
|
||
if __name__ == '__main__': | ||
wizard = get_wizard() | ||
st.title("Energy Wizard") | ||
|
||
if "messages" not in st.session_state: | ||
st.session_state.messages = [] | ||
|
||
for message in st.session_state.messages: | ||
with st.chat_message(message["role"]): | ||
st.markdown(message["content"]) | ||
|
||
msg = "Ask the Energy Wizard a question about NREL research!" | ||
if prompt := st.chat_input(msg): | ||
st.chat_message("user").markdown(prompt) | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
|
||
with st.chat_message("assistant"): | ||
|
||
message_placeholder = st.empty() | ||
full_response = "" | ||
|
||
out = wizard.chat(prompt, | ||
debug=True, stream=True, token_budget=6000, | ||
temperature=0.0, print_references=True, | ||
convo=False, return_chat_obj=True) | ||
references = out[-1] | ||
|
||
for response in out[0]: | ||
full_response += response.choices[0].delta.get("content", "") | ||
message_placeholder.markdown(full_response + "▌") | ||
|
||
ref_msg = ('\n\nThe wizard was provided with the ' | ||
'following documents to support its answer:') | ||
ref_msg += '\n - ' + '\n - '.join(references) | ||
full_response += ref_msg | ||
|
||
message_placeholder.markdown(full_response) | ||
|
||
st.session_state.messages.append({"role": "assistant", | ||
"content": full_response}) |