Skip to content

Commit

Permalink
Merge pull request #46 from Fallenbagel/fix-photo-processing-from-groups
Browse files Browse the repository at this point in the history
fix: image processing from groups/super-groups
  • Loading branch information
ruecat committed May 3, 2024
2 parents 279fac5 + 5d9cd70 commit 43b6498
Showing 1 changed file with 59 additions and 34 deletions.
93 changes: 59 additions & 34 deletions bot/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from aiogram.types import Message
from aiogram.utils.keyboard import InlineKeyboardBuilder
from func.functions import *

# Other
import asyncio
import traceback
import io
import base64

bot = Bot(token=token)
dp = Dispatcher()
builder = InlineKeyboardBuilder()
Expand All @@ -33,15 +35,19 @@
CHAT_TYPE_GROUP = "group"
CHAT_TYPE_SUPERGROUP = "supergroup"


def is_mentioned_in_group_or_supergroup(message):
return (message.chat.type in [CHAT_TYPE_GROUP, CHAT_TYPE_SUPERGROUP]
and message.text.startswith(mention))
return message.chat.type in [CHAT_TYPE_GROUP, CHAT_TYPE_SUPERGROUP] and (
(message.text is not None and message.text.startswith(mention))
or (message.caption is not None and message.caption.startswith(mention))
)


async def get_bot_info():
global mention
if mention is None:
get = await bot.get_me()
mention = (f"@{get.username}")
mention = f"@{get.username}"
return mention


Expand Down Expand Up @@ -102,7 +108,9 @@ async def modelmanager_callback_handler(query: types.CallbackQuery):
if model["details"]["families"]:
modelicon = {"llama": "🦙", "clip": "📷"}
try:
modelfamilies = "".join([modelicon[family] for family in model['details']['families']])
modelfamilies = "".join(
[modelicon[family] for family in model["details"]["families"]]
)
except KeyError as e:
# Use a default value when the key is not found
modelfamilies = f"✨"
Expand All @@ -113,7 +121,8 @@ async def modelmanager_callback_handler(query: types.CallbackQuery):
)
)
await query.message.edit_text(
f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal", reply_markup=modelmanager_builder.as_markup()
f"{len(models)} models available.\n🦙 = Regular\n🦙📷 = Multimodal",
reply_markup=modelmanager_builder.as_markup(),
)


Expand Down Expand Up @@ -147,30 +156,32 @@ async def handle_message(message: types.Message):
await ollama_request(message)
if is_mentioned_in_group_or_supergroup(message):
# Remove the mention from the message
text_without_mention = message.text.replace(mention, "").strip()
if message.text is not None:
text_without_mention = message.text.replace(mention, "").strip()
prompt = text_without_mention
else:
text_without_mention = message.caption.replace(mention, "").strip()
prompt = text_without_mention

# Pass the modified text and bot instance to ollama_request
await ollama_request(types.Message(
message_id=message.message_id,
from_user=message.from_user,
date=message.date,
chat=message.chat,
text=text_without_mention
))
await ollama_request(message, prompt)


...
async def ollama_request(message: types.Message):


async def ollama_request(message: types.Message, prompt: str = None):
try:
await bot.send_chat_action(message.chat.id, "typing")
prompt = message.text or message.caption
image_base64 = ''
if message.content_type == 'photo':
image_base64 = ""
if message.content_type == "photo":
image_buffer = io.BytesIO()
await bot.download(
message.photo[-1],
destination=image_buffer
)
image_base64 = base64.b64encode(image_buffer.getvalue()).decode('utf-8')
await bot.download(message.photo[-1], destination=image_buffer)
image_base64 = base64.b64encode(image_buffer.getvalue()).decode("utf-8")

if prompt is None:
prompt = message.text or message.caption

full_response = ""
sent_message = None
last_sent_text = None
Expand All @@ -180,12 +191,22 @@ async def ollama_request(message: types.Message):
if ACTIVE_CHATS.get(message.from_user.id) is None:
ACTIVE_CHATS[message.from_user.id] = {
"model": modelname,
"messages": [{"role": "user", "content": prompt, "images": ([image_base64] if image_base64 else [])}],
"messages": [
{
"role": "user",
"content": prompt,
"images": ([image_base64] if image_base64 else []),
}
],
"stream": True,
}
else:
ACTIVE_CHATS[message.from_user.id]["messages"].append(
{"role": "user", "content": prompt, "images": ([image_base64] if image_base64 else [])}
{
"role": "user",
"content": prompt,
"images": ([image_base64] if image_base64 else []),
}
)
logging.info(
f"[Request]: Processing '{prompt}' for {message.from_user.first_name} {message.from_user.last_name}"
Expand All @@ -206,8 +227,11 @@ async def ollama_request(message: types.Message):
if "." in chunk or "\n" in chunk or "!" in chunk or "?" in chunk:
if sent_message:
if last_sent_text != full_response_stripped:
await bot.edit_message_text(chat_id=message.chat.id, message_id=sent_message.message_id,
text=full_response_stripped)
await bot.edit_message_text(
chat_id=message.chat.id,
message_id=sent_message.message_id,
text=full_response_stripped,
)
last_sent_text = full_response_stripped
else:
sent_message = await bot.send_message(
Expand All @@ -218,16 +242,17 @@ async def ollama_request(message: types.Message):
last_sent_text = full_response_stripped

if response_data.get("done"):
if (
full_response_stripped
and last_sent_text != full_response_stripped
):
if full_response_stripped and last_sent_text != full_response_stripped:
if sent_message:
await bot.edit_message_text(chat_id=message.chat.id, message_id=sent_message.message_id,
text=full_response_stripped)
await bot.edit_message_text(
chat_id=message.chat.id,
message_id=sent_message.message_id,
text=full_response_stripped,
)
else:
sent_message = await bot.send_message(chat_id=message.chat.id,
text=full_response_stripped)
sent_message = await bot.send_message(
chat_id=message.chat.id, text=full_response_stripped
)
await bot.edit_message_text(
chat_id=message.chat.id,
message_id=sent_message.message_id,
Expand Down

0 comments on commit 43b6498

Please sign in to comment.