+ assets
This commit is contained in:
3
main.py
3
main.py
@@ -24,6 +24,7 @@ from routers import char_router
|
|||||||
from routers.auth_router import router as auth_router
|
from routers.auth_router import router as auth_router
|
||||||
from routers.gen_router import router as gen_router
|
from routers.gen_router import router as gen_router
|
||||||
from routers.char_router import router as char_router
|
from routers.char_router import router as char_router
|
||||||
|
from routers.assets_router import router as assets_router
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -57,6 +58,7 @@ dp["gemini"] = GoogleAdapter(api_key=GEMINI_API_KEY) # Инициализиру
|
|||||||
dp.include_router(auth_router)
|
dp.include_router(auth_router)
|
||||||
main_router = Router()
|
main_router = Router()
|
||||||
dp.include_router(main_router)
|
dp.include_router(main_router)
|
||||||
|
dp.include_router(assets_router)
|
||||||
dp.include_router(char_router)
|
dp.include_router(char_router)
|
||||||
dp.include_router(gen_router)
|
dp.include_router(gen_router)
|
||||||
|
|
||||||
@@ -66,6 +68,7 @@ dp.include_router(gen_router)
|
|||||||
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
||||||
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
||||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||||
|
assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
|
||||||
dp.update.middleware(DaoMiddleware(dao=DAO(client=mongo_client)))
|
dp.update.middleware(DaoMiddleware(dao=DAO(client=mongo_client)))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
19
models/Asset.py
Normal file
19
models/Asset.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from datetime import datetime, UTC
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class AssetType(str, Enum):
|
||||||
|
IMAGE = 'image'
|
||||||
|
PROMPT = 'prompt'
|
||||||
|
|
||||||
|
class Asset(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str
|
||||||
|
type: AssetType
|
||||||
|
linked_char_id: Optional[str] = None
|
||||||
|
data: Optional[bytes] = None
|
||||||
|
tg_doc_file_id: str
|
||||||
|
tg_photo_file_id: Optional[str] = None
|
||||||
|
created_at: datetime = datetime.now(UTC)
|
||||||
39
repos/assets_repo.py
Normal file
39
repos/assets_repo.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from models.Asset import Asset
|
||||||
|
|
||||||
|
|
||||||
|
class AssetsRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["assets"]
|
||||||
|
|
||||||
|
async def save_asset(self, asset: Asset) -> Asset:
|
||||||
|
res = await self.collection.insert_one(asset.model_dump())
|
||||||
|
asset.id = res.inserted_id
|
||||||
|
return asset
|
||||||
|
|
||||||
|
async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]:
|
||||||
|
res = await self.collection.find({},{"data":0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||||
|
assets = []
|
||||||
|
for doc in res:
|
||||||
|
# Конвертируем ObjectId в строку и кладем в поле id
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
|
||||||
|
# Создаем объект
|
||||||
|
assets.append(Asset(**doc))
|
||||||
|
|
||||||
|
return assets
|
||||||
|
|
||||||
|
async def get_asset(self, asset_id: str) -> Asset:
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(asset_id)})
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Asset(**res)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_asset(self, asset_id: str, asset: Asset):
|
||||||
|
if not asset.id:
|
||||||
|
raise Exception(f"Asset ID not found: {asset_id}")
|
||||||
|
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": asset.model_dump()})
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from repos.assets_repo import AssetsRepo
|
||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
|
|
||||||
@@ -7,3 +8,4 @@ from repos.user_repo import UsersRepo
|
|||||||
class DAO:
|
class DAO:
|
||||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
self.chars = CharacterRepo(client, db_name)
|
self.chars = CharacterRepo(client, db_name)
|
||||||
|
self.assets = AssetsRepo(client, db_name)
|
||||||
|
|||||||
26
routers/assets_router.py
Normal file
26
routers/assets_router.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from aiogram import Router
|
||||||
|
from aiogram.filters import Command
|
||||||
|
from aiogram.types import Message, InputMediaPhoto, InputMedia, BufferedInputFile
|
||||||
|
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
router = Router()
|
||||||
|
|
||||||
|
|
||||||
|
@router.message(Command("assets"))
|
||||||
|
async def assets_command(msg: Message, dao: DAO):
|
||||||
|
assets = await dao.assets.get_assets(limit=10, offset=0)
|
||||||
|
media_group = []
|
||||||
|
for asset in assets:
|
||||||
|
if asset.tg_photo_file_id:
|
||||||
|
media_group.append(InputMediaPhoto(media=asset.tg_photo_file_id))
|
||||||
|
elif asset.tg_doc_file_id:
|
||||||
|
asset_full_info = await dao.assets.get_asset(asset.id)
|
||||||
|
media_group.append(InputMediaPhoto(media=BufferedInputFile(asset_full_info.data, asset_full_info.name)))
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
mg = await msg.answer_media_group(media_group)
|
||||||
|
for media_index, media in enumerate(mg):
|
||||||
|
if assets[media_index].tg_photo_file_id is None:
|
||||||
|
assets[media_index].tg_photo_file_id = media.photo[-1].file_id
|
||||||
|
await dao.assets.update_asset(assets[media_index].id, assets[media_index])
|
||||||
@@ -14,6 +14,7 @@ from aiogram.types import *
|
|||||||
|
|
||||||
import keyboards
|
import keyboards
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from models.Asset import Asset, AssetType
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
from models.enums import AspectRatios, Quality, GenType
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
@@ -37,16 +38,20 @@ async def init_gen_mode(state: FSMContext, dao: DAO):
|
|||||||
@router.message(Command("image"))
|
@router.message(Command("image"))
|
||||||
async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemini: GoogleAdapter, bot: Bot):
|
async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemini: GoogleAdapter, bot: Bot):
|
||||||
wait_msg = await message.answer("Генерирую...")
|
wait_msg = await message.answer("Генерирую...")
|
||||||
|
|
||||||
if message.photo:
|
if message.photo:
|
||||||
res = await generate_image(prompt=message.text, media=[message.photo[0].file_id], state=state, dao=dao, bot=bot,
|
res = await generate_image(prompt=message.caption, media=[message.photo[0].file_id], state=state, dao=dao,
|
||||||
|
bot=bot,
|
||||||
gemini=gemini)
|
gemini=gemini)
|
||||||
await wait_msg.delete()
|
|
||||||
await message.answer_document(res[0], caption="Generated result 💫")
|
|
||||||
else:
|
else:
|
||||||
res = await generate_image(prompt=message.text, media=None, state=state, dao=dao, bot=bot,
|
res = await generate_image(prompt=message.text, media=None, state=state, dao=dao, bot=bot,
|
||||||
gemini=gemini)
|
gemini=gemini)
|
||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
await message.answer_document(res[0], caption="Generated result 💫")
|
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||||
|
await dao.assets.save_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data,
|
||||||
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("gen_mode"))
|
@router.message(Command("gen_mode"))
|
||||||
@@ -250,11 +255,14 @@ async def handle_album(
|
|||||||
)
|
)
|
||||||
|
|
||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
|
data = await state.get_data()
|
||||||
# 4. Отправляем результат
|
# 4. Отправляем результат
|
||||||
if generated_files:
|
if generated_files:
|
||||||
for file in generated_files:
|
for file in generated_files:
|
||||||
await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
|
await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||||
|
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||||
|
linked_char_id = data["char_id"]))
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Генерация не вернула изображений.")
|
await message.answer("❌ Генерация не вернула изображений.")
|
||||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||||
@@ -290,7 +298,6 @@ async def gen_mode_start(
|
|||||||
media_ids.append(message.reply_to_message.photo[-1].file_id)
|
media_ids.append(message.reply_to_message.photo[-1].file_id)
|
||||||
|
|
||||||
wait_msg = await message.answer("🎨 Генерирую...")
|
wait_msg = await message.answer("🎨 Генерирую...")
|
||||||
|
|
||||||
data = await state.get_data()
|
data = await state.get_data()
|
||||||
try:
|
try:
|
||||||
if GenType[data['type']] is GenType.IMAGE:
|
if GenType[data['type']] is GenType.IMAGE:
|
||||||
@@ -307,7 +314,10 @@ async def gen_mode_start(
|
|||||||
|
|
||||||
if generated_files:
|
if generated_files:
|
||||||
for file in generated_files:
|
for file in generated_files:
|
||||||
await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
|
await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||||
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||||
|
linked_char_id=data["char_id"]))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Ничего не сгенерировалось.")
|
await message.answer("❌ Ничего не сгенерировалось.")
|
||||||
|
|||||||
Reference in New Issue
Block a user