Skip to content

Commit

Permalink
AssistantAgent no longer sends out StopMessage. We use TextMentionTer…
Browse files Browse the repository at this point in the history
…mination(TERMINATE) on the team instead for default setting. (microsoft#4030)
  • Loading branch information
frances720 committed Nov 4, 2024
1 parent 173acc6 commit 0cadeb6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
Expand Down Expand Up @@ -232,8 +231,8 @@ def __init__(
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
return [TextMessage, HandoffMessage]
return [TextMessage]

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -303,16 +302,9 @@ async def on_messages_stream(
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
28 changes: 14 additions & 14 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Expand Down Expand Up @@ -151,7 +151,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
expected_messages = [
"Write a program that prints 'Hello, world!'",
Expand All @@ -172,7 +172,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -247,7 +247,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)

assert len(result.messages) == 6
Expand All @@ -256,7 +256,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response

context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -275,7 +275,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -351,7 +351,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -366,7 +366,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -401,7 +401,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -417,7 +417,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -472,7 +472,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
allow_repeated_speaker=True,
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
)
assert len(result.messages) == 4
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -484,7 +484,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -649,7 +649,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
)
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
result = await team.run("task", termination_condition=TextMentionTermination("TERMINATE"))
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage)
Expand All @@ -663,7 +663,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
stream = team.run_stream("task", termination_condition=TextMentionTermination("TERMINATE"))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
Expand Down

0 comments on commit 0cadeb6

Please sign in to comment.