diff --git a/elm/wizard.py b/elm/wizard.py index 7e0424c..8f9472d 100644 --- a/elm/wizard.py +++ b/elm/wizard.py @@ -407,7 +407,8 @@ class EnergyWizardPostgres(EnergyWizardBase): def __init__(self, db_host, db_port, db_name, db_schema, db_table, meta_columns=None, cursor=None, boto_client=None, - model=None, token_budget=3500): + model=None, token_budget=3500, + tag=False): """ Parameters ---------- @@ -435,6 +436,9 @@ def __init__(self, db_host, db_port, db_name, Number of tokens that can be embedded in the prompt. Note that the default budget for GPT-3.5-Turbo is 4096, but you want to subtract some tokens to account for the response budget. + tag: bool + Flag to add tag/metadata to text chunks before sending query to + GPT. """ boto3 = try_import('boto3') psycopg2 = try_import('psycopg2') @@ -461,6 +465,8 @@ def __init__(self, db_host, db_port, db_name, else: self.cursor = cursor + self.tag = tag + if boto_client is None: access_key = os.getenv('AWS_ACCESS_KEY_ID') secret_key = os.getenv('AWS_SECRET_ACCESS_KEY') @@ -521,7 +527,33 @@ def get_embedding(self, text): return embedding - def query_vector_db(self, query, limit=100): + @staticmethod + def _add_tag(meta): + """Function to add tag/metadata to text strings before + sending query to GPT. + + Parameters + ---------- + meta : tuple + Text values to include in tag (title, authors, year) + + Returns + ------- + tag : str + Text string containing provided metadata. + """ + title, authors, year = meta + if authors and year: + tag = (f"Title: {title}\n" + f"Authors: {authors}\n" + f"Publication Year: {year}\n\n" + ) + else: + tag = f"Title: {title}\n\n" + + return tag + + def query_vector_db(self, query, probes=25, limit=100): """Returns a list of strings and relatednesses, sorted from most related to least. @@ -529,6 +561,8 @@ def query_vector_db(self, query, limit=100): ---------- query : str Question being asked of GPT + probes: int + Number of lists to search in vector database index. limit : int Number of top results to return. @@ -545,17 +579,25 @@ def query_vector_db(self, query, limit=100): query_embedding = self.get_embedding(query) - self.cursor.execute(f"SELECT {self.db_table}.id, " + self.cursor.execute(f"SET LOCAL ivfflat.probes = {probes};" + f"SELECT {self.db_table}.id, " f"{self.db_table}.chunks, " f"{self.db_table}.embedding " - "<=> %s::vector as score " + "<=> %s::vector as score, " + f"{self.db_table}.title, " + f"{self.db_table}.authors, " + f"{self.db_table}.year " f"FROM {self.db_schema}.{self.db_table} " "ORDER BY embedding <=> %s::vector LIMIT %s;", (query_embedding, query_embedding, limit,), ) result = self.cursor.fetchall() - strings = [s[1] for s in result] + if self.tag: + strings = [self._add_tag(s[3:]) + s[1] for s in result] + else: + strings = [s[1] for s in result] + scores = [s[2] for s in result] best = [s[0] for s in result]