Files
ai-char-bot/routers/gen_router.py
2026-02-03 09:45:27 +03:00

388 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import logging
import random
from typing import List
from aiogram import Router, Bot, F
from aiogram.enums import ParseMode
from aiogram.exceptions import TelegramBadRequest
from aiogram.filters import *
from aiogram.fsm.context import FSMContext
from aiogram.fsm.state import StatesGroup, State
from aiogram.types import *
import keyboards
from adapters.google_adapter import GoogleAdapter
from models.Character import Character
from models.enums import AspectRatios, Quality, GenType
from repos.dao import DAO
router = Router()
class States(StatesGroup):
gen_mode_wait_char = State()
gen_mode = State()
async def init_gen_mode(state: FSMContext, dao: DAO):
data = await state.get_data()
data['aspect_ratio'] = AspectRatios.NINESIXTEEN.name
data['quality'] = Quality.ONEK.name
data['type'] = GenType.IMAGE.name
await state.update_data(data)
@router.message(Command("gen_mode"))
async def gen_mode(message: Message, state: FSMContext, dao: DAO):
state_on = await state.get_state()
if state_on is None and state_on is not States.gen_mode:
await message.answer("Включить режим генерации?", reply_markup=InlineKeyboardMarkup(
inline_keyboard=[[InlineKeyboardButton(text="✅ Включить!", callback_data="gen_mode_on")]]))
else:
await gen_mode_base_msg(message, state, dao, call_type="start")
@router.callback_query(F.data == "gen_mode_on")
async def gen_mode_on(callback_query: CallbackQuery, state: FSMContext, dao: DAO):
await callback_query.message.delete()
chars = await dao.chars.get_all_characters()
if len(chars) == 0:
await callback_query.message.answer(
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char")
keyboards = []
for char in chars:
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_char_{char.id}'))
await state.set_state(States.gen_mode_wait_char)
await callback_query.message.answer("Выбери персонажа",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
@router.callback_query(States.gen_mode_wait_char, F.data.startswith("select_char_"))
async def select_char(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await call.message.delete()
await state.update_data({"char_id": call.data.split("_")[-1]})
await init_gen_mode(state=state, dao=dao)
await state.set_state(States.gen_mode)
await gen_mode_base_msg(call.message, state=state, dao=dao, call_type="start")
@router.callback_query(States.gen_mode, F.data == 'gen_mode_off')
async def gen_mode_off(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await state.clear()
await state.set_data({})
await call.message.delete()
await call.message.answer("Режим генерации выключен. Нажмите /start для продолжения!")
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_char')
async def gen_mode_change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
chars = await dao.chars.get_all_characters()
if len(chars) == 0:
await call.message.edit_caption(
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char",
reply_markup=None)
return
else:
keyboards = []
for char in chars:
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_new_char_{char.id}'))
keyboards.append(InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_char_change"))
await call.message.edit_caption("Выбери персонажа",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
@router.callback_query(States.gen_mode, F.data.startswith('select_new_char_'))
async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await state.update_data({"char_id": call.data.split("_")[-1]})
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
keyboards = []
for ratio in AspectRatios:
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
await call.message.edit_caption(caption="Выбери соотношение сторон",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
async def change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await state.update_data({"aspect_ratio": call.data.split("_")[-1]})
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_quality')
async def gen_mode_change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
keyboards = []
for quality in Quality:
keyboards.append(InlineKeyboardButton(text=quality.value, callback_data=f'select_quality_{quality.name}'))
await call.message.edit_caption(caption="Выбери качество",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
text="⬅️ Назад", callback_data="gen_mode_cancel_quality_change")]]))
@router.callback_query(States.gen_mode, F.data.startswith('select_quality_'))
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await state.update_data({"quality": call.data.split("_")[-1]})
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type')
async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
data = await state.get_data()
if GenType(data['type']) is GenType.IMAGE:
await state.update_data({"type": GenType.TEXT.name})
else:
await state.update_data({"type": GenType.IMAGE.name})
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data.startswith('select_type_'))
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await state.update_data({"type": call.data.split("_")[-1]})
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_char_change')
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_ratio_change')
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_type_change')
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_quality_change')
async def cancel_gen_mode_change(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
await gen_mode_base_msg(call.message, state=state, dao=dao)
async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_type='continue'):
data = await state.get_data()
char: Character = await dao.chars.get_character(data["char_id"])
if call_type == "start":
await message.answer_photo(char.character_image_tg_id,
caption="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n"
"<b>Фото девушки грузить не надо, оно загрузится по дефолту</b>\n\n"
"Но дополнительные фото можно загрузить.",
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
else:
try:
await message.edit_caption(
caption="🎉 Режим генерации включен!",
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
except TelegramBadRequest as tbr:
await message.edit_text(
text="🎉 Режим генерации включен!",
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
@router.message(States.gen_mode, F.media_group_id)
async def handle_album(
message: Message,
album: List[Message],
state: FSMContext,
dao: DAO,
gemini: GoogleAdapter,
bot: Bot
):
"""
Обработка альбома (группы фото).
"""
# 1. Ищем промпт (подпись) в любом из сообщений альбома
# message.text в альбомах обычно None
prompt = next((msg.caption for msg in album if msg.caption), None)
if not prompt:
await message.answer("⚠️ Вы отправили альбом, но не добавили описание (промпт).")
return
# 2. Собираем file_id всех фото
file_ids = []
for msg in album:
if msg.photo:
file_ids.append(msg.photo[-1].file_id)
elif msg.video:
# Если нужно, можно добавить обработку видео (пока пропускаем)
pass
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
wait_msg = await message.answer("🎨 Генерирую...")
# 3. Вызываем генерацию
try:
generated_files = await generate_image(
prompt=prompt,
media=file_ids,
state=state,
dao=dao,
bot=bot,
gemini=gemini
)
await wait_msg.delete()
# 4. Отправляем результат
if generated_files:
for file in generated_files:
await message.answer_document(file, caption="✨ Generated result")
else:
await message.answer("❌ Генерация не вернула изображений.")
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
except Exception as e:
await wait_msg.edit_text(f"❌ Ошибка: {e}")
@router.message(States.gen_mode)
async def gen_mode_start(
message: Message,
state: FSMContext,
dao: DAO,
gemini: GoogleAdapter,
bot: Bot
):
"""
Обработка одиночного сообщения (Текст или Фото+Подпись)
"""
# 1. Получаем промпт (Текст или Подпись)
prompt = message.text or message.caption
if not prompt:
await message.answer("⚠️ Напиши промпт (или отправь фото с подписью).")
return
# 2. Проверяем, есть ли прикрепленное фото
media_ids = []
if message.photo:
media_ids.append(message.photo[-1].file_id)
elif message.reply_to_message and message.reply_to_message.photo:
# Поддержка реплая на фото
media_ids.append(message.reply_to_message.photo[-1].file_id)
wait_msg = await message.answer("🎨 Генерирую...")
data = await state.get_data()
try:
if GenType[data['type']] is GenType.IMAGE:
generated_files = await generate_image(
prompt=prompt,
media=media_ids, # Передаем список (пустой или с 1 фото)
state=state,
dao=dao,
bot=bot,
gemini=gemini
)
await wait_msg.delete()
if generated_files:
for file in generated_files:
await message.answer_document(file, caption="✨ Generated result")
else:
await message.answer("❌ Ничего не сгенерировалось.")
else:
generated_text = await gen_start_text(message=message, state=state, dao=dao, gemini=gemini, bot=bot)
if generated_text:
await wait_msg.edit_text(generated_text)
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
except Exception as e:
await wait_msg.edit_text(f"❌ Ошибка: {e}")
async def generate_image(
prompt: str,
media: List[str] | None,
state: FSMContext,
dao: DAO,
bot: Bot,
gemini: GoogleAdapter
) -> List[BufferedInputFile]:
data = await state.get_data()
char_id = data.get("char_id")
if not char_id:
raise ValueError("Character ID not found in state")
char: Character = await dao.chars.get_character(char_id)
# Начинаем список с фото персонажа
file_byes = await bot.download(char.character_image_doc_tg_id)
media_group_bytes = [file_byes.read()]
file_byes.close()
if media:
# Скачиваем файлы
# tasks вернут список объектов BytesIO
tasks = [bot.download(file_id) for file_id in media]
downloaded_io = await asyncio.gather(*tasks)
for f in downloaded_io:
# ОПТИМИЗАЦИЯ:
# f - это BytesIO. getvalue() копирует байты.
# Если адаптер сразу делает Image.open(io.BytesIO(bytes)),
# мы можем передать байты, но после использования 'f' (BytesIO)
# он будет удален сборщиком мусора быстрее, если мы не держим на него ссылку.
# Читаем байты и сразу забываем про объект f
media_group_bytes.append(f.read())
f.close() # Явно закрываем поток
# ... вызов Gemini ...
generated_images_io = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=AspectRatios[data['aspect_ratio']],
quality=Quality[data['quality']],
)
images = []
if generated_images_io:
for img_io in generated_images_io:
# Читаем байты
content = img_io.read()
# Сразу закрываем поток от адаптера, освобождая память
img_io.close()
images.append(
BufferedInputFile(
content,
filename=f"img_{random.randint(1000, 9999)}.png"
)
)
return images
@router.message(F.text)
async def handle_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot):
wait_msg = await message.answer("Генерирую...")
await wait_msg.edit_text(await gen_start_text(message=message, gemini=gemini, state=state, dao=dao, bot=bot))
async def gen_start_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot,
char_id: str = None) -> str:
await bot.send_chat_action(message.chat.id, "typing")
prompt = "Use a TELEGRAM HTML formatting. If you write a prompt use <pre> tag.\n\n"
prompt += f"PROMPT:\n{message.text}\n\n"
if char_id:
char = await dao.chars.get_character(message.chat.id)
prompt += char.character_bio
result = await asyncio.to_thread(gemini.generate_text, prompt=prompt)
if result:
return result
else:
raise Exception("Ошибка генерации")