diff --git a/apps/service_providers/llm_service/runnables.py b/apps/service_providers/llm_service/runnables.py index 4b24e0178..76d821c18 100644 --- a/apps/service_providers/llm_service/runnables.py +++ b/apps/service_providers/llm_service/runnables.py @@ -334,7 +334,8 @@ def _save_response_annotations(self, output, thread_id, run_id) -> tuple[str, di from apps.assistants.sync import get_and_store_openai_file client = self.state.raw_client - generated_files = [] + chat = self.state.session.chat + session_id = self.state.session.id if isinstance(output, str): output_message = output @@ -350,11 +351,15 @@ def _save_response_annotations(self, output, thread_id, run_id) -> tuple[str, di ).values_list("external_id", flat=True) file_ids = set() + image_file_attachments = [] + file_path_attachments = [] for message in client.beta.threads.messages.list(thread_id, run_id=run_id): for message_content in message.content: if message_content.type == "image_file": - # Ignore these for now. Typically, they are also referenced in the text content - pass + if created_file := self._create_image_file_from_image_message(client, message_content.image_file): + image_file_attachments.append(created_file) + file_ids.add(created_file.external_id) + elif message_content.type == "text": annotations = message_content.text.annotations for idx, annotation in enumerate(annotations): @@ -382,22 +387,56 @@ def _save_response_annotations(self, output, thread_id, run_id) -> tuple[str, di # Original citation text example: sandbox:/mnt/data/the_file.csv. # This is the link part in what looks like # [Download the CSV file](sandbox:/mnt/data/the_file.csv) - session_id = self.state.session.id output_message = output_message.replace( file_ref_text, f"file:{team.slug}:{session_id}:{created_file.id}" ) - generated_files.append(created_file) - + file_path_attachments.append(created_file) file_ids.add(file_id) + else: + # Ignore any other type for now + pass # Attach the generated files to the chat object as an annotation - if generated_files: - chat = self.state.session.chat + if file_path_attachments: resource, _created = chat.attachments.get_or_create(tool_type="file_path") - resource.files.add(*generated_files) + resource.files.add(*file_path_attachments) + + if image_file_attachments: + resource, _created = chat.attachments.get_or_create(tool_type="image_file") + resource.files.add(*image_file_attachments) return output_message, list(file_ids) + def _create_image_file_from_image_message(self, client, image_file_message) -> File | None: + """ + Creates a File record from `image_file_message` by pulling the data from OpenAI. Typically, these files don't + have extentions, so we'll need to guess it based on the content. We know it will be an image, but not which + extention to use. + """ + from mimetypes import guess_extension + + import magic + + from apps.assistants.sync import get_and_store_openai_file + + try: + file_id = image_file_message.file_id + openai_file = client.files.retrieve(file_id=file_id) + created_file = get_and_store_openai_file( + client=client, + file_name=f"{openai_file.filename}", + file_id=file_id, + team_id=self.state.experiment.team_id, + ) + mimetype = magic.from_buffer(created_file.file.open().read(), mime=True) + extention = guess_extension(mimetype) + # extention looks like '.png' + created_file.name = f"{created_file.name}{extention}" + created_file.save() + return created_file + except Exception as ex: + logger.exception(ex) + def _get_file_name_and_link_for_citation(self, file_id: str, forbidden_file_ids: list[str]) -> tuple[str, str]: """Returns a file name and a link constructor for `file_id`. If `file_id` is a member of `forbidden_file_ids`, the link will be empty to prevent unauthorized access. diff --git a/apps/service_providers/tests/test_assistant_runnable.py b/apps/service_providers/tests/test_assistant_runnable.py index 4f494d110..ef2795c9a 100644 --- a/apps/service_providers/tests/test_assistant_runnable.py +++ b/apps/service_providers/tests/test_assistant_runnable.py @@ -426,6 +426,51 @@ def test_assistant_response_with_annotations( assert "openai-file-2" in message.metadata["openai_file_ids"] +@pytest.mark.django_db() +@patch("openai.resources.files.Files.retrieve") +@patch("apps.assistants.sync.get_and_store_openai_file") +@patch("openai.resources.beta.threads.runs.Runs.retrieve") +@patch("openai.resources.beta.Threads.create_and_run") +@patch("openai.resources.beta.threads.messages.Messages.list") +def test_assistant_response_with_image_file_content_block( + list_messages, + create_and_run, + retrieve_run, + get_and_store_openai_file, + retrieve_openai_file, + db_session, +): + """ + Test that ImageFileContentBlock entries in the content array in an OpenAI message response saves the file to a new + "image_file" tool type + """ + retrieve_openai_file.return_value = FileObject( + id="local_file_openai_id", + bytes=1, + created_at=1, + filename="3fac0517-6367-4f92-a1f3-c9d9087c9085", + object="file", + purpose="assistants", + status="processed", + status_details=None, + ) + openai_generated_file = FileFactory(external_id="openai-file-1", id=10) + get_and_store_openai_file.return_value = openai_generated_file + + thread_id = "test_thread_id" + run = _create_run(ASSISTANT_ID, thread_id) + list_messages.return_value = _create_thread_messages(ASSISTANT_ID, run.id, thread_id, [{"assistant": "Ola"}]) + create_and_run.return_value = run + retrieve_run.return_value = run + assistant = create_experiment_runnable(db_session.experiment, db_session) + + # Run assistant + result = assistant.invoke("test", attachments=[]) + assert result.output == "Ola" + assert db_session.chat.attachments.filter(tool_type="image_file").exists() is True + assert db_session.chat.attachments.get(tool_type="image_file").files.count() == 1 + + @pytest.mark.parametrize( ("messages", "thread_id", "thread_created", "messages_created"), [ diff --git a/requirements/requirements.in b/requirements/requirements.in index cd9bb926a..2b3b6062f 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -53,3 +53,4 @@ twilio whitenoise[brotli] phonenumberslite emoji +python-magic diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 5e0060143..6a008d4b1 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -18,7 +18,7 @@ annotated-types==0.6.0 # via pydantic anthropic==0.25.2 # via - # -r requirements.in + # -r requirements/requirements.in # langchain-anthropic anyio==4.4.0 # via @@ -28,6 +28,8 @@ anyio==4.4.0 # openai asgiref==3.8.1 # via django +async-timeout==4.0.3 + # via redis attrs==23.1.0 # via # aiohttp @@ -35,14 +37,14 @@ attrs==23.1.0 # referencing # taskbadger azure-cognitiveservices-speech==1.32.1 - # via -r requirements.in + # via -r requirements/requirements.in backoff==2.2.1 # via langfuse billiard==4.2.0 # via celery boto3==1.28.85 # via - # -r requirements.in + # -r requirements/requirements.in # django-storages botocore==1.31.85 # via @@ -52,10 +54,10 @@ brotli==1.1.0 # via whitenoise celery[redis]==5.3.5 # via - # -r requirements.in + # -r requirements/requirements.in # django-celery-beat celery-progress==0.3 - # via -r requirements.in + # via -r requirements/requirements.in certifi==2024.7.4 # via # httpcore @@ -102,7 +104,7 @@ distro==1.8.0 # openai django==5.1 # via - # -r requirements.in + # -r requirements/requirements.in # django-allauth # django-allauth-2fa # django-anymail @@ -123,56 +125,56 @@ django==5.1 # drf-spectacular django-allauth==0.58.2 # via - # -r requirements.in + # -r requirements/requirements.in # django-allauth-2fa django-allauth-2fa==0.11.1 - # via -r requirements.in + # via -r requirements/requirements.in django-anymail==10.2 - # via -r requirements.in + # via -r requirements/requirements.in django-appconf==1.0.5 # via django-cryptography-django5 django-celery-beat==2.7.0 - # via -r requirements.in + # via -r requirements/requirements.in django-cryptography-django5==2.2 - # via -r requirements.in + # via -r requirements/requirements.in django-environ==0.11.2 - # via -r requirements.in + # via -r requirements/requirements.in django-field-audit==1.2.8 - # via -r requirements.in + # via -r requirements/requirements.in django-health-check==3.18.3 - # via -r requirements.in + # via -r requirements/requirements.in django-hijack==3.4.2 - # via -r requirements.in + # via -r requirements/requirements.in django-otp==1.3.0 # via django-allauth-2fa django-redis==5.4.0 - # via -r requirements.in + # via -r requirements/requirements.in django-storages[s3]==1.14.2 - # via -r requirements.in + # via -r requirements/requirements.in django-tables2==2.6.0 - # via -r requirements.in + # via -r requirements/requirements.in django-taggit==5.0.1 - # via -r requirements.in + # via -r requirements/requirements.in django-timezone-field==7.0 # via django-celery-beat django-tz-detect==0.5.0 - # via -r requirements.in + # via -r requirements/requirements.in django-waffle==4.0.0 - # via -r requirements.in + # via -r requirements/requirements.in djangorestframework==3.15.2 # via - # -r requirements.in + # -r requirements/requirements.in # drf-spectacular djangorestframework-api-key==3.0.0 - # via -r requirements.in + # via -r requirements/requirements.in drf-spectacular==0.26.5 - # via -r requirements.in + # via -r requirements/requirements.in emoji==2.12.1 - # via -r requirements.in + # via -r requirements/requirements.in fbmessenger==6.0.0 - # via -r requirements.in + # via -r requirements/requirements.in ffmpeg==1.4 - # via -r requirements.in + # via -r requirements/requirements.in filelock==3.13.1 # via # huggingface-hub @@ -191,7 +193,7 @@ httpcore==0.17.3 # via httpx httpx==0.24.1 # via - # -r requirements.in + # -r requirements/requirements.in # anthropic # langfuse # openai @@ -210,7 +212,7 @@ idna==3.7 inflection==0.5.1 # via drf-spectacular jinja2==3.1.4 - # via -r requirements.in + # via -r requirements/requirements.in jmespath==1.0.1 # via # boto3 @@ -228,9 +230,9 @@ jsonschema-specifications==2023.7.1 kombu==5.3.3 # via celery langchain==0.1.16 - # via -r requirements.in + # via -r requirements/requirements.in langchain-anthropic==0.1.8 - # via -r requirements.in + # via -r requirements/requirements.in langchain-community==0.0.32 # via langchain langchain-core==0.1.42 @@ -242,22 +244,22 @@ langchain-core==0.1.42 # langchain-text-splitters # langgraph langchain-openai==0.1.3 - # via -r requirements.in + # via -r requirements/requirements.in langchain-text-splitters==0.0.1 # via langchain langfuse==2.43.3 - # via -r requirements.in + # via -r requirements/requirements.in langgraph==0.0.38 - # via -r requirements.in + # via -r requirements/requirements.in langsmith==0.1.47 # via # langchain # langchain-community # langchain-core loguru==0.7.2 - # via -r requirements.in + # via -r requirements/requirements.in markdown==3.5.1 - # via -r requirements.in + # via -r requirements/requirements.in markdown-it-py==3.0.0 # via rich markupsafe==2.1.5 @@ -282,7 +284,7 @@ oauthlib==3.2.2 # via requests-oauthlib openai==1.23.6 # via - # -r requirements.in + # -r requirements/requirements.in # langchain-openai orjson==3.10.0 # via langsmith @@ -295,20 +297,20 @@ packaging==23.2 # marshmallow # transformers pandas==2.1.3 - # via -r requirements.in + # via -r requirements/requirements.in phonenumberslite==8.13.40 - # via -r requirements.in + # via -r requirements/requirements.in prompt-toolkit==3.0.41 # via click-repl psycopg[binary]==3.2.1 - # via -r requirements.in + # via -r requirements/requirements.in psycopg-binary==3.2.1 # via psycopg pycparser==2.21 # via cffi pydantic==2.5.0 # via - # -r requirements.in + # -r requirements/requirements.in # anthropic # langchain # langchain-core @@ -318,7 +320,7 @@ pydantic==2.5.0 pydantic-core==2.14.1 # via pydantic pydub==0.25.1 - # via -r requirements.in + # via -r requirements/requirements.in pygments==2.16.1 # via rich pyjwt[crypto]==2.8.0 @@ -328,7 +330,7 @@ pyjwt[crypto]==2.8.0 pypng==0.20220715.0 # via qrcode pytelegrambotapi==4.12.0 - # via -r requirements.in + # via -r requirements/requirements.in python-crontab==3.0.0 # via django-celery-beat python-dateutil==2.8.2 @@ -338,6 +340,8 @@ python-dateutil==2.8.2 # pandas # python-crontab # taskbadger +python-magic==0.4.27 + # via -r requirements/requirements.in python3-openid==3.2.0 # via django-allauth pytz==2023.3.post1 @@ -393,13 +397,13 @@ s3transfer==0.7.0 safetensors==0.4.3 # via transformers sentry-sdk==2.8.0 - # via -r requirements.in + # via -r requirements/requirements.in shellingham==1.5.4 # via typer six==1.16.0 # via python-dateutil slack-bolt==1.18.1 - # via -r requirements.in + # via -r requirements/requirements.in slack-sdk==3.27.2 # via slack-bolt sniffio==1.3.0 @@ -416,16 +420,16 @@ sqlalchemy==2.0.23 sqlparse==0.5.0 # via django taskbadger==1.3.3 - # via -r requirements.in + # via -r requirements/requirements.in tenacity==8.2.3 # via - # -r requirements.in + # -r requirements/requirements.in # langchain # langchain-community # langchain-core tiktoken==0.7.0 # via - # -r requirements.in + # -r requirements/requirements.in # langchain-openai tokenizers==0.15.0 # via @@ -439,11 +443,11 @@ tqdm==4.66.3 # openai # transformers transformers==4.39.3 - # via -r requirements.in + # via -r requirements/requirements.in turn-python==0.2.0 - # via -r requirements.in + # via -r requirements/requirements.in twilio==8.10.1 - # via -r requirements.in + # via -r requirements/requirements.in typer[all]==0.9.0 # via taskbadger typing-extensions==4.8.0 @@ -482,7 +486,7 @@ vine==5.1.0 wcwidth==0.2.10 # via prompt-toolkit whitenoise[brotli]==6.6.0 - # via -r requirements.in + # via -r requirements/requirements.in wrapt==1.16.0 # via langfuse yarl==1.9.2 diff --git a/templates/experiments/chat/ai_message.html b/templates/experiments/chat/ai_message.html index 2a34bc452..e9d4c79f2 100644 --- a/templates/experiments/chat/ai_message.html +++ b/templates/experiments/chat/ai_message.html @@ -17,5 +17,22 @@

{{ message.content|render_markdown }}

+
+ {% for file in message.get_attached_files %} +
+ {% if request.user.is_authenticated %} + + {{ file.name }} + + {% else %} +
+ {{ file.name }} +
+ {% endif %} +
+ {% endfor %} +