Skip to content

Commit

Permalink
Merge pull request #29 from NREL/sp/refs_cleanup
Browse files Browse the repository at this point in the history
add new references code, query error handling
  • Loading branch information
spodgorny9 authored Sep 17, 2024
2 parents 6ca811b + a7083f8 commit 090f828
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ jobs:
python -m pip install .
- name: Run pytest and Generate coverage report
run: |
python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings
python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web --ignore=tests/test_wizard_postgres.py -v --disable-warnings
40 changes: 40 additions & 0 deletions .github/workflows/pytest_postgres.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: pytests-postgres

on: pull_request

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.11]

steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 1
- name: Set up Python ${{ matrix.python-version }}
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
channels: conda-forge,defaults
python-version: ${{ matrix.python-version }}
miniconda-version: "latest"
- name: Install dependencies
shell: bash -l {0}
run: |
python -m pip install --upgrade pip
python -m pip install psycopg2-binary
python -m pip install boto3
python -m pip install pytest
python -m pip install pytest-mock
python -m pip install pytest-cov
python -m pip install .
- name: Run pytest for postgres
shell: bash -l {0}
run: |
python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings
39 changes: 27 additions & 12 deletions elm/web/rhub.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,22 +585,27 @@ def authors(self):
"""
pa = self.get('personAssociations')

if not pa:
return None

authors = []

for r in pa:
first = r.get('name').get('firstName')
last = r.get('name').get('lastName')
name = r.get('name')

if not name:
continue

if first and last:
full = first + ' ' + last
elif first:
full = first
elif last:
full = last
first = name.get('firstName')
last = name.get('lastName')
full = " ".join(filter(bool, [first, last]))

if not full:
continue

authors.append(full)

out = ', '.join(authors)
out = ', '.join(authors)

return out

Expand Down Expand Up @@ -653,8 +658,16 @@ def abstract(self):
String containing abstract text.
"""
abstract = self.get('abstract')
text = abstract.get('text')[0]
value = text.get('value')

if not abstract:
return None

text = abstract.get('text')

if not text:
return None

value = text[0].get('value')

return value

Expand Down Expand Up @@ -701,6 +714,9 @@ def download(self, pdf_dir, txt_dir):
if not os.path.exists(fp):
if abstract:
self.save_abstract(abstract, fp)
else:
logger.info(f'{self.title}: does not have an '
'abstract to downlod')
else:
if pdf_url and pdf_url.endswith('.pdf'):
fn = self.id.replace('/', '-') + '.pdf'
Expand Down Expand Up @@ -876,7 +892,6 @@ def download(self, pdf_dir, txt_dir):
try:
record.download(pdf_dir, txt_dir)
except Exception as e:
print(f"Could not download {record.title} with error {e}")
logger.exception('Could not download {}: {}'
.format(record.title, e))
logger.info('Finished publications download!')
170 changes: 120 additions & 50 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,14 @@ class EnergyWizardPostgres(EnergyWizardBase):
}
"""Optional mappings for weird azure names to tiktoken/openai names."""

DEFAULT_META_COLS = ('title', 'url', 'authors', 'year', 'category', 'id')
"""Default columns to retrieve for metadata"""

def __init__(self, db_host, db_port, db_name,
db_schema, db_table, meta_columns=None,
cursor=None, boto_client=None,
model=None, token_budget=3500,
tag=False):
db_schema, db_table, probes=25,
meta_columns=None, cursor=None,
boto_client=None, model=None,
token_budget=3500, tag=False):
"""
Parameters
----------
Expand All @@ -423,6 +426,9 @@ def __init__(self, db_host, db_port, db_name,
db_table : str
Table to query in Postgres database. Necessary columns: id,
chunks, embedding, title, and url.
probes : int
Number of lists to search in vector database. Recommended
value is sqrt(n_lists).
meta_columns : list
List of metadata columns to retrieve from database. Default
query returns title and url.
Expand All @@ -441,31 +447,33 @@ def __init__(self, db_host, db_port, db_name,
GPT.
"""
boto3 = try_import('boto3')
psycopg2 = try_import('psycopg2')
self.psycopg2 = try_import('psycopg2')

self.db_schema = db_schema
self.db_table = db_table
if meta_columns is None:
self.meta_columns = ['title', 'url']
self.meta_columns = ['title', 'url', 'id']
else:
self.meta_columns = meta_columns

assert 'id' in self.meta_columns, "Please include the 'id' column!"

if cursor is None:
db_user = os.getenv("EWIZ_DB_USER")
db_password = os.getenv('EWIZ_DB_PASSWORD')
assert db_user is not None, "Must set EWIZ_DB_USER!"
assert db_password is not None, "Must set EWIZ_DB_PASSWORD!"
self.conn = psycopg2.connect(user=db_user,
password=db_password,
host=db_host,
port=db_port,
database=db_name)
self.db_kwargs = dict(user=db_user, password=db_password,
host=db_host, port=db_port,
database=db_name)
self.conn = self.psycopg2.connect(**self.db_kwargs)

self.cursor = self.conn.cursor()
else:
self.cursor = cursor

self.db_schema = db_schema
self.db_table = db_table
self.tag = tag
self.probes = probes

if boto_client is None:
access_key = os.getenv('AWS_ACCESS_KEY_ID')
Expand Down Expand Up @@ -553,16 +561,18 @@ def _add_tag(meta):

return tag

def query_vector_db(self, query, probes=25, limit=100):
def query_vector_db(self, query, limit=100):
"""Returns a list of strings and relatednesses, sorted from most
related to least.
related to least. SQL query uses a context handler and rollback
to ensure a failed query does not interupt future questions from
the user. Ex: a user submitting a new question before the first
one completes will close the cursor preventing future database
access.
Parameters
----------
query : str
Question being asked of GPT
probes: int
Number of lists to search in vector database index.
limit : int
Number of top results to return.
Expand All @@ -579,19 +589,28 @@ def query_vector_db(self, query, probes=25, limit=100):

query_embedding = self.get_embedding(query)

self.cursor.execute(f"SET LOCAL ivfflat.probes = {probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
"<=> %s::vector as score, "
f"{self.db_table}.title, "
f"{self.db_table}.authors, "
f"{self.db_table}.year "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )

result = self.cursor.fetchall()
with self.psycopg2.connect(**self.db_kwargs) as conn:
cursor = conn.cursor()
try:
cursor.execute(f"SET LOCAL ivfflat.probes = {self.probes};"
f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding "
"<=> %s::vector as score, "
f"{self.db_table}.title, "
f"{self.db_table}.authors, "
f"{self.db_table}.year "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )
except Exception as exc:
conn.rollback()
msg = (f'Received error when querying the postgres '
f'vector database: {exc}')
raise RuntimeError(msg) from exc
else:
conn.commit()
result = cursor.fetchall()

if self.tag:
strings = [self._add_tag(s[3:]) + s[1] for s in result]
Expand All @@ -603,12 +622,66 @@ def query_vector_db(self, query, probes=25, limit=100):

return strings, scores, best

def _format_refs(self, refs, ids):
"""Parse and nicely format a reference dictionary into a list of well
formatted string representations
Parameters
----------
refs : list
List of references returned from the vector db
ids : np.ndarray
IDs of the used text from the text corpus sorted by embedding
relevance.
Returns
-------
out : list
Unique ordered list of references (most relevant first)
"""

ref_list = []
for item in refs:
ref_dict = {col: str(value).replace(chr(34), '')
for col, value in zip(self.meta_columns, item)}

ref_list.append(ref_dict)

assert len(ref_list) > 0, ("The Wizard did not return any "
"references. Please check your database "
"connection or query.")

unique_ref_list = []
for ref_dict in ref_list:
if any(ref_dict == d for d in unique_ref_list):
continue
unique_ref_list.append(ref_dict)
ref_list = unique_ref_list

if 'id' in ref_list[0]:
ids_list = list(ids)
sorted_ref_list = []
for ref_id in ids_list:
for ref_dict in ref_list:
if ref_dict['id'] == ref_id:
sorted_ref_list.append(ref_dict)
break
ref_list = sorted_ref_list

ref_list = [json.dumps(ref) for ref in ref_list]

return ref_list

def make_ref_list(self, ids):
"""Make a reference list
"""Make a reference list. SQL query uses a context handler and
rollback to ensure a failed query does not interupt future questions
from the user. Ex: a user submitting a new question before the first
one completes will close the cursor preventing future database
access.
Parameters
----------
used_index : np.ndarray
ids : np.ndarray
IDs of the used text from the text corpus
Returns
Expand All @@ -626,22 +699,19 @@ def make_ref_list(self, ids):
f"FROM {self.db_schema}.{self.db_table} "
f"WHERE {self.db_table}.id IN (" + placeholders + ")")

self.cursor.execute(sql_query, ids)

refs = self.cursor.fetchall()

ref_list = []
for item in refs:
ref_dict = {self.meta_columns[i]: item[i]
for i in range(len(self.meta_columns))}
ref_str = "{"
ref_str += ", ".join([f"\"{key}\": \"{value}\""
for key, value in ref_dict.items()])
ref_str += "}"

ref_list.append(ref_str)
with self.psycopg2.connect(**self.db_kwargs) as conn:
cursor = conn.cursor()
try:
cursor.execute(sql_query, ids)
except Exception as exc:
conn.rollback()
msg = (f'Received error when querying the postgres '
f'vector database: {exc}')
raise RuntimeError(msg) from exc
else:
conn.commit()
refs = cursor.fetchall()

ref_list = self._format_refs(refs, ids)

unique_values = set(ref_list)
unique_list = list(unique_values)

return unique_list
return ref_list
2 changes: 1 addition & 1 deletion examples/research_hub/retrieve_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


logger = logging.getLogger(__name__)
init_logger(__name__, log_level='DEBUG')
init_logger(__name__, log_level='INFO')
init_logger('elm', log_level='INFO')


Expand Down
Loading

0 comments on commit 090f828

Please sign in to comment.