From 35de8efc565f627be6b1a4af5b79358a848ceb43 Mon Sep 17 00:00:00 2001 From: xds Date: Wed, 4 Feb 2026 15:10:55 +0300 Subject: [PATCH] + api --- adapters/google_adapter.py | 4 +- api/dependency.py | 30 +++++ api/endpoints/assets.py | 92 -------------- api/endpoints/assets_router.py | 74 ++++++++++++ api/endpoints/character_router.py | 29 +++-- api/endpoints/generation_router.py | 47 ++++++++ api/models/AssetDTO.py | 19 +++ api/models/GenerationRequest.py | 39 ++++++ api/models/__init__.py | 0 api/service/__init__.py | 0 api/service/generation_service.py | 186 +++++++++++++++++++++++++++++ main.py | 17 ++- models/Asset.py | 5 +- models/Generation.py | 27 +++++ models/enums.py | 46 +++++-- repos/assets_repo.py | 23 +++- repos/dao.py | 2 + repos/generation_repo.py | 43 +++++++ routers/char_router.py | 2 +- routers/gen_router.py | 16 +-- 20 files changed, 566 insertions(+), 135 deletions(-) create mode 100644 api/dependency.py delete mode 100644 api/endpoints/assets.py create mode 100644 api/endpoints/assets_router.py create mode 100644 api/endpoints/generation_router.py create mode 100644 api/models/AssetDTO.py create mode 100644 api/models/GenerationRequest.py create mode 100644 api/models/__init__.py create mode 100644 api/service/__init__.py create mode 100644 api/service/generation_service.py create mode 100644 models/Generation.py create mode 100644 repos/generation_repo.py diff --git a/adapters/google_adapter.py b/adapters/google_adapter.py index 52dcfb9..5a8e99e 100644 --- a/adapters/google_adapter.py +++ b/adapters/google_adapter.py @@ -80,8 +80,8 @@ class GoogleAdapter: response_modalities=['IMAGE'], temperature=1.0, image_config=types.ImageConfig( - aspect_ratio=aspect_ratio.value, - image_size=quality.value + aspect_ratio=aspect_ratio.value_ratio, + image_size=quality.value_quality ), ) ) diff --git a/api/dependency.py b/api/dependency.py new file mode 100644 index 0000000..70f4479 --- /dev/null +++ b/api/dependency.py @@ -0,0 +1,30 @@ +# dependency.py +from fastapi import Request, Depends +from motor.motor_asyncio import AsyncIOMotorClient + +from adapters.google_adapter import GoogleAdapter +from api.service.generation_service import GenerationService +from repos.dao import DAO + + +# ... ваши импорты ... + +# Провайдеры "сырых" клиентов из состояния приложения +def get_mongo_client(request: Request) -> AsyncIOMotorClient: + return request.app.state.mongo_client + +def get_gemini_client(request: Request) -> GoogleAdapter: + return request.app.state.gemini_client + +# Провайдер DAO (собирается из mongo_client) +def get_dao(mongo_client: AsyncIOMotorClient = Depends(get_mongo_client)) -> DAO: + # FastAPI кэширует результат Depends в рамках одного запроса, + # так что DAO создастся один раз за запрос. + return DAO(mongo_client) + +# Провайдер сервиса (собирается из DAO и Gemini) +def get_generation_service( + dao: DAO = Depends(get_dao), + gemini: GoogleAdapter = Depends(get_gemini_client) +) -> GenerationService: + return GenerationService(dao, gemini) \ No newline at end of file diff --git a/api/endpoints/assets.py b/api/endpoints/assets.py deleted file mode 100644 index e2243d1..0000000 --- a/api/endpoints/assets.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import List, Optional - -from aiogram.types import BufferedInputFile -from fastapi import APIRouter, UploadFile, File, Form -from fastapi.openapi.models import MediaType -from starlette.exceptions import HTTPException -from starlette.requests import Request -from starlette.responses import Response, JSONResponse - -from models.Asset import Asset, AssetType -from repos.dao import DAO - -router = APIRouter(prefix="/api/assets", tags=["Assets"]) - - -@router.get("/{asset_id}") -async def get_asset(asset_id: str, request: Request) -> Response: - dao = request.app.state.dao - asset = await dao.assets.get_asset(asset_id) - # 2. Проверка на существование - if not asset: - raise HTTPException(status_code=404, detail="Asset not found") - headers = { - # Кэшировать на 1 год (31536000 сек) - "Cache-Control": "public, max-age=31536000, immutable" - } - return Response(content=asset.data, media_type="image/png", headers=headers) - - -@router.get("") -async def get_assets(request: Request) -> List[Asset]: - dao: DAO = request.app.state.dao - assets = await dao.assets.get_assets() - - return assets - - -@router.post("/upload", response_model=Asset) -async def upload_asset( - request: Request, - # Файл обязателен - file: UploadFile = File(...), - # Остальные поля принимаем как Form-data (не JSON!) - name: str = Form(...), - type: AssetType = Form(...), - linked_char_id: Optional[str] = Form(None) -): - """ - Загружает файл, отправляет его в ТГ (для получения ID) и сохраняет в БД. - """ - # 1. Читаем байты файла - file_content = await file.read() - - if not file_content: - raise HTTPException(status_code=400, detail="File is empty") - - # 2. Получаем необходимые зависимости из state - bot = request.app.state.bot # Бот нужен, чтобы получить tg_file_id - admin_id = request.app.state.admin_id # Куда отправлять файл "на хранение" - dao = request.app.state.assets_dao - - # 3. Отправляем файл в Telegram, чтобы получить tg_doc_file_id - # (Это обязательно, так как ваша модель требует этот ID) - try: - tg_msg = await bot.send_document( - chat_id=admin_id, - document=BufferedInputFile(file_content, filename=file.filename), - caption=f"📥 Uploaded via API: {name}" - ) - # Получаем ID документа из ответа ТГ - tg_doc_id = tg_msg.document.file_id - - # Если это картинка, можно попытаться достать и photo_id (для превью) - # Но send_document обычно возвращает именно документ. - - except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to upload to Telegram: {e}") - - # 4. Создаем объект Asset - # Pydantic сам подставит created_at и вычислит link - new_asset = Asset( - name=name, - type=type, - linked_char_id=linked_char_id, - data=file_content, # Сохраняем байты в БД - tg_doc_file_id=tg_doc_id # ID из телеграма - ) - - # 5. Сохраняем через DAO - saved_asset = await dao.save_asset(new_asset) - - return saved_asset diff --git a/api/endpoints/assets_router.py b/api/endpoints/assets_router.py new file mode 100644 index 0000000..e923f0c --- /dev/null +++ b/api/endpoints/assets_router.py @@ -0,0 +1,74 @@ +from typing import List, Optional + +from aiogram.types import BufferedInputFile +from fastapi import APIRouter, UploadFile, File, Form, Depends +from fastapi.openapi.models import MediaType +from starlette import status +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from api.models.AssetDTO import AssetsResponse, AssetResponse +from models.Asset import Asset, AssetType +from repos.dao import DAO +from api.dependency import get_dao + +router = APIRouter(prefix="/api/assets", tags=["Assets"]) + + +@router.get("/{asset_id}") +async def get_asset(asset_id: str, request: Request,dao: DAO = Depends(get_dao),) -> Response: + + asset = await dao.assets.get_asset(asset_id) + # 2. Проверка на существование + if not asset: + raise HTTPException(status_code=404, detail="Asset not found") + headers = { + # Кэшировать на 1 год (31536000 сек) + "Cache-Control": "public, max-age=31536000, immutable" + } + return Response(content=asset.data, media_type="image/png", headers=headers) + + +@router.get("") +async def get_assets(request: Request, dao: DAO = Depends(get_dao), limit: int = 10, offset: int = 0) -> AssetsResponse: + assets = await dao.assets.get_assets(limit, offset) + assets = await dao.assets.get_assets() + total_count = await dao.assets.get_asset_count() + + return AssetsResponse(assets=assets, total_count=total_count) + + +@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED) +async def upload_asset( + file: UploadFile = File(...), + linked_char_id: Optional[str] = Form(None), + dao: DAO = Depends(get_dao), +): + if not file.content_type: + raise HTTPException(status_code=400, detail="Unknown file type") + + if not file.content_type.startswith("image/"): + raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}") + + data = await file.read() + if not data: + raise HTTPException(status_code=400, detail="Empty file") + + asset = Asset( + name=file.filename or "upload", + type=AssetType.IMAGE, + linked_char_id=linked_char_id, + data=data, + ) + + asset_id = await dao.assets.create_asset(asset) + asset.id = str(asset_id) + + return AssetResponse( + id=asset.id, + name=asset.name, + type=asset.type.value if hasattr(asset.type, "value") else asset.type, + linked_char_id=asset.linked_char_id, + created_at=asset.created_at, + ) \ No newline at end of file diff --git a/api/endpoints/character_router.py b/api/endpoints/character_router.py index 9fc6045..8f20bba 100644 --- a/api/endpoints/character_router.py +++ b/api/endpoints/character_router.py @@ -1,35 +1,44 @@ -from typing import List +from typing import List, Any, Coroutine -from fastapi import APIRouter +from fastapi import APIRouter, Depends +from pydantic import BaseModel from starlette.exceptions import HTTPException from starlette.requests import Request +from api.models.AssetDTO import AssetsResponse +from api.models.GenerationRequest import GenerationRequest, GenerationResponse from models.Asset import Asset from models.Character import Character from repos.dao import DAO +from api.dependency import get_dao router = APIRouter(prefix="/api/characters", tags=["Characters"]) @router.get("/", response_model=List[Character]) -async def get_characters(request: Request) -> List[Character]: - dao: DAO = request.app.state.dao +async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]: characters = await dao.chars.get_all_characters() return characters -@router.get("/{character_id}/assets", response_model=List[Asset]) -async def get_character_assets(character_id: str, request: Request) -> List[Asset]: - dao: DAO = request.app.state.dao +@router.get("/{character_id}/assets", response_model=AssetsResponse) +async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10, + offset: int = 0, ) -> AssetsResponse: character = await dao.chars.get_character(character_id) if character is None: raise HTTPException(status_code=404, detail="Character not found") - return await dao.assets.get_assets_by_char_id(character_id) + assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset) + total_count = await dao.assets.get_asset_count(character_id) + return AssetsResponse(assets=assets, total_count=total_count) @router.get("/{character_id}", response_model=Character) -async def get_character_by_id(character_id: str, request: Request) -> Character: - dao: DAO = request.app.state.dao +async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character: character = await dao.chars.get_character(character_id) return character + +@router.post("/{character_id}/_run", response_model=Asset) +async def post_character_generation(character_id: str, generation: GenerationRequest, + request: Request) -> GenerationResponse: + generation_service = request.app.state.generation_service diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py new file mode 100644 index 0000000..72d103a --- /dev/null +++ b/api/endpoints/generation_router.py @@ -0,0 +1,47 @@ +from typing import List, Optional + +from fastapi import APIRouter +from fastapi.params import Depends +from starlette.requests import Request + +from api import service +from api.dependency import get_generation_service + +from api.models.GenerationRequest import GenerationResponse, GenerationRequest, PromptResponse, PromptRequest +from api.service.generation_service import GenerationService +from models.Generation import Generation + +router = APIRouter(prefix='/api/generations', tags=["Generation"]) + + +@router.post("/prompt-assistant", response_model=PromptResponse) +async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, + generation_service: GenerationService = Depends( + get_generation_service)) -> PromptResponse: + generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) + return PromptResponse(prompt=generated_prompt) + + +@router.get("", response_model=List[GenerationResponse]) +async def get_generations(character_id: Optional[str], limit: int = 10, offset: int = 0, + generation_service: GenerationService = Depends(get_generation_service)): + return await generation_service.get_generations(character_id, limit=limit, offset=offset) + + +@router.post("/_run", response_model=GenerationResponse) +async def post_generation(generation: GenerationRequest, request: Request, + generation_service: GenerationService = Depends( + get_generation_service)) -> GenerationResponse: + return await generation_service.create_generation_task(generation) + + +@router.get("/{generation_id}", response_model=GenerationResponse) +async def get_generation(generation_id: str, + generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse: + return await generation_service.get_generation(generation_id) + + +@router.get("/running") +async def get_running_generations(request: Request, + generation_service: GenerationService = Depends(get_generation_service)): + return await generation_service.get_running_generations() diff --git a/api/models/AssetDTO.py b/api/models/AssetDTO.py new file mode 100644 index 0000000..4df084e --- /dev/null +++ b/api/models/AssetDTO.py @@ -0,0 +1,19 @@ +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel + +from models.Asset import Asset + + +class AssetsResponse(BaseModel): + assets: List[Asset] + total_count: int + + +class AssetResponse(BaseModel): + id: str + name: str + type: str + linked_char_id: Optional[str] = None + created_at: datetime \ No newline at end of file diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py new file mode 100644 index 0000000..fd30755 --- /dev/null +++ b/api/models/GenerationRequest.py @@ -0,0 +1,39 @@ +from datetime import datetime, UTC +from typing import List, Optional + +from pydantic import BaseModel + +from models.Asset import Asset +from models.Generation import GenerationStatus +from models.enums import AspectRatios, Quality + + +class GenerationRequest(BaseModel): + linked_character_id: Optional[str] = None + aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN + quality: Quality = Quality.ONEK + prompt: str + assets_list: List[str] + + +class GenerationResponse(BaseModel): + id: str + status: GenerationStatus + linked_character_id: Optional[str] = None + aspect_ratio: AspectRatios + quality: Quality + prompt: str + assets_list: List[str] + result: Optional[str] = None + created_at: datetime = datetime.now(UTC) + updated_at: datetime = datetime.now(UTC) + + + +class PromptRequest(BaseModel): + prompt: str + linked_assets: List[str] = [] + + +class PromptResponse(BaseModel): + prompt: str \ No newline at end of file diff --git a/api/models/__init__.py b/api/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/service/__init__.py b/api/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/service/generation_service.py b/api/service/generation_service.py new file mode 100644 index 0000000..04aee4b --- /dev/null +++ b/api/service/generation_service.py @@ -0,0 +1,186 @@ +import asyncio +import logging +import random +from datetime import datetime, UTC +from typing import List, Optional +from io import BytesIO + +from adapters.google_adapter import GoogleAdapter +from api.models.GenerationRequest import GenerationRequest, GenerationResponse +# Импортируйте ваши модели DAO, Asset, Generation корректно +from models.Asset import Asset, AssetType +from models.Generation import Generation, GenerationStatus +from models.enums import AspectRatios, Quality +from repos.dao import DAO + +logger = logging.getLogger(__name__) + + +# --- Вспомогательная функция генерации --- +async def generate_image_task( + prompt: str, + media_group_bytes: List[bytes], + aspect_ratio: AspectRatios, + quality: Quality, + gemini: GoogleAdapter +) -> List[bytes]: + """ + Обертка для вызова синхронного метода Gemini в отдельном потоке. + Возвращает список байтов сгенерированных изображений. + """ + + # Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop + generated_images_io: List[BytesIO] = await asyncio.to_thread( + gemini.generate_image, + prompt=prompt, + images_list=media_group_bytes, + aspect_ratio=aspect_ratio, + quality=quality, + ) + + images_bytes = [] + if generated_images_io: + for img_io in generated_images_io: + # Читаем байты из BytesIO + img_io.seek(0) + content = img_io.read() + images_bytes.append(content) + + # Закрываем поток + img_io.close() + + return images_bytes + + +class GenerationService: + def __init__(self, dao: DAO, gemini: GoogleAdapter): + self.dao = dao + self.gemini = gemini + + async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str: + future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. + I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. + ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:""" + future_prompt += prompt + assets_data = [] + if assets is not None: + assets_data = await self.dao.assets.get_assets_by_ids(assets) + generated_prompt = self.gemini.generate_text(future_prompt, assets_data) + return generated_prompt + + async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[ + Generation]: + return await self.dao.generations.get_generations(limit=limit, offset=offset) + + async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]: + gen = await self.dao.generations.get_generation(generation_id) + if gen is None: + return None + else: + return GenerationResponse(**gen.model_dump()) + + async def get_running_generations(self) -> List[Generation]: + return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING) + + async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse: + gen_id = None + generation_model = None + + try: + generation_model = Generation(**generation_request.model_dump()) + gen_id = await self.dao.generations.create_generation(generation_model) + generation_model.id = gen_id + + async def runner(gen): + try: + await self.create_generation(gen) + except Exception: + # если генерация уже пошла и упала — пометим FAILED + try: + db_gen = await self.dao.generations.get_generation(gen.id) + db_gen.status = GenerationStatus.FAILED + await self.dao.generations.update_generation(db_gen) + except Exception: + logger.exception("Failed to mark generation as FAILED") + logger.exception("create_generation task failed") + + asyncio.create_task(runner(generation_model)) + + return GenerationResponse(**generation_model.model_dump()) + + except Exception: + # если не успели создать запись — нечего помечать + if gen_id is not None: + try: + gen = await self.dao.generations.get_generation(gen_id) + gen.status = GenerationStatus.FAILED + await self.dao.generations.update_generation(gen) + except Exception: + logger.exception("Failed to mark generation as FAILED in create_generation_task") + raise + + async def create_generation(self, generation: Generation): + + # 2. Получаем ассеты-референсы (если они есть) + reference_assets: List[Asset] = [] + media_group_bytes: List[bytes] = [] + generation_prompt = "You are creating image. " + if generation.linked_character_id is not None: + char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True) + if char_info is None: + raise Exception(f"Character ID {generation.linked_character_id} not found") + media_group_bytes.append(char_info.character_image_data) + generation_prompt = f"""You are creating image for {char_info.character_bio}""" + + reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) + # Извлекаем данные (bytes) из ассетов для отправки в Gemini + # Фильтруем, чтобы отправлять только картинки, и где есть data + media_group_bytes.extend( + asset.data + for asset in reference_assets + if asset.data is not None and asset.type == AssetType.IMAGE + ) + generation_prompt+=f"PROMPT: {generation.prompt}" + + # 3. Запускаем процесс генерации + try: + generated_bytes_list = await generate_image_task( + prompt=generation_prompt, # или request.prompt + media_group_bytes=media_group_bytes, + aspect_ratio=generation.aspect_ratio, # предполагаем поля в request + quality=generation.quality, + gemini=self.gemini + ) + except Exception as e: + # Тут стоит добавить логирование ошибки + logging.error(f"Generation failed: {e}") + # Можно обновить статус генерации на FAILED в БД + raise e + + # 4. Сохраняем полученные изображения как новые Ассеты + created_assets: List[Asset] = [] + + for idx, img_bytes in enumerate(generated_bytes_list): + new_asset = Asset( + name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}", + type=AssetType.IMAGE, + linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу + data=img_bytes, + # Остальные поля заполнятся дефолтными значениями (created_at) + ) + + # Сохраняем в БД + asset_id = await self.dao.assets.create_asset(new_asset) + new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы + + created_assets.append(new_asset) + + # 5. (Опционально) Обновляем запись генерации ссылками на результаты + # Предполагаем, что у модели Generation есть поле result_asset_ids + result_ids = [a.id for a in created_assets] + + generation.assets_list = result_ids + generation.status = GenerationStatus.DONE + generation.updated_at = datetime.now(UTC) + generation.tech_prompt = generation_prompt + await self.dao.generations.update_generation(generation) diff --git a/main.py b/main.py index 882a8b0..15e4012 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from starlette.middleware.cors import CORSMiddleware # --- ИМПОРТЫ ПРОЕКТА --- from adapters.google_adapter import GoogleAdapter +from api.service.generation_service import GenerationService from middlewares.album import AlbumMiddleware from middlewares.auth import AuthMiddleware from middlewares.dao import DaoMiddleware @@ -33,8 +34,9 @@ from routers.auth_router import router as auth_router from routers.gen_router import router as gen_router from routers.char_router import router as char_router from routers.assets_router import router as assets_router # Роутер бота для ассетов -from api.endpoints.assets import router as api_assets_router # Роутер FastAPI +from api.endpoints.assets_router import router as api_assets_router # Роутер FastAPI from api.endpoints.character_router import router as api_char_router # Роутер FastAPI +from api.endpoints.generation_router import router as api_gen_router load_dotenv() @@ -47,7 +49,7 @@ ADMIN_ID = int(os.getenv("ADMIN_ID", 0)) def setup_logging(): - logging.basicConfig(level=logging.INFO, + logging.basicConfig(level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") @@ -61,6 +63,8 @@ mongo_client = AsyncIOMotorClient(MONGO_HOST) users_repo = UsersRepo(mongo_client) char_repo = CharacterRepo(mongo_client) dao = DAO(mongo_client) # Главный DAO для бота +gemini = GoogleAdapter(api_key=GEMINI_API_KEY) +generation_service = GenerationService(dao, gemini) # Dispatcher dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME)) @@ -68,7 +72,7 @@ dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME)) # Внедрение зависимостей (глобально для бота) dp["repo"] = users_repo dp["admin_id"] = ADMIN_ID -dp["gemini"] = GoogleAdapter(api_key=GEMINI_API_KEY) +dp["gemini"] = gemini # --- НАСТРОЙКА РОУТЕРОВ БОТА --- @@ -108,7 +112,9 @@ async def lifespan(app: FastAPI): # Инициализируем DAO для ассетов и кладем в state приложения # Теперь в эндпоинтах можно делать request.app.state.assets_dao - app.state.dao = dao + + app.state.mongo_client = mongo_client + app.state.gemini_client = gemini print("✅ DB & DAO initialized") @@ -152,6 +158,7 @@ app.add_middleware( # Подключаем роутер API app.include_router(api_assets_router) app.include_router(api_char_router) +app.include_router(api_gen_router) # --- ХЕНДЛЕРЫ БОТА (Main Router) --- @@ -179,7 +186,7 @@ if __name__ == "__main__": async def main(): # Создаем конфигурацию uvicorn вручную # loop="asyncio" заставляет использовать стандартный цикл - config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio") + config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120) server = uvicorn.Server(config) # Запускаем сервер (lifespan запустится внутри) diff --git a/models/Asset.py b/models/Asset.py index 130a476..0333495 100644 --- a/models/Asset.py +++ b/models/Asset.py @@ -1,6 +1,6 @@ from datetime import datetime, UTC from enum import Enum -from typing import Optional, Any +from typing import Optional, Any, List from pydantic import BaseModel, computed_field @@ -16,8 +16,9 @@ class Asset(BaseModel): type: AssetType linked_char_id: Optional[str] = None data: Optional[bytes] = None - tg_doc_file_id: str + tg_doc_file_id: Optional[str] = None tg_photo_file_id: Optional[str] = None + tags: List[str] = [] created_at: datetime = datetime.now(UTC) # --- CALCULATED FIELD --- diff --git a/models/Generation.py b/models/Generation.py new file mode 100644 index 0000000..d875b32 --- /dev/null +++ b/models/Generation.py @@ -0,0 +1,27 @@ +from datetime import datetime, UTC +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel + +from models.Asset import Asset +from models.enums import AspectRatios, Quality + + +class GenerationStatus(str, Enum): + RUNNING = "running" + DONE = "done" + FAILED = "failed" + +class Generation(BaseModel): + id: Optional[str] = None + status: GenerationStatus = GenerationStatus.RUNNING + linked_character_id: Optional[str] = None + aspect_ratio: AspectRatios + quality: Quality + prompt: str + tech_prompt: Optional[str] = None + assets_list: List[str] + result: Optional[str] = None + created_at: datetime = datetime.now(UTC) + updated_at: datetime = datetime.now(UTC) diff --git a/models/enums.py b/models/enums.py index 10d5c49..099622f 100644 --- a/models/enums.py +++ b/models/enums.py @@ -1,19 +1,43 @@ from enum import Enum -class AspectRatios(Enum): - NINESIXTEEN = '9:16' - SIXTEENNINE = '16:9' - THREEFOUR = '3:4' - FOURTHREE = '4:3' +class AspectRatios(str, Enum): + NINESIXTEEN = "NINESIXTEEN" + SIXTEENNINE = "SIXTEENNINE" + THREEFOUR = "THREEFOUR" + FOURTHREE = "FOURTHREE" + + @property + def value_ratio(self) -> str: + return { + AspectRatios.NINESIXTEEN: "9:16", + AspectRatios.SIXTEENNINE: "16:9", + AspectRatios.THREEFOUR: "3:4", + AspectRatios.FOURTHREE: "4:3", + }[self] -class Quality(Enum): - ONEK = '1K' - TWOK = '2K' - FOURK = '4K' +class Quality(str, Enum): + ONEK = 'ONEK' + TWOK = 'TWOK' + FOURK = 'FOURK' + + @property + def value_quality(self) -> str: + return { + Quality.ONEK: '1K', + Quality.TWOK: '2K', + Quality.FOURK: '4K' + }[self] -class GenType(Enum): +class GenType(str, Enum): TEXT = 'Text' - IMAGE = 'Image' \ No newline at end of file + IMAGE = 'Image' + + @property + def value_type(self) -> str: + return { + GenType.TEXT: 'Text', + GenType.IMAGE: 'Image' + }[self] diff --git a/repos/assets_repo.py b/repos/assets_repo.py index a33d3fc..da3af33 100644 --- a/repos/assets_repo.py +++ b/repos/assets_repo.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient @@ -10,10 +10,10 @@ class AssetsRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["assets"] - async def save_asset(self, asset: Asset) -> Asset: + async def create_asset(self, asset: Asset) -> str: res = await self.collection.insert_one(asset.model_dump()) - asset.id = res.inserted_id - return asset + + return str(res.inserted_id) async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]: res = await self.collection.find({}, {"data": 0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None) @@ -27,6 +27,7 @@ class AssetsRepo: return assets + async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset: projection = {"_id": 1, "name": 1, "type": 1, "tg_doc_file_id": 1} if with_data: @@ -54,3 +55,17 @@ class AssetsRepo: doc["id"] = str(doc.pop("_id")) assets.append(Asset(**doc)) return assets + + async def get_asset_count(self, character_id: Optional[str] = None) -> int: + return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {}) + + + + async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: + object_ids = [ObjectId(asset_id) for asset_id in asset_ids] + res = self.collection.find({"_id": {"$in": object_ids}}) + assets = [] + async for doc in res: + doc["id"] = str(doc.pop("_id")) + assets.append(Asset(**doc)) + return assets diff --git a/repos/dao.py b/repos/dao.py index 0034d15..e8c03f3 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -2,6 +2,7 @@ from motor.motor_asyncio import AsyncIOMotorClient from repos.assets_repo import AssetsRepo from repos.char_repo import CharacterRepo +from repos.generation_repo import GenerationRepo from repos.user_repo import UsersRepo @@ -9,3 +10,4 @@ class DAO: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.chars = CharacterRepo(client, db_name) self.assets = AssetsRepo(client, db_name) + self.generations = GenerationRepo(client, db_name) diff --git a/repos/generation_repo.py b/repos/generation_repo.py new file mode 100644 index 0000000..ed5af66 --- /dev/null +++ b/repos/generation_repo.py @@ -0,0 +1,43 @@ +from typing import Optional, List + +from PIL.ImageChops import offset +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient + +from api.models.GenerationRequest import GenerationResponse +from models.Generation import Generation, GenerationStatus + + +class GenerationRepo: + def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): + self.collection = client[db_name]["generations"] + + async def create_generation(self, generation: Generation) -> str: + res = await self.collection.insert_one(generation.model_dump()) + return str(res.inserted_id) + + async def get_generation(self, generation_id: str) -> Optional[Generation]: + res = await self.collection.find_one({"_id": ObjectId(generation_id)}) + if res is None: + return None + else: + res["id"] = str(res.pop("_id")) + return Generation(**res) + + async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, + limit: int = 10, offset: int = 10) -> List[Generation]: + args = {} + if character_id is not None: + args["character_id"] = character_id + if status is not None: + args["status"] = status + res = await self.collection.find(args).sort("created_at", -1).skip( + offset).limit(limit).to_list(None) + generations: List[Generation] = [] + for generation in res: + generation["id"] = str(generation.pop("_id")) + generations.append(Generation(**generation)) + return generations + + async def update_generation(self, generation: Generation, ): + res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) diff --git a/routers/char_router.py b/routers/char_router.py index d6ec870..7ec7c71 100644 --- a/routers/char_router.py +++ b/routers/char_router.py @@ -73,7 +73,7 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot): file_info = await bot.get_file(char.character_image_doc_tg_id) file_bytes = await bot.download_file(file_info.file_path) file_io = file_bytes.read() - avatar_asset = await dao.assets.save_asset( + avatar_asset = await dao.assets.create_asset( Asset(name="avatar.png", type=AssetType.IMAGE, linked_char_id=str(char.id), data=file_io, tg_doc_file_id=file_id)) char.avatar_image = avatar_asset.link diff --git a/routers/gen_router.py b/routers/gen_router.py index 7fa006d..38d76e9 100644 --- a/routers/gen_router.py +++ b/routers/gen_router.py @@ -50,8 +50,8 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi gemini=gemini) await wait_msg.delete() doc = await message.answer_document(res[0], caption="Generated result 💫") - await dao.assets.save_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data, - tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) + await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data, + tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) @router.message(Command("gen_mode")) @@ -260,9 +260,9 @@ async def handle_album( if generated_files: for file in generated_files: doc = await message.answer_document(file, caption="✨ Generated result") - await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, - tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, - linked_char_id = data["char_id"])) + await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, + tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, + linked_char_id = data["char_id"])) else: await message.answer("❌ Генерация не вернула изображений.") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") @@ -315,9 +315,9 @@ async def gen_mode_start( if generated_files: for file in generated_files: doc = await message.answer_document(file, caption="✨ Generated result") - await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, - tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, - linked_char_id=data["char_id"])) + await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data, + tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, + linked_char_id=data["char_id"])) else: await message.answer("❌ Ничего не сгенерировалось.")