349 lines
15 KiB
Python
349 lines
15 KiB
Python
import asyncio
|
||
import random
|
||
from enum import Enum
|
||
from typing import List
|
||
|
||
from aiogram import Router, Bot, F
|
||
from aiogram.enums import ParseMode
|
||
from aiogram.exceptions import TelegramBadRequest
|
||
from aiogram.filters import *
|
||
from aiogram.fsm.context import FSMContext
|
||
from aiogram.fsm.state import StatesGroup, State
|
||
from aiogram.types import *
|
||
from aiogram.types import message
|
||
|
||
import keyboards
|
||
from adapters.google_adapter import GoogleAdapter
|
||
from models.Character import Character
|
||
from models.enums import AspectRatios, Quality, GenType
|
||
from repos.dao import DAO
|
||
|
||
router = Router()
|
||
|
||
|
||
class States(StatesGroup):
|
||
gen_mode_wait_char = State()
|
||
gen_mode = State()
|
||
|
||
|
||
|
||
|
||
async def init_gen_mode(state: FSMContext, dao: DAO):
|
||
data = await state.get_data()
|
||
data['aspect_ratio'] = AspectRatios.NINESIXTEEN.name
|
||
data['quality'] = Quality.ONEK.name
|
||
data['type'] = GenType.TEXT.name
|
||
await state.update_data(data)
|
||
|
||
|
||
@router.message(Command("gen_mode"))
|
||
async def gen_mode(message: Message, state: FSMContext, dao: DAO):
|
||
state_on = await state.get_state()
|
||
if state_on is None and state_on is not States.gen_mode:
|
||
await message.answer("Включить режим генерации?", reply_markup=InlineKeyboardMarkup(
|
||
inline_keyboard=[[InlineKeyboardButton(text="✅ Включить!", callback_data="gen_mode_on")]]))
|
||
else:
|
||
await gen_mode_base_msg(message, state, dao, call_type="start")
|
||
|
||
|
||
@router.callback_query(F.data == "gen_mode_on")
|
||
async def gen_mode_on(callback_query: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await callback_query.message.delete()
|
||
chars = await dao.chars.get_all_characters()
|
||
if len(chars) == 0:
|
||
await callback_query.message.answer(
|
||
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char")
|
||
keyboards = []
|
||
for char in chars:
|
||
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_char_{char.id}'))
|
||
await state.set_state(States.gen_mode_wait_char)
|
||
await callback_query.message.answer("Выбери персонажа",
|
||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
|
||
|
||
|
||
@router.callback_query(States.gen_mode_wait_char, F.data.startswith("select_char_"))
|
||
async def select_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.update_data({"char_id": call.data.split("_")[-1]})
|
||
await init_gen_mode(state=state, dao=dao)
|
||
await state.set_state(States.gen_mode)
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao, call_type="start")
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_off')
|
||
async def gen_mode_off(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.clear()
|
||
await state.set_data({})
|
||
await call.message.delete()
|
||
await call.message.answer("Режим генерации выключен. Нажмите /start для продолжения!")
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_char')
|
||
async def gen_mode_change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
chars = await dao.chars.get_all_characters()
|
||
if len(chars) == 0:
|
||
await call.message.edit_caption(
|
||
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char",
|
||
reply_markup=None)
|
||
return
|
||
else:
|
||
keyboards = []
|
||
for char in chars:
|
||
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_new_char_{char.id}'))
|
||
keyboards.append(InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_char_change"))
|
||
await call.message.edit_caption("Выбери персонажа", reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data.startswith('select_new_char_'))
|
||
async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.update_data({"char_id": call.data.split("_")[-1]})
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
keyboards = []
|
||
for ratio in AspectRatios:
|
||
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
||
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||
async def change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.update_data({"aspect_ratio": call.data.split("_")[-1]})
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_quality')
|
||
async def gen_mode_change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
keyboards = []
|
||
for quality in Quality:
|
||
keyboards.append(InlineKeyboardButton(text=quality.value, callback_data=f'select_quality_{quality.name}'))
|
||
await call.message.edit_caption(caption="Выбери качество",
|
||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||
text="⬅️ Назад", callback_data="gen_mode_cancel_quality_change")]]))
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data.startswith('select_quality_'))
|
||
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.update_data({"quality": call.data.split("_")[-1]})
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type')
|
||
async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
keyboards = []
|
||
for gen_type in GenType:
|
||
keyboards.append(InlineKeyboardButton(text=gen_type.value, callback_data=f'select_type_{gen_type.name}'))
|
||
await call.message.edit_caption(caption="Выбери тип генерации", reply_markup=InlineKeyboardMarkup(
|
||
inline_keyboard=[keyboards,
|
||
[InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_type_change")]]))
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data.startswith('select_type_'))
|
||
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await state.update_data({"type": call.data.split("_")[-1]})
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||
|
||
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_char_change')
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_ratio_change')
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_type_change')
|
||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_quality_change')
|
||
async def cancel_gen_mode_change(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||
|
||
|
||
async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_type='continue'):
|
||
data = await state.get_data()
|
||
char: Character = await dao.chars.get_character(data["char_id"])
|
||
if call_type == "start":
|
||
await message.answer_photo(BufferedInputFile(char.character_image, f'{char.id}.png'),
|
||
caption="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n"
|
||
"<b>Фото девушки грузить не надо, оно загрузится по дефолту</b>\n\n"
|
||
"Но дополнительные фото можно загрузить.",
|
||
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||
else:
|
||
try:
|
||
await message.edit_caption(
|
||
caption="🎉 Режим генерации включен!",
|
||
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||
except TelegramBadRequest as tbr:
|
||
await message.edit_text(
|
||
text="🎉 Режим генерации включен!",
|
||
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||
|
||
|
||
@router.message(States.gen_mode, F.media_group_id)
|
||
async def handle_album(
|
||
message: Message,
|
||
album: List[Message],
|
||
state: FSMContext,
|
||
dao: DAO,
|
||
gemini: GoogleAdapter,
|
||
bot: Bot
|
||
):
|
||
"""
|
||
Обработка альбома (группы фото).
|
||
"""
|
||
# 1. Ищем промпт (подпись) в любом из сообщений альбома
|
||
# message.text в альбомах обычно None
|
||
prompt = next((msg.caption for msg in album if msg.caption), None)
|
||
|
||
if not prompt:
|
||
await message.answer("⚠️ Вы отправили альбом, но не добавили описание (промпт).")
|
||
return
|
||
|
||
# 2. Собираем file_id всех фото
|
||
file_ids = []
|
||
for msg in album:
|
||
if msg.photo:
|
||
file_ids.append(msg.photo[-1].file_id)
|
||
elif msg.video:
|
||
# Если нужно, можно добавить обработку видео (пока пропускаем)
|
||
pass
|
||
|
||
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
|
||
wait_msg = await message.answer("🎨 Генерирую...")
|
||
|
||
# 3. Вызываем генерацию
|
||
try:
|
||
generated_files = await generate_image(
|
||
prompt=prompt,
|
||
media=file_ids,
|
||
state=state,
|
||
dao=dao,
|
||
bot=bot,
|
||
gemini=gemini
|
||
)
|
||
|
||
await wait_msg.delete()
|
||
|
||
# 4. Отправляем результат
|
||
if generated_files:
|
||
for file in generated_files:
|
||
await message.answer_document(file, caption="✨ Generated result")
|
||
else:
|
||
await message.answer("❌ Генерация не вернула изображений.")
|
||
await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" )
|
||
|
||
except Exception as e:
|
||
await wait_msg.edit_text(f"❌ Ошибка: {e}")
|
||
|
||
|
||
@router.message(States.gen_mode)
|
||
async def gen_mode_start(
|
||
message: Message,
|
||
state: FSMContext,
|
||
dao: DAO,
|
||
gemini: GoogleAdapter,
|
||
bot: Bot
|
||
):
|
||
"""
|
||
Обработка одиночного сообщения (Текст или Фото+Подпись)
|
||
"""
|
||
# 1. Получаем промпт (Текст или Подпись)
|
||
prompt = message.text or message.caption
|
||
|
||
if not prompt:
|
||
await message.answer("⚠️ Напиши промпт (или отправь фото с подписью).")
|
||
return
|
||
|
||
# 2. Проверяем, есть ли прикрепленное фото
|
||
media_ids = []
|
||
if message.photo:
|
||
media_ids.append(message.photo[-1].file_id)
|
||
elif message.reply_to_message and message.reply_to_message.photo:
|
||
# Поддержка реплая на фото
|
||
media_ids.append(message.reply_to_message.photo[-1].file_id)
|
||
|
||
wait_msg = await message.answer("🎨 Генерирую...")
|
||
|
||
try:
|
||
generated_files = await generate_image(
|
||
prompt=prompt,
|
||
media=media_ids, # Передаем список (пустой или с 1 фото)
|
||
state=state,
|
||
dao=dao,
|
||
bot=bot,
|
||
gemini=gemini
|
||
)
|
||
|
||
await wait_msg.delete()
|
||
|
||
if generated_files:
|
||
for file in generated_files:
|
||
await message.answer_document(file, caption="✨ Generated result")
|
||
else:
|
||
await message.answer("❌ Ничего не сгенерировалось.")
|
||
await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" )
|
||
|
||
except Exception as e:
|
||
await wait_msg.edit_text(f"❌ Ошибка: {e}")
|
||
|
||
|
||
async def generate_image(
|
||
prompt: str,
|
||
media: List[str] | None,
|
||
state: FSMContext,
|
||
dao: DAO,
|
||
bot: Bot,
|
||
gemini: GoogleAdapter
|
||
) -> List[BufferedInputFile]:
|
||
# 1. Получаем данные персонажа
|
||
data = await state.get_data()
|
||
char_id = data.get("char_id")
|
||
|
||
if not char_id:
|
||
raise ValueError("Character ID not found in state")
|
||
|
||
char: Character = await dao.chars.get_character(char_id)
|
||
|
||
# Начинаем список с фото персонажа
|
||
media_group_bytes = [char.character_image]
|
||
|
||
# 2. Скачиваем дополнительные файлы (если переданы)
|
||
if media:
|
||
# Создаем задачи для скачивания
|
||
tasks = [bot.download(file_id) for file_id in media]
|
||
|
||
# Запускаем все скачивания параллельно
|
||
downloaded_files = await asyncio.gather(*tasks)
|
||
|
||
# Добавляем байты скачанных файлов в общий список
|
||
for f in downloaded_files:
|
||
media_group_bytes.append(f.getvalue())
|
||
|
||
# 3. Генерация в Gemini
|
||
generated_images_io = await asyncio.to_thread(
|
||
gemini.generate_image,
|
||
prompt=prompt,
|
||
images_list=media_group_bytes,
|
||
aspect_ratio=AspectRatios[data['aspect_ratio']],
|
||
quality=Quality[data['quality']],
|
||
)
|
||
|
||
# 4. Упаковка результата
|
||
images = []
|
||
if generated_images_io:
|
||
for i, img_io in enumerate(generated_images_io):
|
||
# Важно: img_io.read() работает корректно, если курсор в начале (adapter это делает)
|
||
images.append(
|
||
BufferedInputFile(
|
||
img_io.read(),
|
||
filename=f"img_{random.randint(1000, 9999)}.png"
|
||
)
|
||
)
|
||
|
||
return images
|
||
|
||
@router.message(F.text)
|
||
async def handle_text(message: Message, gemini: GoogleAdapter, bot: Bot):
|
||
await bot.send_chat_action(message.chat.id, "typing")
|
||
result = await asyncio.to_thread(gemini.generate, prompt=message.text)
|
||
if result.get("text"):
|
||
await message.answer(result["text"], parse_mode=ParseMode.MARKDOWN)
|
||
else:
|
||
await message.answer("Ошибка генерации")
|