Skip to content

Commit

Permalink
Cleanup and refactor agents (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Oct 29, 2024
1 parent 59c36ae commit 3722bd0
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 115 deletions.
166 changes: 55 additions & 111 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from .memory import memory
from .models import (
DataRequired, FuzzyTable, JoinRequired, Sql, TableJoins, Topic,
VegaLiteSpec,
VegaLiteSpec, make_table_model,
)
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, get_data, get_pipeline, get_schema,
render_template, report_error, retry_llm_output,
clean_sql, describe_data, gather_table_sources, get_data, get_pipeline,
get_schema, render_template, report_error, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -245,13 +245,17 @@ async def requirements(self, messages: list | str, errors=None):
if 'current_data' in memory:
return self.requires

available_sources = memory["available_sources"]
_, tables_schema_str = await gather_table_sources(available_sources)
with self.interface.add_step(title="Checking if data is required") as step:
response = self.llm.stream(
messages,
system=(
"The user may or may not want to chat about a particular dataset. "
"Determine whether the provided user prompt requires access to "
"actual data. If they're only searching for one, it's not required."
"Assess if the user's prompt requires loading data. "
"If the inquiry is just about available tables, no data access required. "
"However, if relevant tables apply to the query, load the data for a "
"more accurate and up-to-date response. "
f"Here are the available tables:\n{tables_schema_str}"
),
response_model=DataRequired,
)
Expand Down Expand Up @@ -350,107 +354,6 @@ def _render_lumen(self, component: Component, message: pn.chat.ChatMessage = Non
self.interface.stream(message=message, **message_kwargs, replace=True, max_width=self._max_width)


class TableAgent(LumenBaseAgent):
"""
Displays a single table / dataset. Does not discuss.
"""

system_prompt = param.String(
default=textwrap.dedent(
"""
Identify the most relevant table that contains the most columns useful
for answering the user's query. Keep in mind that additional tables
can be joined later, so focus on selecting the best starting point.
"""
)
)

requires = param.List(default=["current_source"], readonly=True)

provides = param.List(default=["current_table", "current_pipeline"], readonly=True)

@staticmethod
def _create_table_model(tables):
table_model = create_model(
"Table",
chain_of_thought=(str, FieldInfo(
description="The thought process behind selecting the table, listing out which columns are useful."
)),
relevant_table=(Literal[tables], FieldInfo(
description="The most relevant table based on the user query; if none are relevant, select the first."
))
)
return table_model

async def answer(self, messages: list | str):
available_sources = memory["available_sources"]
tables_to_source = {}
tables_schema_str = "\nHere are the tables\n"
for source in available_sources:
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"

tables = tuple(tables_to_source)
if messages and messages[-1]["content"].startswith("Show the table: '"):
# Handle the case where the TableListAgent explicitly requested a table
table = messages[-1]["content"].replace("Show the table: '", "")[:-1]
elif len(tables) == 1:
table = tables[0]
else:
with self.interface.add_step(title="Choosing the most relevant table...") as step:
closest_tables = memory.pop("closest_tables", [])
if closest_tables:
tables = closest_tables
elif len(tables) > FUZZY_TABLE_LENGTH:
tables = await self._get_closest_tables(messages, tables)
system_prompt = await self._system_prompt_with_context(messages, context=tables_schema_str)
if self.debug:
print(f"{self.name} is being instructed that it should {system_prompt}")
if len(tables) > 1:
table_model = self._create_table_model(tables)
result = await self.llm.invoke(
messages,
system=system_prompt,
response_model=table_model,
allow_partial=False,
)
table = result.relevant_table
step.stream(f"{result.chain_of_thought}\n\nSelected table: {table}")
else:
table = tables[0]
step.stream(f"Selected table: {table}")

if table in tables_to_source:
source = tables_to_source[table]
else:
sources = [src for src in available_sources if table in src]
source = sources[0] if sources else memory["current_source"]

get_kwargs = {}
if isinstance(source, BaseSQLSource):
get_kwargs['sql_transforms'] = [SQLLimit(limit=1_000_000)]
memory["current_source"] = source
memory["current_table"] = table
memory["current_pipeline"] = pipeline = await get_pipeline(
source=source, table=table, **get_kwargs
)
df = await get_data(pipeline)
if len(df) > 0:
memory["current_data"] = await describe_data(df)
if self.debug:
print(f"{self.name} thinks that the user is talking about {table=!r}.")
return pipeline

async def invoke(self, messages: list | str):
pipeline = await self.answer(messages)
self._render_lumen(pipeline)


class TableListAgent(LumenBaseAgent):
"""
Provides a list of all availables tables/datasets.
Expand Down Expand Up @@ -518,9 +421,51 @@ class SQLAgent(LumenBaseAgent):
)
)

requires = param.List(default=["current_table", "current_source"], readonly=True)
requires = param.List(default=["current_source"], readonly=True)

provides = param.List(default=["current_table", "current_sql", "current_pipeline"], readonly=True)

async def _select_relevant_table(self, messages: list | str) -> tuple[str, BaseSQLSource]:
"""Select the most relevant table based on the user query."""
available_sources = memory["available_sources"]

tables_to_source, tables_schema_str = await gather_table_sources(available_sources)
tables = tuple(tables_to_source)
if messages and messages[-1]["content"].startswith("Show the table: '"):
# Handle the case where explicitly requested a table
table = messages[-1]["content"].replace("Show the table: '", "")[:-1]
elif len(tables) == 1:
table = tables[0]
else:
with self.interface.add_step(title="Choosing the most relevant table...") as step:
closest_tables = memory.pop("closest_tables", [])
if closest_tables:
tables = closest_tables
elif len(tables) > FUZZY_TABLE_LENGTH:
tables = await self._get_closest_tables(messages, tables)
system_prompt = await self._system_prompt_with_context(messages, context=tables_schema_str)

if len(tables) > 1:
table_model = make_table_model(tables)
result = await self.llm.invoke(
messages,
system=system_prompt,
response_model=table_model,
allow_partial=False,
)
table = result.relevant_table
step.stream(f"{result.chain_of_thought}\n\nSelected table: {table}")
else:
table = tables[0]
step.stream(f"Selected table: {table}")

if table in tables_to_source:
source = tables_to_source[table]
else:
sources = [src for src in available_sources if table in src]
source = sources[0] if sources else memory["current_source"]

provides = param.List(default=["current_sql", "current_pipeline"], readonly=True)
return table, source

def _render_sql(self, query):
pipeline = memory['current_pipeline']
Expand Down Expand Up @@ -699,8 +644,7 @@ async def answer(self, messages: list | str):
8. If a join is required, remove source/table prefixes from the last message.
9. Construct the SQL query with `_create_valid_sql`.
"""
source = memory["current_source"]
table = memory["current_table"]
table, source = await self._select_relevant_table(messages)

if not hasattr(source, "get_sql_expr"):
return None
Expand Down
4 changes: 2 additions & 2 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from panel.widgets import Button, FileDownload

from .agents import (
Agent, AnalysisAgent, ChatAgent, SQLAgent, TableAgent,
Agent, AnalysisAgent, ChatAgent, SQLAgent,
)
from .config import DEMO_MESSAGES, GETTING_STARTED_SUGGESTIONS
from .export import export_notebook
Expand Down Expand Up @@ -408,7 +408,7 @@ async def _get_agent(self, messages: list | str):
step.stream(f"`{agent_name}` agent is working on the following task:\n\n{instruction}")
self._current_agent.object = f"## **Current Agent**: {agent_name}"
custom_messages = messages.copy()
if isinstance(subagent, (TableAgent, SQLAgent)):
if isinstance(subagent, SQLAgent):
custom_agent = next((agent for agent in self.agents if isinstance(agent, AnalysisAgent)), None)
if custom_agent:
custom_analysis_doc = custom_agent.__doc__.replace("Available analyses include:\n", "")
Expand Down
13 changes: 13 additions & 0 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,16 @@ def make_agent_model(agent_names: list[str], primary: bool = False):
FieldInfo(default=..., description=description)
),
)


def make_table_model(tables):
table_model = create_model(
"Table",
chain_of_thought=(str, FieldInfo(
description="A concise, one sentence decision-tree-style analysis on choosing a table."
)),
relevant_table=(Literal[tables], FieldInfo(
description="The most relevant table based on the user query; if none are relevant, select the first."
))
)
return table_model
2 changes: 1 addition & 1 deletion lumen/ai/prompts/plan_agent.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ You are team lead and have to make a plan to solve how to address the user query

Ensure that the plan solves the entire problem, step-by-step and ensure all steps listed in the chain of thought are listed!

If some piece of information is already available to you only call an agent to provide the same piece of information if absolutely necessary, e.g. if 'current_table' is avaible do not call the TableAgent again.
If some piece of information is already available to you only call an agent to provide the same piece of information if absolutely necessary, e.g. if 'current_source' is avaible do not call the SourceAgent again.

You have to choose which of the experts at your disposal should address the problem.

Expand Down
20 changes: 20 additions & 0 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import jinja2
import pandas as pd
import yaml

from lumen.pipeline import Pipeline
from lumen.sources.base import Source
from lumen.sources.duckdb import DuckDBSource

from .config import THIS_DIR, UNRECOVERABLE_ERRORS

Expand Down Expand Up @@ -272,3 +274,21 @@ def report_error(exc: Exception, step: ChatStep):
error_msg = error_msg[:50] + "..."
step.failed_title = error_msg
step.status = "failed"


async def gather_table_sources(available_sources: list[Source]) -> tuple[dict[str, Source], str]:
"""
Get a dictionary of tables to their respective sources
and a markdown string of the tables and their schemas.
"""
tables_to_source = {}
tables_schema_str = "\nHere are the tables\n"
for source in available_sources:
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
else:
tables_schema_str += f"### {table}\n"
return tables_to_source, tables_schema_str
1 change: 0 additions & 1 deletion lumen/command/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
llm=llm,
agents=[
lmai.agents.SourceAgent,
lmai.agents.TableAgent,
lmai.agents.TableListAgent,
lmai.agents.SQLAgent,
lmai.agents.hvPlotAgent,
Expand Down

0 comments on commit 3722bd0

Please sign in to comment.