Skip to content

Commit

Permalink
Add embed instruction option
Browse files Browse the repository at this point in the history
  • Loading branch information
HamadaSalhab committed Oct 19, 2024
1 parent 8720d46 commit ea49775
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 145 deletions.
8 changes: 8 additions & 0 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class CreateDocRequest(BaseModel):
"""
Contents of the document
"""
embed_instruction: str | None = None
"""
Instruction for the embedding model.
"""


class Doc(BaseModel):
Expand Down Expand Up @@ -113,6 +117,10 @@ class EmbedQueryRequest(BaseModel):
"""
Text or texts to embed
"""
embed_instruction: str | None = None
"""
Instruction for the embedding model.
"""


class EmbedQueryResponse(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/models/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def create_doc(
data.metadata = data.metadata or {}

doc_data = data.model_dump()
doc_data.pop("embed_instruction", None)
content = doc_data.pop("content")

doc_data["owner_type"] = owner_type
Expand Down
9 changes: 5 additions & 4 deletions agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from temporalio.client import Client as TemporalClient

from ...activities.types import EmbedDocsPayload
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...autogen.openapi_model import CreateDocRequest, Doc, ResourceCreatedResponse
from ...clients import temporal
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...dependencies.developer_id import get_developer_id
Expand All @@ -21,6 +21,7 @@ async def run_embed_docs_task(
doc_id: UUID,
title: str,
content: list[str],
embed_instruction: str | None = None,
job_id: UUID,
background_tasks: BackgroundTasks,
client: TemporalClient | None = None,
Expand All @@ -34,7 +35,7 @@ async def run_embed_docs_task(
doc_id=doc_id,
content=content,
title=title,
embed_instruction=None,
embed_instruction=embed_instruction,
)

handle = await client.start_workflow(
Expand All @@ -60,7 +61,7 @@ async def create_user_doc(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
doc: Doc = create_doc_query(
developer_id=x_developer_id,
owner_type="user",
owner_id=user_id,
Expand Down Expand Up @@ -90,7 +91,7 @@ async def create_agent_doc(
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
background_tasks: BackgroundTasks,
) -> ResourceCreatedResponse:
doc = create_doc_query(
doc: Doc = create_doc_query(
developer_id=x_developer_id,
owner_type="agent",
owner_id=agent_id,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ async def embed(
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await litellm.aembedding(inputs=text_to_embed)
vectors = await litellm.aembedding(inputs=data.embed_instruction + text_to_embed)

return EmbedQueryResponse(vectors=vectors)
2 changes: 1 addition & 1 deletion agents-api/notebooks/RecSum-experiments.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"\n",
"\n",
"def chat():\n",
" while (user_input := input(\"You: \").lower()) != \"bye\": \n",
" while (user_input := input(\"You: \").lower()) != \"bye\":\n",
" chat_session.append(user(user_input))\n",
"\n",
" result = generate(chat_session)\n",
Expand Down
144 changes: 71 additions & 73 deletions agents-api/notebooks/main-3-Copy1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@
"metadata": {},
"outputs": [],
"source": [
"#Keys\n",
"api_key = 'eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9.eyJleHAiOjE3MzM2OTkxOTEsImlhdCI6MTcyODUxNTE5MSwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIn0.OOiS_MkP1QEOMQ2Gs13JeZsFCPkR-ldbNtedK9iS3qIxSN_fSPGzajcdbLtedZZYD9OwMsBg4sKvmkeyrBti9w'\n",
"environment = 'local_multi_tenant'"
"# Keys\n",
"api_key = \"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9.eyJleHAiOjE3MzM2OTkxOTEsImlhdCI6MTcyODUxNTE5MSwic3ViIjoiMDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwIn0.OOiS_MkP1QEOMQ2Gs13JeZsFCPkR-ldbNtedK9iS3qIxSN_fSPGzajcdbLtedZZYD9OwMsBg4sKvmkeyrBti9w\"\n",
"environment = \"local_multi_tenant\""
]
},
{
Expand All @@ -174,7 +174,7 @@
"metadata": {},
"outputs": [],
"source": [
"client = Julep(api_key = api_key, environment = environment)"
"client = Julep(api_key=api_key, environment=environment)"
]
},
{
Expand Down Expand Up @@ -2607,16 +2607,16 @@
" \"min_p\": 0.05,\n",
" \"presence_penalty\": 0.2,\n",
" \"frequency_penalty\": 0.2,\n",
" \"length_penalty\": 1.0\n",
" \"length_penalty\": 1.0,\n",
"}\n",
"\n",
"\n",
"# Create the agent\n",
"agent = client.agents.create_or_update(\n",
" agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" name=name,\n",
" about=about,\n",
" model=\"gpt-4o\"\n",
" model=\"gpt-4o\",\n",
")"
]
},
Expand All @@ -2639,9 +2639,9 @@
],
"source": [
"client.agents.docs.create(\n",
" agent_id = '847d03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" title = \"NewsTitles\",\n",
" content = articles[:25000] \n",
" agent_id=\"847d03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" title=\"NewsTitles\",\n",
" content=articles[:25000],\n",
")"
]
},
Expand All @@ -2664,8 +2664,8 @@
],
"source": [
"client.agents.docs.search(\n",
" agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" text = \"nascar-news-kyle-busch-denny-hamlin-brad-keselowski-fans-beef-over-their-superstars-filling-the-void-in-nascar-left-by-legendary-cup-champ\"\n",
" agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" text=\"nascar-news-kyle-busch-denny-hamlin-brad-keselowski-fans-beef-over-their-superstars-filling-the-void-in-nascar-left-by-legendary-cup-champ\",\n",
")"
]
},
Expand All @@ -2687,9 +2687,7 @@
}
],
"source": [
"client.agents.docs.list(\n",
" agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d'\n",
")"
"client.agents.docs.list(agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\")"
]
},
{
Expand All @@ -2716,21 +2714,28 @@
],
"source": [
"for i in tqdm(range(len(user_data))):\n",
" \n",
" client.users.create(\n",
" name = f\"user_{i}\",\n",
" name=f\"user_{i}\",\n",
" metadata={\n",
" \"ppid\": str(user_data[i][\"metadata\"][\"ppid\"]), # Ensure it's a native int\n",
" \"age\": int(user_data[i][\"metadata\"][\"age\"]), # Ensure it's a native int\n",
" \"state\": str(user_data[i][\"metadata\"][\"state\"]), # Convert to string if not already\n",
" \"city\": str(user_data[i][\"metadata\"][\"city\"]), # Convert to string if not already\n",
" \"sports_likes\": str(user_data[i][\"metadata\"][\"sports_likes\"]), # Convert to string or json serializable format\n",
" \"entity_likes\": str(user_data[i][\"metadata\"][\"entity_likes\"]), # Same as above\n",
" \"state\": str(\n",
" user_data[i][\"metadata\"][\"state\"]\n",
" ), # Convert to string if not already\n",
" \"city\": str(\n",
" user_data[i][\"metadata\"][\"city\"]\n",
" ), # Convert to string if not already\n",
" \"sports_likes\": str(\n",
" user_data[i][\"metadata\"][\"sports_likes\"]\n",
" ), # Convert to string or json serializable format\n",
" \"entity_likes\": str(\n",
" user_data[i][\"metadata\"][\"entity_likes\"]\n",
" ), # Same as above\n",
" \"top_sport\": str(user_data[i][\"metadata\"][\"top_sport\"]),\n",
" \"top_entity\": str(user_data[i][\"metadata\"][\"top_entity\"]),\n",
" \"latest_sport_read\": str(user_data[i][\"metadata\"][\"latest_sport_read\"]),\n",
" \"top_sources\": str(user_data[i][\"metadata\"][\"top_sources\"])\n",
" }\n",
" \"top_sources\": str(user_data[i][\"metadata\"][\"top_sources\"]),\n",
" },\n",
" )"
]
},
Expand Down Expand Up @@ -2765,7 +2770,7 @@
"\n",
"# Create the agent\n",
"agent = client.agents.create_or_update(\n",
" agent_id = '844e03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" agent_id=\"844e03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" name=name,\n",
" about=about,\n",
" model=\"gpt-4o\",\n",
Expand Down Expand Up @@ -2865,8 +2870,7 @@
" # metadata:\n",
" # user_persona: true\n",
" content: _\n",
"\"\"\"\n",
" )"
"\"\"\")"
]
},
{
Expand All @@ -2878,9 +2882,9 @@
"source": [
"# Creating/Updating a task\n",
"task = client.tasks.create_or_update(\n",
" task_id= '813a03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" agent_id= '844e03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" **task_def\n",
" task_id=\"813a03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" agent_id=\"844e03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" **task_def,\n",
")"
]
},
Expand All @@ -2899,24 +2903,24 @@
}
],
"source": [
"for i in tqdm(range(10)): # range(len(user_data))):\n",
" ppid = str(user_data[i][\"metadata\"]['ppid']) \n",
" # filtered_df1 = df1[df1['ppid'] == ppid ] \n",
" \n",
" # if not filtered_df1.empty: \n",
" # Creating an Execution\n",
"for i in tqdm(range(10)): # range(len(user_data))):\n",
" ppid = str(user_data[i][\"metadata\"][\"ppid\"])\n",
" # filtered_df1 = df1[df1['ppid'] == ppid ]\n",
"\n",
" # if not filtered_df1.empty:\n",
" # Creating an Execution\n",
" execution = client.executions.create(\n",
" task_id= '813a03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" input = {\n",
" 'user_ppid': ppid,\n",
" 'articles_read' : [\n",
" 'nascar-news-they-didnt-wait-for-tony-stewart',\n",
" 'nfl-ncaa-news-privilege-to-love-travis-hunters',\n",
" 'nascar-news-stuck-in-a-thirty-million-hole-dal',\n",
" 'nascar-news-that-shouldnt-take-away-my-playoff',\n",
" 'nascar-news-kyle-busch-denny-hamlin-brad-kesel',\n",
" ]\n",
" }\n",
" task_id=\"813a03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" input={\n",
" \"user_ppid\": ppid,\n",
" \"articles_read\": [\n",
" \"nascar-news-they-didnt-wait-for-tony-stewart\",\n",
" \"nfl-ncaa-news-privilege-to-love-travis-hunters\",\n",
" \"nascar-news-stuck-in-a-thirty-million-hole-dal\",\n",
" \"nascar-news-that-shouldnt-take-away-my-playoff\",\n",
" \"nascar-news-kyle-busch-denny-hamlin-brad-kesel\",\n",
" ],\n",
" },\n",
" )"
]
},
Expand Down Expand Up @@ -2959,9 +2963,7 @@
}
],
"source": [
"client.executions.transitions.list(\n",
" execution_id = '8e69fbc6-7f10-4067-9b81-4b2015508628'\n",
")"
"client.executions.transitions.list(execution_id=\"8e69fbc6-7f10-4067-9b81-4b2015508628\")"
]
},
{
Expand Down Expand Up @@ -2994,11 +2996,11 @@
"\n",
"# Create the agent\n",
"agent = client.agents.create_or_update(\n",
" agent_id = '865e03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" agent_id=\"865e03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" name=name,\n",
" about=about,\n",
" model=\"gpt-4o\",\n",
") "
")"
]
},
{
Expand All @@ -3008,10 +3010,9 @@
"metadata": {},
"outputs": [],
"source": [
"for docs in client.agents.docs.list(agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d'):\n",
"for docs in client.agents.docs.list(agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\"):\n",
" client.agents.docs.delete(\n",
" agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" doc_id = docs.id\n",
" agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\", doc_id=docs.id\n",
" )"
]
},
Expand All @@ -3033,7 +3034,7 @@
}
],
"source": [
"df2 = pd.read_csv('new_titles.csv')\n",
"df2 = pd.read_csv(\"new_titles.csv\")\n",
"df2.shape"
]
},
Expand All @@ -3053,12 +3054,10 @@
],
"source": [
"for i in tqdm(range(len(df2))):\n",
" title = df2.iloc[i]['post_name']\n",
" \n",
" title = df2.iloc[i][\"post_name\"]\n",
"\n",
" client.agents.docs.create(\n",
" agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" title = title,\n",
" content = title \n",
" agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\", title=title, content=title\n",
" )"
]
},
Expand All @@ -3080,7 +3079,7 @@
}
],
"source": [
"client.agents.docs.list(agent_id = '847b03b1-856a-4ae1-a1f5-ad994ba5c87d')"
"client.agents.docs.list(agent_id=\"847b03b1-856a-4ae1-a1f5-ad994ba5c87d\")"
]
},
{
Expand Down Expand Up @@ -3204,9 +3203,9 @@
"source": [
"# Creating/Updating a task\n",
"task1 = client.tasks.create_or_update(\n",
" task_id = '825a03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" agent_id = '865e03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" **task_def1\n",
" task_id=\"825a03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" agent_id=\"865e03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" **task_def1,\n",
")"
]
},
Expand Down Expand Up @@ -3257,17 +3256,16 @@
"source": [
"execution_array = []\n",
"for i in tqdm(range(len(df))):\n",
" ppid = str(df.iloc[i]['ppid']) \n",
" \n",
" ppid = str(df.iloc[i][\"ppid\"])\n",
"\n",
" # Creating an Execution\n",
" execution = client.executions.create(\n",
" task_id= '825a03b1-856a-4ae1-a1f5-ad994ba5c87d',\n",
" input = {\n",
" 'user_ppid': ppid,\n",
" }\n",
" task_id=\"825a03b1-856a-4ae1-a1f5-ad994ba5c87d\",\n",
" input={\n",
" \"user_ppid\": ppid,\n",
" },\n",
" )\n",
" execution_array.append(execution.id)\n",
" "
" execution_array.append(execution.id)"
]
},
{
Expand Down Expand Up @@ -4309,7 +4307,7 @@
}
],
"source": [
"client.executions.get(execution_id = 'caed8408-235c-4a36-91fe-0559b91ece8d')"
"client.executions.get(execution_id=\"caed8408-235c-4a36-91fe-0559b91ece8d\")"
]
},
{
Expand Down
Loading

0 comments on commit ea49775

Please sign in to comment.