diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 9e89968a..ac88fb41 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -34,4 +34,4 @@ jobs: python -m pip install . - name: Run pytest and Generate coverage report run: | - python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings + python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web --ignore=tests/test_wizard_postgres.py -v --disable-warnings diff --git a/.github/workflows/pytest_postgres.yml b/.github/workflows/pytest_postgres.yml new file mode 100644 index 00000000..84054c66 --- /dev/null +++ b/.github/workflows/pytest_postgres.yml @@ -0,0 +1,40 @@ +name: pytests-postgres + +on: pull_request + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.11] + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + fetch-depth: 1 + - name: Set up Python ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v3 + with: + auto-update-conda: true + channels: conda-forge,defaults + python-version: ${{ matrix.python-version }} + miniconda-version: "latest" + - name: Install dependencies + shell: bash -l {0} + run: | + python -m pip install --upgrade pip + python -m pip install psycopg2-binary + python -m pip install boto3 + python -m pip install pytest + python -m pip install pytest-mock + python -m pip install pytest-cov + python -m pip install . + - name: Run pytest for postgres + shell: bash -l {0} + run: | + python -m pytest --ignore=tests/ords --ignore=tests/utilities --ignore=tests/web -v --disable-warnings + diff --git a/elm/web/rhub.py b/elm/web/rhub.py index df1d3cad..7b88503c 100644 --- a/elm/web/rhub.py +++ b/elm/web/rhub.py @@ -585,22 +585,27 @@ def authors(self): """ pa = self.get('personAssociations') + if not pa: + return None + authors = [] for r in pa: - first = r.get('name').get('firstName') - last = r.get('name').get('lastName') + name = r.get('name') + + if not name: + continue - if first and last: - full = first + ' ' + last - elif first: - full = first - elif last: - full = last + first = name.get('firstName') + last = name.get('lastName') + full = " ".join(filter(bool, [first, last])) + + if not full: + continue authors.append(full) - out = ', '.join(authors) + out = ', '.join(authors) return out @@ -653,8 +658,16 @@ def abstract(self): String containing abstract text. """ abstract = self.get('abstract') - text = abstract.get('text')[0] - value = text.get('value') + + if not abstract: + return None + + text = abstract.get('text') + + if not text: + return None + + value = text[0].get('value') return value @@ -701,6 +714,9 @@ def download(self, pdf_dir, txt_dir): if not os.path.exists(fp): if abstract: self.save_abstract(abstract, fp) + else: + logger.info(f'{self.title}: does not have an ' + 'abstract to downlod') else: if pdf_url and pdf_url.endswith('.pdf'): fn = self.id.replace('/', '-') + '.pdf' @@ -876,7 +892,6 @@ def download(self, pdf_dir, txt_dir): try: record.download(pdf_dir, txt_dir) except Exception as e: - print(f"Could not download {record.title} with error {e}") logger.exception('Could not download {}: {}' .format(record.title, e)) logger.info('Finished publications download!') diff --git a/elm/wizard.py b/elm/wizard.py index 8f9472d2..e369b274 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -404,11 +404,14 @@ class EnergyWizardPostgres(EnergyWizardBase): } """Optional mappings for weird azure names to tiktoken/openai names.""" + DEFAULT_META_COLS = ('title', 'url', 'authors', 'year', 'category', 'id') + """Default columns to retrieve for metadata""" + def __init__(self, db_host, db_port, db_name, - db_schema, db_table, meta_columns=None, - cursor=None, boto_client=None, - model=None, token_budget=3500, - tag=False): + db_schema, db_table, probes=25, + meta_columns=None, cursor=None, + boto_client=None, model=None, + token_budget=3500, tag=False): """ Parameters ---------- @@ -423,6 +426,9 @@ def __init__(self, db_host, db_port, db_name, db_table : str Table to query in Postgres database. Necessary columns: id, chunks, embedding, title, and url. + probes : int + Number of lists to search in vector database. Recommended + value is sqrt(n_lists). meta_columns : list List of metadata columns to retrieve from database. Default query returns title and url. @@ -441,31 +447,33 @@ def __init__(self, db_host, db_port, db_name, GPT. """ boto3 = try_import('boto3') - psycopg2 = try_import('psycopg2') + self.psycopg2 = try_import('psycopg2') - self.db_schema = db_schema - self.db_table = db_table if meta_columns is None: - self.meta_columns = ['title', 'url'] + self.meta_columns = ['title', 'url', 'id'] else: self.meta_columns = meta_columns + assert 'id' in self.meta_columns, "Please include the 'id' column!" + if cursor is None: db_user = os.getenv("EWIZ_DB_USER") db_password = os.getenv('EWIZ_DB_PASSWORD') assert db_user is not None, "Must set EWIZ_DB_USER!" assert db_password is not None, "Must set EWIZ_DB_PASSWORD!" - self.conn = psycopg2.connect(user=db_user, - password=db_password, - host=db_host, - port=db_port, - database=db_name) + self.db_kwargs = dict(user=db_user, password=db_password, + host=db_host, port=db_port, + database=db_name) + self.conn = self.psycopg2.connect(**self.db_kwargs) self.cursor = self.conn.cursor() else: self.cursor = cursor + self.db_schema = db_schema + self.db_table = db_table self.tag = tag + self.probes = probes if boto_client is None: access_key = os.getenv('AWS_ACCESS_KEY_ID') @@ -553,16 +561,18 @@ def _add_tag(meta): return tag - def query_vector_db(self, query, probes=25, limit=100): + def query_vector_db(self, query, limit=100): """Returns a list of strings and relatednesses, sorted from most - related to least. + related to least. SQL query uses a context handler and rollback + to ensure a failed query does not interupt future questions from + the user. Ex: a user submitting a new question before the first + one completes will close the cursor preventing future database + access. Parameters ---------- query : str Question being asked of GPT - probes: int - Number of lists to search in vector database index. limit : int Number of top results to return. @@ -579,19 +589,28 @@ def query_vector_db(self, query, probes=25, limit=100): query_embedding = self.get_embedding(query) - self.cursor.execute(f"SET LOCAL ivfflat.probes = {probes};" - f"SELECT {self.db_table}.id, " - f"{self.db_table}.chunks, " - f"{self.db_table}.embedding " - "<=> %s::vector as score, " - f"{self.db_table}.title, " - f"{self.db_table}.authors, " - f"{self.db_table}.year " - f"FROM {self.db_schema}.{self.db_table} " - "ORDER BY embedding <=> %s::vector LIMIT %s;", - (query_embedding, query_embedding, limit,), ) - - result = self.cursor.fetchall() + with self.psycopg2.connect(**self.db_kwargs) as conn: + cursor = conn.cursor() + try: + cursor.execute(f"SET LOCAL ivfflat.probes = {self.probes};" + f"SELECT {self.db_table}.id, " + f"{self.db_table}.chunks, " + f"{self.db_table}.embedding " + "<=> %s::vector as score, " + f"{self.db_table}.title, " + f"{self.db_table}.authors, " + f"{self.db_table}.year " + f"FROM {self.db_schema}.{self.db_table} " + "ORDER BY embedding <=> %s::vector LIMIT %s;", + (query_embedding, query_embedding, limit,), ) + except Exception as exc: + conn.rollback() + msg = (f'Received error when querying the postgres ' + f'vector database: {exc}') + raise RuntimeError(msg) from exc + else: + conn.commit() + result = cursor.fetchall() if self.tag: strings = [self._add_tag(s[3:]) + s[1] for s in result] @@ -603,12 +622,66 @@ def query_vector_db(self, query, probes=25, limit=100): return strings, scores, best + def _format_refs(self, refs, ids): + """Parse and nicely format a reference dictionary into a list of well + formatted string representations + + Parameters + ---------- + refs : list + List of references returned from the vector db + ids : np.ndarray + IDs of the used text from the text corpus sorted by embedding + relevance. + + Returns + ------- + out : list + Unique ordered list of references (most relevant first) + """ + + ref_list = [] + for item in refs: + ref_dict = {col: str(value).replace(chr(34), '') + for col, value in zip(self.meta_columns, item)} + + ref_list.append(ref_dict) + + assert len(ref_list) > 0, ("The Wizard did not return any " + "references. Please check your database " + "connection or query.") + + unique_ref_list = [] + for ref_dict in ref_list: + if any(ref_dict == d for d in unique_ref_list): + continue + unique_ref_list.append(ref_dict) + ref_list = unique_ref_list + + if 'id' in ref_list[0]: + ids_list = list(ids) + sorted_ref_list = [] + for ref_id in ids_list: + for ref_dict in ref_list: + if ref_dict['id'] == ref_id: + sorted_ref_list.append(ref_dict) + break + ref_list = sorted_ref_list + + ref_list = [json.dumps(ref) for ref in ref_list] + + return ref_list + def make_ref_list(self, ids): - """Make a reference list + """Make a reference list. SQL query uses a context handler and + rollback to ensure a failed query does not interupt future questions + from the user. Ex: a user submitting a new question before the first + one completes will close the cursor preventing future database + access. Parameters ---------- - used_index : np.ndarray + ids : np.ndarray IDs of the used text from the text corpus Returns @@ -626,22 +699,19 @@ def make_ref_list(self, ids): f"FROM {self.db_schema}.{self.db_table} " f"WHERE {self.db_table}.id IN (" + placeholders + ")") - self.cursor.execute(sql_query, ids) - - refs = self.cursor.fetchall() - - ref_list = [] - for item in refs: - ref_dict = {self.meta_columns[i]: item[i] - for i in range(len(self.meta_columns))} - ref_str = "{" - ref_str += ", ".join([f"\"{key}\": \"{value}\"" - for key, value in ref_dict.items()]) - ref_str += "}" - - ref_list.append(ref_str) + with self.psycopg2.connect(**self.db_kwargs) as conn: + cursor = conn.cursor() + try: + cursor.execute(sql_query, ids) + except Exception as exc: + conn.rollback() + msg = (f'Received error when querying the postgres ' + f'vector database: {exc}') + raise RuntimeError(msg) from exc + else: + conn.commit() + refs = cursor.fetchall() + + ref_list = self._format_refs(refs, ids) - unique_values = set(ref_list) - unique_list = list(unique_values) - - return unique_list + return ref_list diff --git a/examples/research_hub/retrieve_docs.py b/examples/research_hub/retrieve_docs.py index 6deadb2b..999e757d 100644 --- a/examples/research_hub/retrieve_docs.py +++ b/examples/research_hub/retrieve_docs.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -init_logger(__name__, log_level='DEBUG') +init_logger(__name__, log_level='INFO') init_logger('elm', log_level='INFO') diff --git a/tests/test_wizard_postgres.py b/tests/test_wizard_postgres.py index c2d8fa2e..feb1e376 100644 --- a/tests/test_wizard_postgres.py +++ b/tests/test_wizard_postgres.py @@ -5,6 +5,7 @@ import ast import json from io import BytesIO +import numpy as np from elm import TEST_DATA_DIR from elm.wizard import EnergyWizardPostgres @@ -21,6 +22,9 @@ QUERY_TUPLE = ast.literal_eval(QUERY_TEXT) REF_TUPLE = ast.literal_eval(REF_TEXT) +os.environ["EWIZ_DB_USER"] = "user" +os.environ["EWIZ_DB_PASSWORD"] = "password" + class Cursor: """Dummy class for mocking database cursor objects""" @@ -63,14 +67,18 @@ def invoke_model(self, **kwargs): # pylint: disable=unused-argument return dummy_response -def test_postgres(): +def test_postgres(mocker): """Test to ensure correct response vector db.""" - os.environ["EWIZ_DB_USER"] = "user" + mock_conn_cm = mocker.MagicMock() + mock_conn = mock_conn_cm.__enter__.return_value + mock_conn.cursor.return_value = Cursor() + mock_connect = mocker.patch('psycopg2.connect') + mock_connect.return_value = mock_conn_cm wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', db_name='Dummy', db_schema='Dummy', - db_table='Dummy', cursor=Cursor(), + db_table='Dummy', boto_client=BotoClient()) question = 'Is this a dummy question?' @@ -87,3 +95,69 @@ def test_postgres(): assert 'title' in str(ref_list) assert 'url' in str(ref_list) assert 'research-hub.nrel.gov' in str(ref_list) + + +def test_ref_replace(): + """Test to ensure removal of double quotes from references.""" + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', cursor=Cursor(), + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [(chr(34), 'test.com', '5a'), + ('remove "double" quotes', 'test_2.com', '7b')] + + ids = np.array(['7b', '5a']) + + out = wizard._format_refs(refs, ids) + + assert len(out) > 1 + + for i in out: + refs_dict = json.loads(i) + assert '"' not in refs_dict['title'] + assert chr(34) not in refs_dict['title'] + + +def test_ids(): + """Test to ensure only records with valid ids are returned.""" + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', cursor=Cursor(), + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [('title', 'test.com', '5a'), + ('title2', 'test_2.com', '7b')] + + ids = np.array(['7c', '5a']) + + out = wizard._format_refs(refs, ids) + + assert len(out) == 1 + assert not any('7b' in item for item in out) + + +def test_sorted_refs(): + """Test to ensure references are sorted in same order as ids.""" + + wizard = EnergyWizardPostgres(db_host='Dummy', db_port='Dummy', + db_name='Dummy', db_schema='Dummy', + db_table='Dummy', cursor=Cursor(), + boto_client=BotoClient(), + meta_columns=['title', 'url', 'id']) + + refs = [('title', 'test.com', '5a'), + ('title2', 'test_2.com', '7b')] + + ids = np.array(['7b', '5a']) + + expected = ['{"title": "title2", "url": "test_2.com", "id": "7b"}', + '{"title": "title", "url": "test.com", "id": "5a"}'] + + out = wizard._format_refs(refs, ids) + + assert expected == out