From 222431deceed079b2c4f40e59e37ebf3f3770733 Mon Sep 17 00:00:00 2001 From: grantbuster Date: Mon, 6 Nov 2023 16:22:27 -0700 Subject: [PATCH] added energy wizard example --- examples/energy_wizard/retrieve_docs.py | 116 ++++++++++++++++++++++++ examples/energy_wizard/run_app.py | 107 ++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 examples/energy_wizard/retrieve_docs.py create mode 100644 examples/energy_wizard/run_app.py diff --git a/examples/energy_wizard/retrieve_docs.py b/examples/energy_wizard/retrieve_docs.py new file mode 100644 index 0000000..d1532b0 --- /dev/null +++ b/examples/energy_wizard/retrieve_docs.py @@ -0,0 +1,116 @@ +import os +import asyncio +import pandas as pd +import logging +import openai +import time +from glob import glob +from rex import init_logger + +from elm.pdf import PDFtoTXT +from elm.embed import ChunkAndEmbed +from elm.osti import OstiList + + +logger = logging.getLogger(__name__) +init_logger(__name__, log_level='DEBUG') +init_logger('elm', log_level='INFO') + + +# NREL-Azure endpoint. You can also use just the openai endpoint. +# NOTE: embedding values are different between OpenAI and Azure models! +openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") +openai.api_key = os.getenv("AZURE_OPENAI_API_KEY") +openai.api_type = 'azure' +openai.api_version = '2023-03-15-preview' + +ChunkAndEmbed.EMBEDDING_MODEL = 'text-embedding-ada-002-2' +ChunkAndEmbed.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + 'text-embedding-ada-002-2/embeddings?' + f'api-version={openai.api_version}') +ChunkAndEmbed.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +PDF_DIR = './pdfs/' +TXT_DIR = './txt/' +EMBED_DIR = './embed/' + +URL = ('https://www.osti.gov/api/v1/records?' + 'research_org=NREL' + '&sort=publication_date%20desc' + '&product_type=Technical%20Report' + '&has_fulltext=true' + '&publication_date_start=01/01/2023' + '&publication_date_end=12/31/2023') + + +if __name__ == '__main__': + os.makedirs(PDF_DIR, exist_ok=True) + os.makedirs(TXT_DIR, exist_ok=True) + os.makedirs(EMBED_DIR, exist_ok=True) + + osti = OstiList(URL, n_pages=1) + osti.download(PDF_DIR) + + meta = osti.meta.copy() + meta['osti_id'] = meta['osti_id'].astype(str) + meta = meta.drop_duplicates(subset=['osti_id']) + meta['fp'] = PDF_DIR + meta['fn'] + meta.to_csv('./meta.csv', index=False) + + missing = [] + for i, row in meta.iterrows(): + if not os.path.exists(row['fp']): + missing.append(i) + meta = meta.drop(missing, axis=0) + + for i, row in meta.iterrows(): + fp = os.path.join(PDF_DIR, row['fn']) + txt_fp = os.path.join(TXT_DIR, row['fn'].replace('.pdf', '.txt')) + embed_fp = os.path.join(EMBED_DIR, row['fn'].replace('.pdf', '.json')) + + assert fp.endswith('.pdf') + assert os.path.exists(fp) + + if os.path.exists(txt_fp): + with open(txt_fp, 'r') as f: + text = f.read() + else: + pdf_obj = PDFtoTXT(fp) + text = pdf_obj.clean_poppler(layout=True) + if pdf_obj.is_double_col(): + text = pdf_obj.clean_poppler(layout=False) + text = pdf_obj.clean_headers(char_thresh=0.6, page_thresh=0.8, + split_on='\n', + iheaders=[0, 1, 3, -3, -2, -1]) + with open(txt_fp, 'w') as f: + f.write(text) + logger.info(f'Saved: {txt_fp}') + + if not os.path.exists(embed_fp): + logger.info('Embedding {}/{}: "{}"' + .format(i+1, len(meta), row['title'])) + tag = f"Title: {row['title']}\nAuthors: {row['authors']}" + obj = ChunkAndEmbed(text, tag=tag, tokens_per_chunk=500, overlap=1) + embeddings = asyncio.run(obj.run_async(rate_limit=3e4)) + if any(e is None for e in embeddings): + raise RuntimeError('Embeddings are None!') + else: + df = pd.DataFrame({'text': obj.text_chunks.chunks, + 'embedding': embeddings, + 'osti_id': row['osti_id']}) + df.to_json(embed_fp, indent=2) + logger.info('Saved: {}'.format(embed_fp)) + time.sleep(5) + + bad = [] + fps = glob(EMBED_DIR + '*.json') + for fp in fps: + data = pd.read_json(fp) + if data['embedding'].isna().any(): + bad.append(fp) + assert not any(bad), f'Bad output: {bad}' + + logger.info('Finished!') diff --git a/examples/energy_wizard/run_app.py b/examples/energy_wizard/run_app.py new file mode 100644 index 0000000..b840ccb --- /dev/null +++ b/examples/energy_wizard/run_app.py @@ -0,0 +1,107 @@ +import streamlit as st +import os +import openai +from glob import glob +import pandas as pd + +from elm import EnergyWizard + + +model = 'gpt-4' + +# NREL-Azure endpoint. You can also use just the openai endpoint. +# NOTE: embedding values are different between OpenAI and Azure models! +openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT") +openai.api_key = os.getenv("AZURE_OPENAI_API_KEY") +openai.api_type = 'azure' +openai.api_version = '2023-03-15-preview' + +EnergyWizard.EMBEDDING_MODEL = 'text-embedding-ada-002-2' +EnergyWizard.EMBEDDING_URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + 'text-embedding-ada-002-2/embeddings?' + f'api-version={openai.api_version}') +EnergyWizard.URL = ('https://stratus-embeddings-south-central.' + 'openai.azure.com/openai/deployments/' + f'{model}/chat/completions?' + f'api-version={openai.api_version}') +EnergyWizard.HEADERS = {"Content-Type": "application/json", + "Authorization": f"Bearer {openai.api_key}", + "api-key": f"{openai.api_key}"} + +EnergyWizard.MODEL_ROLE = ('You are a energy research assistant. Use the ' + 'articles below to answer the question. If ' + 'articles do not provide enough information to ' + 'answer the question, say "I do not know."') +EnergyWizard.MODEL_INSTRUCTION = EnergyWizard.MODEL_ROLE + + +@st.cache_data +def get_corpus(): + """Get the corpus of text data with embeddings.""" + corpus = sorted(glob('./embed/*.json')) + corpus = [pd.read_json(fp) for fp in corpus] + corpus = pd.concat(corpus, ignore_index=True) + meta = pd.read_csv('./meta.csv') + + corpus['osti_id'] = corpus['osti_id'].astype(str) + meta['osti_id'] = meta['osti_id'].astype(str) + corpus = corpus.set_index('osti_id') + meta = meta.set_index('osti_id') + + corpus = corpus.join(meta, on='osti_id', rsuffix='_record', how='left') + + ref = [f"{row['title']} ({row['doi']})" for _, row in corpus.iterrows()] + corpus['ref'] = ref + + return corpus + + +@st.cache_resource +def get_wizard(): + """Get the energy wizard object.""" + corpus = get_corpus() + wizard = EnergyWizard(corpus, ref_col='ref', model=model) + return wizard + + +if __name__ == '__main__': + wizard = get_wizard() + st.title("Energy Wizard") + + if "messages" not in st.session_state: + st.session_state.messages = [] + + for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + + msg = "Ask the Energy Wizard a question about NREL research!" + if prompt := st.chat_input(msg): + st.chat_message("user").markdown(prompt) + st.session_state.messages.append({"role": "user", "content": prompt}) + + with st.chat_message("assistant"): + + message_placeholder = st.empty() + full_response = "" + + out = wizard.chat(prompt, + debug=True, stream=True, token_budget=6000, + temperature=0.0, print_references=True, + convo=False, return_chat_obj=True) + references = out[-1] + + for response in out[0]: + full_response += response.choices[0].delta.get("content", "") + message_placeholder.markdown(full_response + "▌") + + ref_msg = ('\n\nThe wizard was provided with the ' + 'following documents to support its answer:') + ref_msg += '\n - ' + '\n - '.join(references) + full_response += ref_msg + + message_placeholder.markdown(full_response) + + st.session_state.messages.append({"role": "assistant", + "content": full_response})