Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Sep 19, 2024
1 parent 1bdfdb3 commit fee4a70
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,16 +341,15 @@ def _get_stream_response(
:returns: The extracted response with the content of all streaming chunks.
"""
replies: Union[List[str], List[ChatMessage]] = []
metadata = stream.to_dict()

for candidate in metadata.get("candidates", []):
candidate.pop("content", None)

for chunk in stream:
metadata = chunk.to_dict()
for candidate in chunk.candidates:
for part in candidate.content.parts:
if part.text != "":
replies.append(part.text)
replies.append(
ChatMessage(content=part.text, role=ChatRole.ASSISTANT, meta=metadata, name=None)
)
elif part.function_call is not None:
metadata["function_call"] = part.function_call
replies.append(
Expand All @@ -363,9 +362,4 @@ def _get_stream_response(
)

streaming_callback(StreamingChunk(content=part.text, meta=chunk.to_dict()))

if isinstance(replies[0], ChatMessage):
return replies

combined_response = "".join(replies).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
return replies
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
response = gemini_chat.run(messages=messages)
assert "replies" in response
assert len(response["replies"]) > 0
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])
assert all(reply.role == ChatRole.SYSTEM for reply in response["replies"])


@pytest.mark.skipif(not os.environ.get("GOOGLE_API_KEY", None), reason="GOOGLE_API_KEY env var not set")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
replies = []
for candidate in response_body.candidates:
metadata = candidate.to_dict()
metadata.pop("content")
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(
Expand Down Expand Up @@ -260,11 +259,22 @@ def _get_stream_response(
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
replies = []
for chunk in stream:
metadata = chunk.to_dict()
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_assistant(content=combined_response)]
if chunk.text != "":
replies.append(ChatMessage(chunk.text, role=ChatRole.ASSISTANT, name=None, meta=metadata))
elif chunk.function_call is not None:
metadata["function_call"] = chunk.function_call
replies.append(
ChatMessage(
content=dict(chunk.function_call.args.items()),
role=ChatRole.ASSISTANT,
name=chunk.function_call.name,
meta=metadata,
)
)
return replies
2 changes: 0 additions & 2 deletions integrations/google_vertex/tests/chat/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,6 @@ def streaming_callback(chunk: StreamingChunk) -> None:
assert streaming_callback_called == ["First part", " Second part"]
assert "replies" in response
assert len(response["replies"]) > 0

assert response["replies"][0].content == "First part Second part"
assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"])


Expand Down

0 comments on commit fee4a70

Please sign in to comment.