From 3497f963fbcab8e93486f6ec2d07519fb1ebdc9a Mon Sep 17 00:00:00 2001 From: xds Date: Tue, 3 Feb 2026 09:14:45 +0300 Subject: [PATCH] + fixes --- middlewares/album.py | 48 +++++++++--------- models/Character.py | 5 +- repos/char_repo.py | 6 ++- routers/char_router.py | 34 ++++++++----- routers/gen_router.py | 110 ++++++++++++++++++++++++++--------------- 5 files changed, 122 insertions(+), 81 deletions(-) diff --git a/middlewares/album.py b/middlewares/album.py index b471316..fbc12a6 100644 --- a/middlewares/album.py +++ b/middlewares/album.py @@ -6,7 +6,6 @@ from aiogram.types import Message class AlbumMiddleware(BaseMiddleware): def __init__(self, latency: float = 0.5): - # latency - задержка в секундах для сбора частей альбома self.latency = latency self.album_data: Dict[str, List[Message]] = {} @@ -16,36 +15,33 @@ class AlbumMiddleware(BaseMiddleware): event: Message, data: Dict[str, Any] ) -> Any: - # Если у сообщения нет media_group_id, это не альбом -> пропускаем дальше как обычно if not event.media_group_id: return await handler(event, data) group_id = event.media_group_id - try: - # Если этот альбом мы еще не видели (первое сообщение из пачки) - if group_id not in self.album_data: - self.album_data[group_id] = [event] # Создаем список - await asyncio.sleep(self.latency) # Ждем остальные части + # Если это первое сообщение группы + if group_id not in self.album_data: + self.album_data[group_id] = [event] + try: + # Ждем сбора остальных частей + await asyncio.sleep(self.latency) - # После ожидания кладем собранный список в data - # Теперь в хендлере будет доступен аргумент 'album' - data["album"] = self.album_data[group_id] + # Проверяем, что ключ все еще существует (на всякий случай) + if group_id in self.album_data: + # Передаем собранный альбом в хендлер + # Сортируем по message_id, чтобы порядок был верным + self.album_data[group_id].sort(key=lambda x: x.message_id) + data["album"] = self.album_data[group_id] + return await handler(event, data) - # Вызываем хендлер ОДИН раз - return await handler(event, data) + finally: + # ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись + # Проверяем, что мы удаляем именно то, что создали, и ключ существует + if group_id in self.album_data and self.album_data[group_id][0] == event: + del self.album_data[group_id] - else: - # Если альбом уже собирается, просто добавляем сообщение в список - # и НЕ вызываем хендлер (прерываем цепочку для этого сообщения) - self.album_data[group_id].append(event) - return - - finally: - # Чистим память после обработки, если это был "главный" поток обработки - if group_id in self.album_data and len(self.album_data[group_id]) > 1: - # Маленький хак: удаляем только если обработчик завершился - # Проверка len нужна, чтобы не удалить раньше времени в параллельных тасках, - # но корректнее просто удалять в блоке первого сообщения. - if event == self.album_data[group_id][0]: - del self.album_data[group_id] \ No newline at end of file + else: + # Если группа уже собирается - просто добавляем и выходим + self.album_data[group_id].append(event) + return \ No newline at end of file diff --git a/models/Character.py b/models/Character.py index f73cd9c..39a8c2c 100644 --- a/models/Character.py +++ b/models/Character.py @@ -2,8 +2,9 @@ from pydantic import BaseModel class Character(BaseModel): - id: str + id: str | None name: str - character_image: bytes + character_image_doc_tg_id: str + character_image_tg_id: str | None character_bio: str diff --git a/repos/char_repo.py b/repos/char_repo.py index 89598aa..01d7cf5 100644 --- a/repos/char_repo.py +++ b/repos/char_repo.py @@ -34,4 +34,8 @@ class CharacterRepo: # Создаем объект characters.append(Character(**doc)) - return characters \ No newline at end of file + return characters + + async def update_char(self, char_id: str, character: Character) -> None: + await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()}) + diff --git a/routers/char_router.py b/routers/char_router.py index 30c043e..fe39004 100644 --- a/routers/char_router.py +++ b/routers/char_router.py @@ -1,3 +1,6 @@ +import logging +import traceback + from aiogram.filters import Command from aiogram.fsm.context import FSMContext from aiogram.fsm.state import State, StatesGroup @@ -47,36 +50,43 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot): try: # ВОТ ТУТ скачиваем файл (прямо перед сохранением) - file_io = await bot.download(file_id) - photo_bytes = file_io.getvalue() # Получаем байты + # file_io = await bot.download(file_id) + # photo_bytes = file_io.getvalue() # Получаем байты # Создаем модель char = Character( id=None, name=name, - character_image=photo_bytes, + # character_image=photo_bytes, + character_image_tg_id=None, + character_image_doc_tg_id=file_id, character_bio=bio ) # Сохраняем через DAO await dao.chars.add_character(char) - + file_info = await bot.get_file(char.character_image_doc_tg_id) + file_bytes = await bot.download_file(file_info.file_path) # Отправляем подтверждение # Используем байты для отправки обратно - await message.answer_photo( - photo=BufferedInputFile(photo_bytes, filename="char.png"), + photo_msg = await message.answer_photo( + photo=BufferedInputFile(file_bytes.read(), + filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id, caption=( "🎉 Персонаж создан!\n\n" f"👤 Имя: {char.name}\n" f"📝 Био: {char.character_bio}" ) ) + char.character_image_tg_id = photo_msg.photo[0].file_id + await dao.chars.update_char(char.id, char) await wait_msg.delete() # Сбрасываем состояние await state.clear() except Exception as e: + logging.error(e) await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}") # Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново @@ -98,7 +108,7 @@ async def get_chars(message: Message, state: FSMContext, dao: DAO): @router.callback_query(F.data.startswith("char_info_")) -async def get_char_info(callback_query: CallbackQuery, state: FSMContext, dao: DAO): +async def get_char_info(callback_query: CallbackQuery, state: FSMContext, dao: DAO, bot: Bot): await callback_query.message.delete() wait_msg = await callback_query.message.answer("Ищем инфу о персонаже") char = await dao.chars.get_character(callback_query.data.split("_")[-1]) @@ -109,17 +119,19 @@ async def get_char_info(callback_query: CallbackQuery, state: FSMContext, dao: D return keyboard = InlineKeyboardMarkup(inline_keyboard=[ [InlineKeyboardButton(text="Запросить фото в документе", callback_data=f'char_photo_file_{char.id}')]]) - await callback_query.message.answer_photo(photo=BufferedInputFile(char.character_image, f"photo_{char.id}.png"), caption=f"👤 Имя: {char.name}\n" - f"📝 Био: {char.character_bio}", - reply_markup=keyboard) + photo_msg = await callback_query.message.answer_photo( + photo=char.character_image_tg_id, + caption=f"👤 Имя: {char.name}\n" + f"📝 Био: {char.character_bio}", + reply_markup=keyboard) await wait_msg.delete() @router.callback_query(F.data.startswith("char_photo_file")) async def get_char_info_photo_file(callback_query: CallbackQuery, state: FSMContext, dao: DAO): char = await dao.chars.get_character(callback_query.data.split("_")[-1]) - await callback_query.message.answer_document(BufferedInputFile(char.character_image, f"photo_{char.id}.png")) + await callback_query.message.answer_document(char.character_image_doc_tg_id) # 4. Хендлер-помощник (если отправили команду без файла) diff --git a/routers/gen_router.py b/routers/gen_router.py index 8c68f5b..9baa519 100644 --- a/routers/gen_router.py +++ b/routers/gen_router.py @@ -1,6 +1,7 @@ import asyncio +import logging import random -from enum import Enum + from typing import List from aiogram import Router, Bot, F @@ -10,7 +11,6 @@ from aiogram.filters import * from aiogram.fsm.context import FSMContext from aiogram.fsm.state import StatesGroup, State from aiogram.types import * -from aiogram.types import message import keyboards from adapters.google_adapter import GoogleAdapter @@ -26,13 +26,11 @@ class States(StatesGroup): gen_mode = State() - - async def init_gen_mode(state: FSMContext, dao: DAO): data = await state.get_data() data['aspect_ratio'] = AspectRatios.NINESIXTEEN.name data['quality'] = Quality.ONEK.name - data['type'] = GenType.TEXT.name + data['type'] = GenType.IMAGE.name await state.update_data(data) @@ -63,6 +61,8 @@ async def gen_mode_on(callback_query: CallbackQuery, state: FSMContext, dao: DAO @router.callback_query(States.gen_mode_wait_char, F.data.startswith("select_char_")) async def select_char(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() + await call.message.delete() await state.update_data({"char_id": call.data.split("_")[-1]}) await init_gen_mode(state=state, dao=dao) await state.set_state(States.gen_mode) @@ -71,6 +71,7 @@ async def select_char(call: CallbackQuery, state: FSMContext, dao: DAO): @router.callback_query(States.gen_mode, F.data == 'gen_mode_off') async def gen_mode_off(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await state.clear() await state.set_data({}) await call.message.delete() @@ -79,6 +80,7 @@ async def gen_mode_off(call: CallbackQuery, state: FSMContext, dao: DAO): @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_char') async def gen_mode_change_char(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() chars = await dao.chars.get_all_characters() if len(chars) == 0: await call.message.edit_caption( @@ -90,17 +92,20 @@ async def gen_mode_change_char(call: CallbackQuery, state: FSMContext, dao: DAO) for char in chars: keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_new_char_{char.id}')) keyboards.append(InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_char_change")) - await call.message.edit_caption("Выбери персонажа", reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards])) + await call.message.edit_caption("Выбери персонажа", + reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards])) @router.callback_query(States.gen_mode, F.data.startswith('select_new_char_')) async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await state.update_data({"char_id": call.data.split("_")[-1]}) await gen_mode_base_msg(call.message, state=state, dao=dao) @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio') async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() keyboards = [] for ratio in AspectRatios: keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}')) @@ -111,12 +116,14 @@ async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, d @router.callback_query(States.gen_mode, F.data.startswith('select_ratio_')) async def change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await state.update_data({"aspect_ratio": call.data.split("_")[-1]}) await gen_mode_base_msg(call.message, state=state, dao=dao) @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_quality') async def gen_mode_change_quality(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() keyboards = [] for quality in Quality: keyboards.append(InlineKeyboardButton(text=quality.value, callback_data=f'select_quality_{quality.name}')) @@ -127,12 +134,14 @@ async def gen_mode_change_quality(call: CallbackQuery, state: FSMContext, dao: D @router.callback_query(States.gen_mode, F.data.startswith('select_quality_')) async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await state.update_data({"quality": call.data.split("_")[-1]}) await gen_mode_base_msg(call.message, state=state, dao=dao) @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type') async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() keyboards = [] for gen_type in GenType: keyboards.append(InlineKeyboardButton(text=gen_type.value, callback_data=f'select_type_{gen_type.name}')) @@ -143,6 +152,7 @@ async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO) @router.callback_query(States.gen_mode, F.data.startswith('select_type_')) async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await state.update_data({"type": call.data.split("_")[-1]}) await gen_mode_base_msg(call.message, state=state, dao=dao) @@ -152,6 +162,7 @@ async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO): @router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_type_change') @router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_quality_change') async def cancel_gen_mode_change(call: CallbackQuery, state: FSMContext, dao: DAO): + await call.answer() await gen_mode_base_msg(call.message, state=state, dao=dao) @@ -159,7 +170,7 @@ async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_ data = await state.get_data() char: Character = await dao.chars.get_character(data["char_id"]) if call_type == "start": - await message.answer_photo(BufferedInputFile(char.character_image, f'{char.id}.png'), + await message.answer_photo(char.character_image_tg_id, caption="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n" "Фото девушки грузить не надо, оно загрузится по дефолту\n\n" "Но дополнительные фото можно загрузить.", @@ -168,7 +179,7 @@ async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_ try: await message.edit_caption( caption="🎉 Режим генерации включен!", - reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao)) + reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao)) except TelegramBadRequest as tbr: await message.edit_text( text="🎉 Режим генерации включен!", @@ -226,7 +237,7 @@ async def handle_album( await message.answer_document(file, caption="✨ Generated result") else: await message.answer("❌ Генерация не вернула изображений.") - await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" ) + await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") except Exception as e: await wait_msg.edit_text(f"❌ Ошибка: {e}") @@ -260,24 +271,31 @@ async def gen_mode_start( wait_msg = await message.answer("🎨 Генерирую...") + data = await state.get_data() try: - generated_files = await generate_image( - prompt=prompt, - media=media_ids, # Передаем список (пустой или с 1 фото) - state=state, - dao=dao, - bot=bot, - gemini=gemini - ) + if GenType(data['type']) is GenType.IMAGE: + generated_files = await generate_image( + prompt=prompt, + media=media_ids, # Передаем список (пустой или с 1 фото) + state=state, + dao=dao, + bot=bot, + gemini=gemini + ) - await wait_msg.delete() + await wait_msg.delete() - if generated_files: - for file in generated_files: - await message.answer_document(file, caption="✨ Generated result") + if generated_files: + for file in generated_files: + await message.answer_document(file, caption="✨ Generated result") + + else: + await message.answer("❌ Ничего не сгенерировалось.") else: - await message.answer("❌ Ничего не сгенерировалось.") - await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" ) + generated_text = await gen_start_text(message=message, state=state, dao=dao, gemini=gemini, bot=bot) + if generated_text: + await wait_msg.edit_text(generated_text) + await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") except Exception as e: await wait_msg.edit_text(f"❌ Ошибка: {e}") @@ -291,10 +309,8 @@ async def generate_image( bot: Bot, gemini: GoogleAdapter ) -> List[BufferedInputFile]: - # 1. Получаем данные персонажа data = await state.get_data() char_id = data.get("char_id") - if not char_id: raise ValueError("Character ID not found in state") @@ -303,19 +319,24 @@ async def generate_image( # Начинаем список с фото персонажа media_group_bytes = [char.character_image] - # 2. Скачиваем дополнительные файлы (если переданы) if media: - # Создаем задачи для скачивания + # Скачиваем файлы + # tasks вернут список объектов BytesIO tasks = [bot.download(file_id) for file_id in media] + downloaded_io = await asyncio.gather(*tasks) - # Запускаем все скачивания параллельно - downloaded_files = await asyncio.gather(*tasks) + for f in downloaded_io: + # ОПТИМИЗАЦИЯ: + # f - это BytesIO. getvalue() копирует байты. + # Если адаптер сразу делает Image.open(io.BytesIO(bytes)), + # мы можем передать байты, но после использования 'f' (BytesIO) + # он будет удален сборщиком мусора быстрее, если мы не держим на него ссылку. - # Добавляем байты скачанных файлов в общий список - for f in downloaded_files: - media_group_bytes.append(f.getvalue()) + # Читаем байты и сразу забываем про объект f + media_group_bytes.append(f.read()) + f.close() # Явно закрываем поток - # 3. Генерация в Gemini + # ... вызов Gemini ... generated_images_io = await asyncio.to_thread( gemini.generate_image, prompt=prompt, @@ -324,25 +345,32 @@ async def generate_image( quality=Quality[data['quality']], ) - # 4. Упаковка результата images = [] if generated_images_io: - for i, img_io in enumerate(generated_images_io): - # Важно: img_io.read() работает корректно, если курсор в начале (adapter это делает) + for img_io in generated_images_io: images.append( BufferedInputFile( img_io.read(), filename=f"img_{random.randint(1000, 9999)}.png" ) ) + # Важно: img_io здесь тоже BytesIO. После отправки aiogram закроет его сам, + # либо он удалится GC. Но если список generated_images_io большой, + # он висит в памяти до выхода из функции. return images + @router.message(F.text) -async def handle_text(message: Message, gemini: GoogleAdapter, bot: Bot): +async def handle_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot): + wait_msg = await message.answer("Генерирую...") + await wait_msg.edit_text(await gen_start_text(message=message, gemini=gemini, state=state, dao=dao, bot=bot)) + + +async def gen_start_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot) -> str: await bot.send_chat_action(message.chat.id, "typing") - result = await asyncio.to_thread(gemini.generate, prompt=message.text) - if result.get("text"): - await message.answer(result["text"], parse_mode=ParseMode.MARKDOWN) + result = await asyncio.to_thread(gemini.generate_text, prompt=message.text) + if result: + return result else: - await message.answer("Ошибка генерации") + raise Exception("Ошибка генерации")