+ gen mode
This commit is contained in:
@@ -1,15 +1,14 @@
|
|||||||
import os
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Union, Dict, Any
|
from typing import List, Union
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
# Импортируем из нового SDK
|
from PIL import Image
|
||||||
from google import genai
|
from google import genai
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
# Для настройки логгера
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -18,82 +17,100 @@ class GoogleAdapter:
|
|||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("API Key for Gemini is missing")
|
raise ValueError("API Key for Gemini is missing")
|
||||||
self.client = genai.Client(api_key=api_key)
|
self.client = genai.Client(api_key=api_key)
|
||||||
# Укажите актуальную модель.
|
|
||||||
# Если gemini-3-pro-image-preview недоступна, используйте gemini-2.0-flash-exp
|
|
||||||
self.model_name = "gemini-3-pro-preview"
|
|
||||||
|
|
||||||
def generate(
|
# Константы моделей
|
||||||
self,
|
self.TEXT_MODEL = "gemini-3-pro-preview"
|
||||||
prompt: str,
|
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||||
image_bytes: bytes = None,
|
|
||||||
generate_image: bool = False
|
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
|
||||||
) -> Dict[str, Any]:
|
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||||
"""
|
|
||||||
Универсальный метод:
|
|
||||||
- Если generate_image=True: просим модель вернуть картинку (Image Generation).
|
|
||||||
- Если image_bytes переданы + generate_image=False: это Vision (описание фото).
|
|
||||||
- Если image_bytes + generate_image=True: это Image-to-Image (редактирование).
|
|
||||||
"""
|
|
||||||
if generate_image:
|
|
||||||
self.model_name = "gemini-3-pro-image-preview"
|
|
||||||
else :
|
|
||||||
self.model_name = "gemini-3-pro-preview"
|
|
||||||
contents = [prompt]
|
contents = [prompt]
|
||||||
|
if images_list:
|
||||||
# Если есть входное изображение (для Vision или для редактирования)
|
for img_bytes in images_list:
|
||||||
if image_bytes:
|
|
||||||
try:
|
try:
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
# Gemini API требует PIL Image на входе
|
||||||
|
image = Image.open(io.BytesIO(img_bytes))
|
||||||
contents.append(image)
|
contents.append(image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing input image: {e}")
|
logger.error(f"Error processing input image: {e}")
|
||||||
return {"error": "Не удалось обработать входящее изображение."}
|
return contents
|
||||||
|
|
||||||
# Настраиваем конфигурацию
|
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||||
# Для генерации картинок добавляем 'IMAGE' в response_modalities
|
"""
|
||||||
modalities = ['TEXT']
|
Генерация текста (Чат или Vision).
|
||||||
if generate_image:
|
Возвращает строку с ответом.
|
||||||
modalities.append('IMAGE')
|
"""
|
||||||
|
contents = self._prepare_contents(prompt, images_list)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Вызов API (синхронный метод в обертке, но aiogram вызывает его в треде,
|
|
||||||
# либо используйте client.aio для асинхронности если поддерживается версией SDK)
|
|
||||||
# В google-genai v0.3+ есть асинхронный клиент, но для простоты здесь стандартный вызов.
|
|
||||||
# Чтобы не блокировать event loop, в main.py мы обернем это в to_thread при необходимости,
|
|
||||||
# но пока используем стандартный вызов.
|
|
||||||
|
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
model=self.model_name,
|
model=self.TEXT_MODEL,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=types.GenerateContentConfig(
|
config=types.GenerateContentConfig(
|
||||||
response_modalities=modalities,
|
response_modalities=['TEXT'],
|
||||||
temperature=0.7 if not generate_image else 1.0,
|
temperature=0.7,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {"text": "", "images": []}
|
# Собираем текст из всех частей ответа
|
||||||
|
result_text = ""
|
||||||
# Парсим ответ (Text или Inline Data)
|
|
||||||
if response.parts:
|
if response.parts:
|
||||||
for part in response.parts:
|
for part in response.parts:
|
||||||
if part.text:
|
if part.text:
|
||||||
result["text"] += part.text
|
result_text += part.text
|
||||||
|
|
||||||
# Проверяем наличие сгенерированного изображения
|
return result_text
|
||||||
if part.inline_data:
|
|
||||||
# ИСПРАВЛЕНИЕ: Берем "сырые" байты напрямую из ответа
|
|
||||||
# Это работает быстрее и не вызывает ошибку с PIL
|
|
||||||
|
|
||||||
# part.inline_data.data — это уже bytes
|
|
||||||
byte_arr = io.BytesIO(part.inline_data.data)
|
|
||||||
now = datetime.now()
|
|
||||||
# Имя файла для телеграма (формально)
|
|
||||||
byte_arr.name = f'{now.timestamp()}.png'
|
|
||||||
|
|
||||||
result["images"].append(byte_arr)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini API Error: {e}")
|
logger.error(f"Gemini Text API Error: {e}")
|
||||||
return {"error": f"Ошибка API: {str(e)}"}
|
return f"Ошибка генерации текста: {e}"
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> List[io.BytesIO]:
|
||||||
|
"""
|
||||||
|
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||||
|
Возвращает список байтовых потоков (готовых к отправке).
|
||||||
|
"""
|
||||||
|
contents = self._prepare_contents(prompt, images_list)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=self.IMAGE_MODEL,
|
||||||
|
contents=contents,
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
response_modalities=['IMAGE'],
|
||||||
|
temperature=1.0,
|
||||||
|
image_config=types.ImageConfig(
|
||||||
|
aspect_ratio=aspect_ratio.value,
|
||||||
|
image_size=quality.value
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_images = []
|
||||||
|
|
||||||
|
if response.parts:
|
||||||
|
for part in response.parts:
|
||||||
|
# Ищем картинки (inline_data)
|
||||||
|
if part.inline_data:
|
||||||
|
try:
|
||||||
|
# 1. Берем сырые байты
|
||||||
|
raw_data = part.inline_data.data
|
||||||
|
byte_arr = io.BytesIO(raw_data)
|
||||||
|
|
||||||
|
# 2. Нейминг (формально, для TG)
|
||||||
|
timestamp = datetime.now().timestamp()
|
||||||
|
byte_arr.name = f'{timestamp}.png'
|
||||||
|
|
||||||
|
# 3. Важно: сбросить курсор в начало
|
||||||
|
byte_arr.seek(0)
|
||||||
|
|
||||||
|
generated_images.append(byte_arr)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing output image: {e}")
|
||||||
|
|
||||||
|
return generated_images
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Gemini Image API Error: {e}")
|
||||||
|
# В случае ошибки возвращаем пустой список (или можно рейзить исключение)
|
||||||
|
return []
|
||||||
23
keyboards.py
23
keyboards.py
@@ -1,5 +1,11 @@
|
|||||||
|
from aiogram.fsm.context import FSMContext
|
||||||
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton
|
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||||||
|
|
||||||
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_request_kb():
|
def get_request_kb():
|
||||||
return InlineKeyboardMarkup(inline_keyboard=[
|
return InlineKeyboardMarkup(inline_keyboard=[
|
||||||
[InlineKeyboardButton(text="🔐 Запросить доступ", callback_data="req_access")]
|
[InlineKeyboardButton(text="🔐 Запросить доступ", callback_data="req_access")]
|
||||||
@@ -13,3 +19,20 @@ def get_admin_decision_kb(user_id: int):
|
|||||||
InlineKeyboardButton(text="🚫 Запретить", callback_data=f"access_deny_{user_id}")
|
InlineKeyboardButton(text="🚫 Запретить", callback_data=f"access_deny_{user_id}")
|
||||||
]
|
]
|
||||||
])
|
])
|
||||||
|
|
||||||
|
async def get_gen_mode_kb(state: FSMContext, dao: DAO):
|
||||||
|
data = await state.get_data()
|
||||||
|
char = await dao.chars.get_character(character_id=data['char_id'])
|
||||||
|
return InlineKeyboardMarkup(inline_keyboard=[
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text=f'Перс: {char.name}', callback_data=f'gen_mode_change_char'),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text=f"🔁{AspectRatios[data['aspect_ratio']].value}", callback_data=f'gen_mode_change_aspect_ratio'),
|
||||||
|
InlineKeyboardButton(text=f"🔁{Quality[data['quality']].value}", callback_data=f'gen_mode_change_quality'),
|
||||||
|
InlineKeyboardButton(text=f"🔁{GenType[data['type']].value}",callback_data=f'gen_mode_change_type')
|
||||||
|
],
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text="❌ Выйти из режима генерации", callback_data=f'gen_mode_off'),
|
||||||
|
]
|
||||||
|
])
|
||||||
4
main.py
4
main.py
@@ -13,6 +13,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
|||||||
|
|
||||||
# Импорты
|
# Импорты
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from middlewares.album import AlbumMiddleware
|
||||||
from middlewares.auth import AuthMiddleware
|
from middlewares.auth import AuthMiddleware
|
||||||
from middlewares.dao import DaoMiddleware
|
from middlewares.dao import DaoMiddleware
|
||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
@@ -64,7 +65,8 @@ dp.include_router(gen_router)
|
|||||||
# Вешаем защиту ТОЛЬКО на основной роутер
|
# Вешаем защиту ТОЛЬКО на основной роутер
|
||||||
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
||||||
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
||||||
char_router.message.middleware(DaoMiddleware(dao=DAO(client=mongo_client)))
|
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||||
|
dp.update.middleware(DaoMiddleware(dao=DAO(client=mongo_client)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
51
middlewares/album.py
Normal file
51
middlewares/album.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, List, Union, Callable, Awaitable
|
||||||
|
from aiogram import BaseMiddleware
|
||||||
|
from aiogram.types import Message
|
||||||
|
|
||||||
|
|
||||||
|
class AlbumMiddleware(BaseMiddleware):
|
||||||
|
def __init__(self, latency: float = 0.5):
|
||||||
|
# latency - задержка в секундах для сбора частей альбома
|
||||||
|
self.latency = latency
|
||||||
|
self.album_data: Dict[str, List[Message]] = {}
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
handler: Callable[[Message, Dict[str, Any]], Awaitable[Any]],
|
||||||
|
event: Message,
|
||||||
|
data: Dict[str, Any]
|
||||||
|
) -> Any:
|
||||||
|
# Если у сообщения нет media_group_id, это не альбом -> пропускаем дальше как обычно
|
||||||
|
if not event.media_group_id:
|
||||||
|
return await handler(event, data)
|
||||||
|
|
||||||
|
group_id = event.media_group_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Если этот альбом мы еще не видели (первое сообщение из пачки)
|
||||||
|
if group_id not in self.album_data:
|
||||||
|
self.album_data[group_id] = [event] # Создаем список
|
||||||
|
await asyncio.sleep(self.latency) # Ждем остальные части
|
||||||
|
|
||||||
|
# После ожидания кладем собранный список в data
|
||||||
|
# Теперь в хендлере будет доступен аргумент 'album'
|
||||||
|
data["album"] = self.album_data[group_id]
|
||||||
|
|
||||||
|
# Вызываем хендлер ОДИН раз
|
||||||
|
return await handler(event, data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Если альбом уже собирается, просто добавляем сообщение в список
|
||||||
|
# и НЕ вызываем хендлер (прерываем цепочку для этого сообщения)
|
||||||
|
self.album_data[group_id].append(event)
|
||||||
|
return
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Чистим память после обработки, если это был "главный" поток обработки
|
||||||
|
if group_id in self.album_data and len(self.album_data[group_id]) > 1:
|
||||||
|
# Маленький хак: удаляем только если обработчик завершился
|
||||||
|
# Проверка len нужна, чтобы не удалить раньше времени в параллельных тасках,
|
||||||
|
# но корректнее просто удалять в блоке первого сообщения.
|
||||||
|
if event == self.album_data[group_id][0]:
|
||||||
|
del self.album_data[group_id]
|
||||||
@@ -2,7 +2,8 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
|
|
||||||
class Character(BaseModel):
|
class Character(BaseModel):
|
||||||
id: int | None
|
id: str
|
||||||
name: str
|
name: str
|
||||||
character_image: bytes
|
character_image: bytes
|
||||||
character_bio: str
|
character_bio: str
|
||||||
|
|
||||||
|
|||||||
19
models/enums.py
Normal file
19
models/enums.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatios(Enum):
|
||||||
|
NINESIXTEEN = '9:16'
|
||||||
|
SIXTEENNINE = '16:9'
|
||||||
|
THREEFOUR = '3:4'
|
||||||
|
FOURTHREE = '4:3'
|
||||||
|
|
||||||
|
|
||||||
|
class Quality(Enum):
|
||||||
|
ONEK = '1K'
|
||||||
|
TWOK = '2K'
|
||||||
|
FOURK = '4K'
|
||||||
|
|
||||||
|
|
||||||
|
class GenType(Enum):
|
||||||
|
TEXT = 'Text'
|
||||||
|
IMAGE = 'Image'
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
@@ -12,5 +15,23 @@ class CharacterRepo:
|
|||||||
character.id = op.inserted_id
|
character.id = op.inserted_id
|
||||||
return character
|
return character
|
||||||
|
|
||||||
async def get_character(self, character_id: int) -> Character:
|
async def get_character(self, character_id: str) -> Character | None:
|
||||||
return await self.collection.find_one({"id": character_id})
|
res = await self.collection.find_one({"_id": ObjectId(character_id)})
|
||||||
|
if res is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Character(**res)
|
||||||
|
|
||||||
|
async def get_all_characters(self) -> List[Character]:
|
||||||
|
docs = await self.collection.find().to_list(None)
|
||||||
|
|
||||||
|
characters = []
|
||||||
|
for doc in docs:
|
||||||
|
# Конвертируем ObjectId в строку и кладем в поле id
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
|
||||||
|
# Создаем объект
|
||||||
|
characters.append(Character(**doc))
|
||||||
|
|
||||||
|
return characters
|
||||||
@@ -25,9 +25,11 @@ async def add_char(message: Message, state: FSMContext, dao: DAO):
|
|||||||
await message.answer("Кайф, теперь напиши ее имя")
|
await message.answer("Кайф, теперь напиши ее имя")
|
||||||
|
|
||||||
|
|
||||||
@router.callback_query(States.char_wait_name)
|
@router.message(States.char_wait_name)
|
||||||
async def new_char_name(message: Message, state: FSMContext, dao: DAO):
|
async def new_char_name(message: Message, state: FSMContext, dao: DAO):
|
||||||
await state.set_data({"name": message.text})
|
data = await state.get_data()
|
||||||
|
data["name"] = message.text
|
||||||
|
await state.set_data(data)
|
||||||
await state.set_state(States.char_wait_bio)
|
await state.set_state(States.char_wait_bio)
|
||||||
await message.answer("А теперь напиши био. Хоть чуть чуть.")
|
await message.answer("А теперь напиши био. Хоть чуть чуть.")
|
||||||
|
|
||||||
@@ -37,7 +39,7 @@ async def new_char_name(message: Message, state: FSMContext, dao: DAO):
|
|||||||
async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||||
# Получаем все накопленные данные
|
# Получаем все накопленные данные
|
||||||
data = await state.get_data()
|
data = await state.get_data()
|
||||||
file_id = data["file_id"]
|
file_id = data["photo"]
|
||||||
name = data["name"]
|
name = data["name"]
|
||||||
bio = message.text
|
bio = message.text
|
||||||
|
|
||||||
@@ -64,9 +66,9 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
|||||||
await message.answer_photo(
|
await message.answer_photo(
|
||||||
photo=BufferedInputFile(photo_bytes, filename="char.png"),
|
photo=BufferedInputFile(photo_bytes, filename="char.png"),
|
||||||
caption=(
|
caption=(
|
||||||
"🎉 **Персонаж создан!**\n\n"
|
"🎉 <b>Персонаж создан!</b>\n\n"
|
||||||
f"👤 **Имя:** {char.name}\n"
|
f"👤 <b>Имя:</b> {char.name}\n"
|
||||||
f"📝 **Био:** {char.character_bio}"
|
f"📝 <b>Био:</b> {char.character_bio}"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
@@ -79,6 +81,47 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
|||||||
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(Command("chars"))
|
||||||
|
async def get_chars(message: Message, state: FSMContext, dao: DAO):
|
||||||
|
wait_msg = await message.answer("Ищем персонажей")
|
||||||
|
chars = await dao.chars.get_all_characters()
|
||||||
|
keyboards = []
|
||||||
|
if len(chars) > 0:
|
||||||
|
for char in chars:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'char_info_{char.id}'))
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[keyboards])
|
||||||
|
else:
|
||||||
|
keyboard = InlineKeyboardMarkup(
|
||||||
|
inline_keyboard=[[InlineKeyboardButton(text="Персонажей нет", callback_data=f'no_chars')]])
|
||||||
|
await message.answer("Сейчас есть такие персонажи:", reply_markup=keyboard)
|
||||||
|
await wait_msg.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith("char_info_"))
|
||||||
|
async def get_char_info(callback_query: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await callback_query.message.delete()
|
||||||
|
wait_msg = await callback_query.message.answer("Ищем инфу о персонаже")
|
||||||
|
char = await dao.chars.get_character(callback_query.data.split("_")[-1])
|
||||||
|
if char is None:
|
||||||
|
await callback_query.message.answer("Информация о персонаже не найдена")
|
||||||
|
await get_chars(callback_query.message, state, dao)
|
||||||
|
await wait_msg.delete()
|
||||||
|
return
|
||||||
|
keyboard = InlineKeyboardMarkup(inline_keyboard=[
|
||||||
|
[InlineKeyboardButton(text="Запросить фото в документе", callback_data=f'char_photo_file_{char.id}')]])
|
||||||
|
await callback_query.message.answer_photo(photo=BufferedInputFile(char.character_image, f"photo_{char.id}.png"), caption=f"👤 <b>Имя:</b> {char.name}\n"
|
||||||
|
f"📝 <b>Био:</b> {char.character_bio}",
|
||||||
|
reply_markup=keyboard)
|
||||||
|
|
||||||
|
await wait_msg.delete()
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith("char_photo_file"))
|
||||||
|
async def get_char_info_photo_file(callback_query: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
char = await dao.chars.get_character(callback_query.data.split("_")[-1])
|
||||||
|
await callback_query.message.answer_document(BufferedInputFile(char.character_image, f"photo_{char.id}.png"))
|
||||||
|
|
||||||
|
|
||||||
# 4. Хендлер-помощник (если отправили команду без файла)
|
# 4. Хендлер-помощник (если отправили команду без файла)
|
||||||
@router.message(Command("add_char"))
|
@router.message(Command("add_char"))
|
||||||
async def add_char_help(message: Message):
|
async def add_char_help(message: Message):
|
||||||
|
|||||||
@@ -1,48 +1,342 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import random
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from aiogram import Router, Bot, F
|
from aiogram import Router, Bot, F
|
||||||
from aiogram.enums import ParseMode
|
from aiogram.enums import ParseMode
|
||||||
|
from aiogram.exceptions import TelegramBadRequest
|
||||||
from aiogram.filters import *
|
from aiogram.filters import *
|
||||||
|
from aiogram.fsm.context import FSMContext
|
||||||
|
from aiogram.fsm.state import StatesGroup, State
|
||||||
from aiogram.types import *
|
from aiogram.types import *
|
||||||
|
from aiogram.types import message
|
||||||
|
|
||||||
|
import keyboards
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from models.Character import Character
|
||||||
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
router = Router()
|
router = Router()
|
||||||
|
|
||||||
@router.message(Command("image"))
|
|
||||||
async def cmd_image_gen(message: Message, command: CommandObject, gemini: GoogleAdapter, bot: Bot):
|
class States(StatesGroup):
|
||||||
# ... ваш код ...
|
gen_mode_wait_char = State()
|
||||||
# Обратите внимание: gemini теперь прилетает аргументом, так как мы сделали dp["gemini"]
|
gen_mode = State()
|
||||||
prompt = command.args
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def init_gen_mode(state: FSMContext, dao: DAO):
|
||||||
|
data = await state.get_data()
|
||||||
|
data['aspect_ratio'] = AspectRatios.NINESIXTEEN.name
|
||||||
|
data['quality'] = Quality.ONEK.name
|
||||||
|
data['type'] = GenType.TEXT.name
|
||||||
|
await state.update_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(Command("gen_mode"))
|
||||||
|
async def gen_mode(message: Message, state: FSMContext, dao: DAO):
|
||||||
|
state_on = await state.get_state()
|
||||||
|
if state_on is None and state_on is not States.gen_mode:
|
||||||
|
await message.answer("Включить режим генерации?", reply_markup=InlineKeyboardMarkup(
|
||||||
|
inline_keyboard=[[InlineKeyboardButton(text="✅ Включить!", callback_data="gen_mode_on")]]))
|
||||||
|
else:
|
||||||
|
await gen_mode_base_msg(message, state, dao, call_type="start")
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(F.data == "gen_mode_on")
|
||||||
|
async def gen_mode_on(callback_query: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await callback_query.message.delete()
|
||||||
|
chars = await dao.chars.get_all_characters()
|
||||||
|
if len(chars) == 0:
|
||||||
|
await callback_query.message.answer(
|
||||||
|
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char")
|
||||||
|
keyboards = []
|
||||||
|
for char in chars:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_char_{char.id}'))
|
||||||
|
await state.set_state(States.gen_mode_wait_char)
|
||||||
|
await callback_query.message.answer("Выбери персонажа",
|
||||||
|
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode_wait_char, F.data.startswith("select_char_"))
|
||||||
|
async def select_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.update_data({"char_id": call.data.split("_")[-1]})
|
||||||
|
await init_gen_mode(state=state, dao=dao)
|
||||||
|
await state.set_state(States.gen_mode)
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao, call_type="start")
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_off')
|
||||||
|
async def gen_mode_off(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.clear()
|
||||||
|
await state.set_data({})
|
||||||
|
await call.message.delete()
|
||||||
|
await call.message.answer("Режим генерации выключен. Нажмите /start для продолжения!")
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_char')
|
||||||
|
async def gen_mode_change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
chars = await dao.chars.get_all_characters()
|
||||||
|
if len(chars) == 0:
|
||||||
|
await call.message.edit_caption(
|
||||||
|
"Персонажи не найдены! Сперва создайте персонажа отправив его фото ФАЙЛОМ с командой /add_char",
|
||||||
|
reply_markup=None)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
keyboards = []
|
||||||
|
for char in chars:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=char.name, callback_data=f'select_new_char_{char.id}'))
|
||||||
|
keyboards.append(InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_char_change"))
|
||||||
|
await call.message.edit_caption("Выбери персонажа", reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards]))
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data.startswith('select_new_char_'))
|
||||||
|
async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.update_data({"char_id": call.data.split("_")[-1]})
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||||||
|
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
keyboards = []
|
||||||
|
for ratio in AspectRatios:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
||||||
|
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||||||
|
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||||||
|
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||||||
|
async def change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.update_data({"aspect_ratio": call.data.split("_")[-1]})
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_quality')
|
||||||
|
async def gen_mode_change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
keyboards = []
|
||||||
|
for quality in Quality:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=quality.value, callback_data=f'select_quality_{quality.name}'))
|
||||||
|
await call.message.edit_caption(caption="Выбери качество",
|
||||||
|
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||||||
|
text="⬅️ Назад", callback_data="gen_mode_cancel_quality_change")]]))
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data.startswith('select_quality_'))
|
||||||
|
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.update_data({"quality": call.data.split("_")[-1]})
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type')
|
||||||
|
async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
keyboards = []
|
||||||
|
for gen_type in GenType:
|
||||||
|
keyboards.append(InlineKeyboardButton(text=gen_type.value, callback_data=f'select_type_{gen_type.name}'))
|
||||||
|
await call.message.edit_caption(caption="Выбери тип генерации", reply_markup=InlineKeyboardMarkup(
|
||||||
|
inline_keyboard=[keyboards,
|
||||||
|
[InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_type_change")]]))
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data.startswith('select_type_'))
|
||||||
|
async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await state.update_data({"type": call.data.split("_")[-1]})
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||||||
|
|
||||||
|
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_char_change')
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_ratio_change')
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_type_change')
|
||||||
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_cancel_quality_change')
|
||||||
|
async def cancel_gen_mode_change(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
|
await gen_mode_base_msg(call.message, state=state, dao=dao)
|
||||||
|
|
||||||
|
|
||||||
|
async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_type='continue'):
|
||||||
|
data = await state.get_data()
|
||||||
|
char: Character = await dao.chars.get_character(data["char_id"])
|
||||||
|
if call_type == "start":
|
||||||
|
await message.answer_photo(BufferedInputFile(char.character_image, f'{char.id}.png'),
|
||||||
|
caption="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n"
|
||||||
|
"<b>Фото девушки грузить не надо, оно загрузится по дефолту</b>\n\n"
|
||||||
|
"Но дополнительные фото можно загрузить.",
|
||||||
|
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await message.edit_caption(
|
||||||
|
caption="🎉 Режим генерации включен!",
|
||||||
|
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||||||
|
except TelegramBadRequest as tbr:
|
||||||
|
await message.edit_text(
|
||||||
|
text="🎉 Режим генерации включен!",
|
||||||
|
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(States.gen_mode, F.media_group_id)
|
||||||
|
async def handle_album(
|
||||||
|
message: Message,
|
||||||
|
album: List[Message],
|
||||||
|
state: FSMContext,
|
||||||
|
dao: DAO,
|
||||||
|
gemini: GoogleAdapter,
|
||||||
|
bot: Bot
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Обработка альбома (группы фото).
|
||||||
|
"""
|
||||||
|
# 1. Ищем промпт (подпись) в любом из сообщений альбома
|
||||||
|
# message.text в альбомах обычно None
|
||||||
|
prompt = next((msg.caption for msg in album if msg.caption), None)
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
await message.answer("⚠️ Напиши промпт.")
|
await message.answer("⚠️ Вы отправили альбом, но не добавили описание (промпт).")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 2. Собираем file_id всех фото
|
||||||
|
file_ids = []
|
||||||
|
for msg in album:
|
||||||
|
if msg.photo:
|
||||||
|
file_ids.append(msg.photo[-1].file_id)
|
||||||
|
elif msg.video:
|
||||||
|
# Если нужно, можно добавить обработку видео (пока пропускаем)
|
||||||
|
pass
|
||||||
|
|
||||||
|
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
|
||||||
wait_msg = await message.answer("🎨 Генерирую...")
|
wait_msg = await message.answer("🎨 Генерирую...")
|
||||||
|
|
||||||
# Получение байтов фото (логика та же)
|
# 3. Вызываем генерацию
|
||||||
image_bytes = None
|
try:
|
||||||
if message.photo:
|
generated_files = await generate_image(
|
||||||
file_io = await bot.download(message.photo[-1].file_id)
|
prompt=prompt,
|
||||||
image_bytes = file_io.getvalue()
|
media=file_ids,
|
||||||
elif message.reply_to_message and message.reply_to_message.photo:
|
state=state,
|
||||||
file_io = await bot.download(message.reply_to_message.photo[-1].file_id)
|
dao=dao,
|
||||||
image_bytes = file_io.getvalue()
|
bot=bot,
|
||||||
|
gemini=gemini
|
||||||
result = await asyncio.to_thread(
|
|
||||||
gemini.generate, prompt=prompt, image_bytes=image_bytes, generate_image=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
|
|
||||||
if result.get("images"):
|
# 4. Отправляем результат
|
||||||
for img in result["images"]:
|
if generated_files:
|
||||||
await message.answer_document(BufferedInputFile(img.read(), "img.png"))
|
for file in generated_files:
|
||||||
elif result.get("text"):
|
await message.answer_document(file, caption="✨ Generated result")
|
||||||
await message.answer(result["text"])
|
|
||||||
else:
|
else:
|
||||||
await message.answer(f"Ошибка: {result.get('error', 'Unknown')}")
|
await message.answer("❌ Генерация не вернула изображений.")
|
||||||
|
await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" )
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await wait_msg.edit_text(f"❌ Ошибка: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(States.gen_mode)
|
||||||
|
async def gen_mode_start(
|
||||||
|
message: Message,
|
||||||
|
state: FSMContext,
|
||||||
|
dao: DAO,
|
||||||
|
gemini: GoogleAdapter,
|
||||||
|
bot: Bot
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Обработка одиночного сообщения (Текст или Фото+Подпись)
|
||||||
|
"""
|
||||||
|
# 1. Получаем промпт (Текст или Подпись)
|
||||||
|
prompt = message.text or message.caption
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
await message.answer("⚠️ Напиши промпт (или отправь фото с подписью).")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. Проверяем, есть ли прикрепленное фото
|
||||||
|
media_ids = []
|
||||||
|
if message.photo:
|
||||||
|
media_ids.append(message.photo[-1].file_id)
|
||||||
|
elif message.reply_to_message and message.reply_to_message.photo:
|
||||||
|
# Поддержка реплая на фото
|
||||||
|
media_ids.append(message.reply_to_message.photo[-1].file_id)
|
||||||
|
|
||||||
|
wait_msg = await message.answer("🎨 Генерирую...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
generated_files = await generate_image(
|
||||||
|
prompt=prompt,
|
||||||
|
media=media_ids, # Передаем список (пустой или с 1 фото)
|
||||||
|
state=state,
|
||||||
|
dao=dao,
|
||||||
|
bot=bot,
|
||||||
|
gemini=gemini
|
||||||
|
)
|
||||||
|
|
||||||
|
await wait_msg.delete()
|
||||||
|
|
||||||
|
if generated_files:
|
||||||
|
for file in generated_files:
|
||||||
|
await message.answer_document(file, caption="✨ Generated result")
|
||||||
|
else:
|
||||||
|
await message.answer("❌ Ничего не сгенерировалось.")
|
||||||
|
await gen_mode_base_msg(message=message, state=state, dao=dao,call_type="start" )
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await wait_msg.edit_text(f"❌ Ошибка: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
prompt: str,
|
||||||
|
media: List[str] | None,
|
||||||
|
state: FSMContext,
|
||||||
|
dao: DAO,
|
||||||
|
bot: Bot,
|
||||||
|
gemini: GoogleAdapter
|
||||||
|
) -> List[BufferedInputFile]:
|
||||||
|
# 1. Получаем данные персонажа
|
||||||
|
data = await state.get_data()
|
||||||
|
char_id = data.get("char_id")
|
||||||
|
|
||||||
|
if not char_id:
|
||||||
|
raise ValueError("Character ID not found in state")
|
||||||
|
|
||||||
|
char: Character = await dao.chars.get_character(char_id)
|
||||||
|
|
||||||
|
# Начинаем список с фото персонажа
|
||||||
|
media_group_bytes = [char.character_image]
|
||||||
|
|
||||||
|
# 2. Скачиваем дополнительные файлы (если переданы)
|
||||||
|
if media:
|
||||||
|
# Создаем задачи для скачивания
|
||||||
|
tasks = [bot.download(file_id) for file_id in media]
|
||||||
|
|
||||||
|
# Запускаем все скачивания параллельно
|
||||||
|
downloaded_files = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Добавляем байты скачанных файлов в общий список
|
||||||
|
for f in downloaded_files:
|
||||||
|
media_group_bytes.append(f.getvalue())
|
||||||
|
|
||||||
|
# 3. Генерация в Gemini
|
||||||
|
generated_images_io = await asyncio.to_thread(
|
||||||
|
gemini.generate_image,
|
||||||
|
prompt=prompt,
|
||||||
|
images_list=media_group_bytes,
|
||||||
|
aspect_ratio=AspectRatios[data['aspect_ratio']],
|
||||||
|
quality=Quality[data['quality']],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Упаковка результата
|
||||||
|
images = []
|
||||||
|
if generated_images_io:
|
||||||
|
for i, img_io in enumerate(generated_images_io):
|
||||||
|
# Важно: img_io.read() работает корректно, если курсор в начале (adapter это делает)
|
||||||
|
images.append(
|
||||||
|
BufferedInputFile(
|
||||||
|
img_io.read(),
|
||||||
|
filename=f"img_{random.randint(1000, 9999)}.png"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
@router.message(F.text)
|
@router.message(F.text)
|
||||||
async def handle_text(message: Message, gemini: GoogleAdapter, bot: Bot):
|
async def handle_text(message: Message, gemini: GoogleAdapter, bot: Bot):
|
||||||
@@ -52,4 +346,3 @@ async def handle_text(message: Message, gemini: GoogleAdapter, bot: Bot):
|
|||||||
await message.answer(result["text"], parse_mode=ParseMode.MARKDOWN)
|
await message.answer(result["text"], parse_mode=ParseMode.MARKDOWN)
|
||||||
else:
|
else:
|
||||||
await message.answer("Ошибка генерации")
|
await message.answer("Ошибка генерации")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user