diff --git a/main.py b/main.py index a117e91..c63e858 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,7 @@ from routers import char_router from routers.auth_router import router as auth_router from routers.gen_router import router as gen_router from routers.char_router import router as char_router +from routers.assets_router import router as assets_router load_dotenv() @@ -57,6 +58,7 @@ dp["gemini"] = GoogleAdapter(api_key=GEMINI_API_KEY) # Инициализиру dp.include_router(auth_router) main_router = Router() dp.include_router(main_router) +dp.include_router(assets_router) dp.include_router(char_router) dp.include_router(gen_router) @@ -66,6 +68,7 @@ dp.include_router(gen_router) main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID)) gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID)) gen_router.message.middleware(AlbumMiddleware(latency=0.8)) +assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID)) dp.update.middleware(DaoMiddleware(dao=DAO(client=mongo_client))) diff --git a/models/Asset.py b/models/Asset.py new file mode 100644 index 0000000..6e81d5d --- /dev/null +++ b/models/Asset.py @@ -0,0 +1,19 @@ +from datetime import datetime, UTC +from enum import Enum +from typing import Optional + +from pydantic import BaseModel + +class AssetType(str, Enum): + IMAGE = 'image' + PROMPT = 'prompt' + +class Asset(BaseModel): + id: Optional[str] = None + name: str + type: AssetType + linked_char_id: Optional[str] = None + data: Optional[bytes] = None + tg_doc_file_id: str + tg_photo_file_id: Optional[str] = None + created_at: datetime = datetime.now(UTC) diff --git a/repos/assets_repo.py b/repos/assets_repo.py new file mode 100644 index 0000000..58d45e0 --- /dev/null +++ b/repos/assets_repo.py @@ -0,0 +1,39 @@ +from typing import List + +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient + +from models.Asset import Asset + + +class AssetsRepo: + def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): + self.collection = client[db_name]["assets"] + + async def save_asset(self, asset: Asset) -> Asset: + res = await self.collection.insert_one(asset.model_dump()) + asset.id = res.inserted_id + return asset + + async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]: + res = await self.collection.find({},{"data":0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None) + assets = [] + for doc in res: + # Конвертируем ObjectId в строку и кладем в поле id + doc["id"] = str(doc.pop("_id")) + + # Создаем объект + assets.append(Asset(**doc)) + + return assets + + async def get_asset(self, asset_id: str) -> Asset: + res = await self.collection.find_one({"_id": ObjectId(asset_id)}) + res["id"] = str(res.pop("_id")) + return Asset(**res) + + + async def update_asset(self, asset_id: str, asset: Asset): + if not asset.id: + raise Exception(f"Asset ID not found: {asset_id}") + await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": asset.model_dump()}) diff --git a/repos/dao.py b/repos/dao.py index 65fc769..0034d15 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -1,5 +1,6 @@ from motor.motor_asyncio import AsyncIOMotorClient +from repos.assets_repo import AssetsRepo from repos.char_repo import CharacterRepo from repos.user_repo import UsersRepo @@ -7,3 +8,4 @@ from repos.user_repo import UsersRepo class DAO: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.chars = CharacterRepo(client, db_name) + self.assets = AssetsRepo(client, db_name) diff --git a/routers/assets_router.py b/routers/assets_router.py new file mode 100644 index 0000000..9ab7e5e --- /dev/null +++ b/routers/assets_router.py @@ -0,0 +1,26 @@ +from aiogram import Router +from aiogram.filters import Command +from aiogram.types import Message, InputMediaPhoto, InputMedia, BufferedInputFile + +from repos.dao import DAO + +router = Router() + + +@router.message(Command("assets")) +async def assets_command(msg: Message, dao: DAO): + assets = await dao.assets.get_assets(limit=10, offset=0) + media_group = [] + for asset in assets: + if asset.tg_photo_file_id: + media_group.append(InputMediaPhoto(media=asset.tg_photo_file_id)) + elif asset.tg_doc_file_id: + asset_full_info = await dao.assets.get_asset(asset.id) + media_group.append(InputMediaPhoto(media=BufferedInputFile(asset_full_info.data, asset_full_info.name))) + else: + continue + mg = await msg.answer_media_group(media_group) + for media_index, media in enumerate(mg): + if assets[media_index].tg_photo_file_id is None: + assets[media_index].tg_photo_file_id = media.photo[-1].file_id + await dao.assets.update_asset(assets[media_index].id, assets[media_index]) diff --git a/routers/gen_router.py b/routers/gen_router.py index 676613c..7fa006d 100644 --- a/routers/gen_router.py +++ b/routers/gen_router.py @@ -14,6 +14,7 @@ from aiogram.types import * import keyboards from adapters.google_adapter import GoogleAdapter +from models.Asset import Asset, AssetType from models.Character import Character from models.enums import AspectRatios, Quality, GenType from repos.dao import DAO @@ -37,16 +38,20 @@ async def init_gen_mode(state: FSMContext, dao: DAO): @router.message(Command("image")) async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemini: GoogleAdapter, bot: Bot): wait_msg = await message.answer("Генерирую...") + if message.photo: - res = await generate_image(prompt=message.text, media=[message.photo[0].file_id], state=state, dao=dao, bot=bot, + res = await generate_image(prompt=message.caption, media=[message.photo[0].file_id], state=state, dao=dao, + bot=bot, gemini=gemini) - await wait_msg.delete() - await message.answer_document(res[0], caption="Generated result 💫") + + else: res = await generate_image(prompt=message.text, media=None, state=state, dao=dao, bot=bot, gemini=gemini) - await wait_msg.delete() - await message.answer_document(res[0], caption="Generated result 💫") + await wait_msg.delete() + doc = await message.answer_document(res[0], caption="Generated result 💫") + await dao.assets.save_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data, + tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) @router.message(Command("gen_mode")) @@ -250,14 +255,17 @@ async def handle_album( ) await wait_msg.delete() - + data = await state.get_data() # 4. Отправляем результат if generated_files: for file in generated_files: - await message.answer_document(file, caption="✨ Generated result") + doc = await message.answer_document(file, caption="✨ Generated result") + await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, + tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, + linked_char_id = data["char_id"])) else: await message.answer("❌ Генерация не вернула изображений.") - await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") + await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") except Exception as e: await wait_msg.edit_text(f"❌ Ошибка: {e}") @@ -290,7 +298,6 @@ async def gen_mode_start( media_ids.append(message.reply_to_message.photo[-1].file_id) wait_msg = await message.answer("🎨 Генерирую...") - data = await state.get_data() try: if GenType[data['type']] is GenType.IMAGE: @@ -307,7 +314,10 @@ async def gen_mode_start( if generated_files: for file in generated_files: - await message.answer_document(file, caption="✨ Generated result") + doc = await message.answer_document(file, caption="✨ Generated result") + await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, + tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, + linked_char_id=data["char_id"])) else: await message.answer("❌ Ничего не сгенерировалось.")