From b0d6fd52377abcb9378d74f2fa051145280f07e6 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Thu, 26 Sep 2024 14:46:55 +0200 Subject: [PATCH 1/2] handle image file types from openai [DIMAGI-BOTS-F7](https://dimagi.sentry.io/issues/5907069447/) --- .../llm_service/runnables.py | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/apps/service_providers/llm_service/runnables.py b/apps/service_providers/llm_service/runnables.py index 40106bd1e..8532c4dc4 100644 --- a/apps/service_providers/llm_service/runnables.py +++ b/apps/service_providers/llm_service/runnables.py @@ -349,38 +349,43 @@ def _save_response_annotations(self, output, thread_id, run_id) -> tuple[str, di file_ids = set() for message in client.beta.threads.messages.list(thread_id, run_id=run_id): for message_content in message.content: - annotations = message_content.text.annotations - for idx, annotation in enumerate(annotations): - file_id = None - file_ref_text = annotation.text - if annotation.type == "file_citation": - file_citation = annotation.file_citation - file_id = file_citation.file_id - file_name, file_link = self._get_file_name_and_link_for_citation( - file_id=file_id, forbidden_file_ids=assistant_files_ids - ) - - # Original citation text example:【6:0†source】 - output_message = output_message.replace(file_ref_text, f" [{file_name}]({file_link})") - - elif annotation.type == "file_path": - file_path = annotation.file_path - file_id = file_path.file_id - created_file = get_and_store_openai_file( - client=client, - file_name=annotation.text.split("/")[-1], - file_id=file_id, - team_id=self.state.experiment.team_id, - ) - # Original citation text example: sandbox:/mnt/data/the_file.csv. This is the link part in what - # looks like [Download the CSV file](sandbox:/mnt/data/the_file.csv) - session_id = self.state.session.id - output_message = output_message.replace( - file_ref_text, f"file:{team.slug}:{session_id}:{created_file.id}" - ) - generated_files.append(created_file) - - file_ids.add(file_id) + if message_content.type == "image_file": + # Ignore these for now. Typically, they are also referenced in the text content + pass + elif message_content.type == "text": + annotations = message_content.text.annotations + for idx, annotation in enumerate(annotations): + file_id = None + file_ref_text = annotation.text + if annotation.type == "file_citation": + file_citation = annotation.file_citation + file_id = file_citation.file_id + file_name, file_link = self._get_file_name_and_link_for_citation( + file_id=file_id, forbidden_file_ids=assistant_files_ids + ) + + # Original citation text example:【6:0†source】 + output_message = output_message.replace(file_ref_text, f" [{file_name}]({file_link})") + + elif annotation.type == "file_path": + file_path = annotation.file_path + file_id = file_path.file_id + created_file = get_and_store_openai_file( + client=client, + file_name=annotation.text.split("/")[-1], + file_id=file_id, + team_id=self.state.experiment.team_id, + ) + # Original citation text example: sandbox:/mnt/data/the_file.csv. + # This is the link part in what looks like + # [Download the CSV file](sandbox:/mnt/data/the_file.csv) + session_id = self.state.session.id + output_message = output_message.replace( + file_ref_text, f"file:{team.slug}:{session_id}:{created_file.id}" + ) + generated_files.append(created_file) + + file_ids.add(file_id) # Attach the generated files to the chat object as an annotation if generated_files: From aa91862b589b4ab1b5959bba0ae32adbcbf7d1c1 Mon Sep 17 00:00:00 2001 From: Simon Kelly Date: Thu, 26 Sep 2024 15:27:31 +0200 Subject: [PATCH 2/2] update tests --- apps/service_providers/llm_service/runnables.py | 7 +++++-- apps/service_providers/tests/test_assistant_runnable.py | 8 ++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/apps/service_providers/llm_service/runnables.py b/apps/service_providers/llm_service/runnables.py index 8532c4dc4..4b24e0178 100644 --- a/apps/service_providers/llm_service/runnables.py +++ b/apps/service_providers/llm_service/runnables.py @@ -336,8 +336,11 @@ def _save_response_annotations(self, output, thread_id, run_id) -> tuple[str, di client = self.state.raw_client generated_files = [] - # This output is a concatanation of all messages in this run - output_message = output + if isinstance(output, str): + output_message = output + else: + output_message = "\n".join(content.text.value for content in output if content.type == "text") + team = self.state.session.team assistant_file_ids = ToolResources.objects.filter(assistant=self.state.experiment.assistant).values_list( "files" diff --git a/apps/service_providers/tests/test_assistant_runnable.py b/apps/service_providers/tests/test_assistant_runnable.py index 4b52d3e5a..4f494d110 100644 --- a/apps/service_providers/tests/test_assistant_runnable.py +++ b/apps/service_providers/tests/test_assistant_runnable.py @@ -5,8 +5,8 @@ import openai import pytest +from openai.types.beta.threads import ImageFile, ImageFileContentBlock, Run from openai.types.beta.threads import Message as ThreadMessage -from openai.types.beta.threads import Run from openai.types.beta.threads.file_citation_annotation import FileCitation, FileCitationAnnotation from openai.types.beta.threads.file_path_annotation import FilePath, FilePathAnnotation from openai.types.beta.threads.text import Text @@ -493,10 +493,14 @@ def _create_thread_messages( metadata={}, created_at=0, content=[ + ImageFileContentBlock( + type="image_file", + image_file=ImageFile(file_id="test_file_id"), + ), TextContentBlock( text=Text(annotations=annotations if annotations else [], value=list(message.values())[0]), type="text", - ) + ), ], object="thread.message", role=list(message)[0],