Skip to content

Commit

Permalink
Tweak ChatAgent so it uses available data (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Oct 29, 2024
1 parent ee99796 commit 99998ad
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
25 changes: 10 additions & 15 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
)
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, get_data, get_pipeline, get_schema,
render_template, report_error, retry_llm_output,
clean_sql, describe_data, gather_table_sources, get_data, get_pipeline,
get_schema, render_template, report_error, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -247,13 +247,17 @@ async def requirements(self, messages: list | str, errors=None):
if 'current_data' in memory:
return self.requires

available_sources = memory["available_sources"]
_, tables_schema_str = await gather_table_sources(available_sources)
with self.interface.add_step(title="Checking if data is required") as step:
response = self.llm.stream(
messages,
system=(
"The user may or may not want to chat about a particular dataset. "
"Determine whether the provided user prompt requires access to "
"actual data. If they're only searching for one, it's not required."
"Assess if the user's prompt requires loading data. "
"If the inquiry is just about available tables, no data access required. "
"However, if relevant tables apply to the query, load the data for a "
"more accurate and up-to-date response. "
f"Here are the available tables:\n{tables_schema_str}"
),
response_model=DataRequired,
)
Expand Down Expand Up @@ -426,17 +430,8 @@ class SQLAgent(LumenBaseAgent):
async def _select_relevant_table(self, messages: list | str) -> tuple[str, BaseSQLSource]:
"""Select the most relevant table based on the user query."""
available_sources = memory["available_sources"]
tables_to_source = {}
tables_schema_str = "\nHere are the tables\n"
for source in available_sources:
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"

tables_to_source, tables_schema_str = await gather_table_sources(available_sources)
tables = tuple(tables_to_source)
if messages and messages[-1]["content"].startswith("Show the table: '"):
# Handle the case where explicitly requested a table
Expand Down
20 changes: 20 additions & 0 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import jinja2
import pandas as pd
import yaml

from lumen.pipeline import Pipeline
from lumen.sources.base import Source
from lumen.sources.duckdb import DuckDBSource

from .config import THIS_DIR, UNRECOVERABLE_ERRORS

Expand Down Expand Up @@ -272,3 +274,21 @@ def report_error(exc: Exception, step: ChatStep):
error_msg = error_msg[:50] + "..."
step.failed_title = error_msg
step.status = "failed"


async def gather_table_sources(available_sources: list[Source]) -> tuple[dict[str, Source], str]:
"""
Get a dictionary of tables to their respective sources
and a markdown string of the tables and their schemas.
"""
tables_to_source = {}
tables_schema_str = "\nHere are the tables\n"
for source in available_sources:
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"
return tables_to_source, tables_schema_str

0 comments on commit 99998ad

Please sign in to comment.