+ api
This commit is contained in:
@@ -80,8 +80,8 @@ class GoogleAdapter:
|
||||
response_modalities=['IMAGE'],
|
||||
temperature=1.0,
|
||||
image_config=types.ImageConfig(
|
||||
aspect_ratio=aspect_ratio.value,
|
||||
image_size=quality.value
|
||||
aspect_ratio=aspect_ratio.value_ratio,
|
||||
image_size=quality.value_quality
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
30
api/dependency.py
Normal file
30
api/dependency.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# dependency.py
|
||||
from fastapi import Request, Depends
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
|
||||
|
||||
# ... ваши импорты ...
|
||||
|
||||
# Провайдеры "сырых" клиентов из состояния приложения
|
||||
def get_mongo_client(request: Request) -> AsyncIOMotorClient:
|
||||
return request.app.state.mongo_client
|
||||
|
||||
def get_gemini_client(request: Request) -> GoogleAdapter:
|
||||
return request.app.state.gemini_client
|
||||
|
||||
# Провайдер DAO (собирается из mongo_client)
|
||||
def get_dao(mongo_client: AsyncIOMotorClient = Depends(get_mongo_client)) -> DAO:
|
||||
# FastAPI кэширует результат Depends в рамках одного запроса,
|
||||
# так что DAO создастся один раз за запрос.
|
||||
return DAO(mongo_client)
|
||||
|
||||
# Провайдер сервиса (собирается из DAO и Gemini)
|
||||
def get_generation_service(
|
||||
dao: DAO = Depends(get_dao),
|
||||
gemini: GoogleAdapter = Depends(get_gemini_client)
|
||||
) -> GenerationService:
|
||||
return GenerationService(dao, gemini)
|
||||
@@ -1,92 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from aiogram.types import BufferedInputFile
|
||||
from fastapi import APIRouter, UploadFile, File, Form
|
||||
from fastapi.openapi.models import MediaType
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from models.Asset import Asset, AssetType
|
||||
from repos.dao import DAO
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||
|
||||
|
||||
@router.get("/{asset_id}")
|
||||
async def get_asset(asset_id: str, request: Request) -> Response:
|
||||
dao = request.app.state.dao
|
||||
asset = await dao.assets.get_asset(asset_id)
|
||||
# 2. Проверка на существование
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
headers = {
|
||||
# Кэшировать на 1 год (31536000 сек)
|
||||
"Cache-Control": "public, max-age=31536000, immutable"
|
||||
}
|
||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_assets(request: Request) -> List[Asset]:
|
||||
dao: DAO = request.app.state.dao
|
||||
assets = await dao.assets.get_assets()
|
||||
|
||||
return assets
|
||||
|
||||
|
||||
@router.post("/upload", response_model=Asset)
|
||||
async def upload_asset(
|
||||
request: Request,
|
||||
# Файл обязателен
|
||||
file: UploadFile = File(...),
|
||||
# Остальные поля принимаем как Form-data (не JSON!)
|
||||
name: str = Form(...),
|
||||
type: AssetType = Form(...),
|
||||
linked_char_id: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
Загружает файл, отправляет его в ТГ (для получения ID) и сохраняет в БД.
|
||||
"""
|
||||
# 1. Читаем байты файла
|
||||
file_content = await file.read()
|
||||
|
||||
if not file_content:
|
||||
raise HTTPException(status_code=400, detail="File is empty")
|
||||
|
||||
# 2. Получаем необходимые зависимости из state
|
||||
bot = request.app.state.bot # Бот нужен, чтобы получить tg_file_id
|
||||
admin_id = request.app.state.admin_id # Куда отправлять файл "на хранение"
|
||||
dao = request.app.state.assets_dao
|
||||
|
||||
# 3. Отправляем файл в Telegram, чтобы получить tg_doc_file_id
|
||||
# (Это обязательно, так как ваша модель требует этот ID)
|
||||
try:
|
||||
tg_msg = await bot.send_document(
|
||||
chat_id=admin_id,
|
||||
document=BufferedInputFile(file_content, filename=file.filename),
|
||||
caption=f"📥 Uploaded via API: {name}"
|
||||
)
|
||||
# Получаем ID документа из ответа ТГ
|
||||
tg_doc_id = tg_msg.document.file_id
|
||||
|
||||
# Если это картинка, можно попытаться достать и photo_id (для превью)
|
||||
# Но send_document обычно возвращает именно документ.
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload to Telegram: {e}")
|
||||
|
||||
# 4. Создаем объект Asset
|
||||
# Pydantic сам подставит created_at и вычислит link
|
||||
new_asset = Asset(
|
||||
name=name,
|
||||
type=type,
|
||||
linked_char_id=linked_char_id,
|
||||
data=file_content, # Сохраняем байты в БД
|
||||
tg_doc_file_id=tg_doc_id # ID из телеграма
|
||||
)
|
||||
|
||||
# 5. Сохраняем через DAO
|
||||
saved_asset = await dao.save_asset(new_asset)
|
||||
|
||||
return saved_asset
|
||||
74
api/endpoints/assets_router.py
Normal file
74
api/endpoints/assets_router.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from aiogram.types import BufferedInputFile
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
||||
from fastapi.openapi.models import MediaType
|
||||
from starlette import status
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from models.Asset import Asset, AssetType
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||
|
||||
|
||||
@router.get("/{asset_id}")
|
||||
async def get_asset(asset_id: str, request: Request,dao: DAO = Depends(get_dao),) -> Response:
|
||||
|
||||
asset = await dao.assets.get_asset(asset_id)
|
||||
# 2. Проверка на существование
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
headers = {
|
||||
# Кэшировать на 1 год (31536000 сек)
|
||||
"Cache-Control": "public, max-age=31536000, immutable"
|
||||
}
|
||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), limit: int = 10, offset: int = 0) -> AssetsResponse:
|
||||
assets = await dao.assets.get_assets(limit, offset)
|
||||
assets = await dao.assets.get_assets()
|
||||
total_count = await dao.assets.get_asset_count()
|
||||
|
||||
return AssetsResponse(assets=assets, total_count=total_count)
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
linked_char_id: Optional[str] = Form(None),
|
||||
dao: DAO = Depends(get_dao),
|
||||
):
|
||||
if not file.content_type:
|
||||
raise HTTPException(status_code=400, detail="Unknown file type")
|
||||
|
||||
if not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
||||
|
||||
data = await file.read()
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
|
||||
asset = Asset(
|
||||
name=file.filename or "upload",
|
||||
type=AssetType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=data,
|
||||
)
|
||||
|
||||
asset_id = await dao.assets.create_asset(asset)
|
||||
asset.id = str(asset_id)
|
||||
|
||||
return AssetResponse(
|
||||
id=asset.id,
|
||||
name=asset.name,
|
||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||
linked_char_id=asset.linked_char_id,
|
||||
created_at=asset.created_at,
|
||||
)
|
||||
@@ -1,35 +1,44 @@
|
||||
from typing import List
|
||||
from typing import List, Any, Coroutine
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
from models.Asset import Asset
|
||||
from models.Character import Character
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
router = APIRouter(prefix="/api/characters", tags=["Characters"])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Character])
|
||||
async def get_characters(request: Request) -> List[Character]:
|
||||
dao: DAO = request.app.state.dao
|
||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
|
||||
characters = await dao.chars.get_all_characters()
|
||||
return characters
|
||||
|
||||
|
||||
@router.get("/{character_id}/assets", response_model=List[Asset])
|
||||
async def get_character_assets(character_id: str, request: Request) -> List[Asset]:
|
||||
dao: DAO = request.app.state.dao
|
||||
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
||||
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
||||
offset: int = 0, ) -> AssetsResponse:
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if character is None:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
return await dao.assets.get_assets_by_char_id(character_id)
|
||||
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
||||
total_count = await dao.assets.get_asset_count(character_id)
|
||||
return AssetsResponse(assets=assets, total_count=total_count)
|
||||
|
||||
|
||||
@router.get("/{character_id}", response_model=Character)
|
||||
async def get_character_by_id(character_id: str, request: Request) -> Character:
|
||||
dao: DAO = request.app.state.dao
|
||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
|
||||
character = await dao.chars.get_character(character_id)
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/{character_id}/_run", response_model=Asset)
|
||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||
request: Request) -> GenerationResponse:
|
||||
generation_service = request.app.state.generation_service
|
||||
|
||||
47
api/endpoints/generation_router.py
Normal file
47
api/endpoints/generation_router.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.params import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
from api import service
|
||||
from api.dependency import get_generation_service
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, PromptResponse, PromptRequest
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Generation import Generation
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> PromptResponse:
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.get("", response_model=List[GenerationResponse])
|
||||
async def get_generations(character_id: Optional[str], limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
|
||||
|
||||
|
||||
@router.post("/_run", response_model=GenerationResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> GenerationResponse:
|
||||
return await generation_service.create_generation_task(generation)
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
|
||||
return await generation_service.get_generation(generation_id)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
return await generation_service.get_running_generations()
|
||||
19
api/models/AssetDTO.py
Normal file
19
api/models/AssetDTO.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Asset import Asset
|
||||
|
||||
|
||||
class AssetsResponse(BaseModel):
|
||||
assets: List[Asset]
|
||||
total_count: int
|
||||
|
||||
|
||||
class AssetResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
linked_char_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
39
api/models/GenerationRequest.py
Normal file
39
api/models/GenerationRequest.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.Generation import GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
quality: Quality = Quality.ONEK
|
||||
prompt: str
|
||||
assets_list: List[str]
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
id: str
|
||||
status: GenerationStatus
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
assets_list: List[str]
|
||||
result: Optional[str] = None
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
linked_assets: List[str] = []
|
||||
|
||||
|
||||
class PromptResponse(BaseModel):
|
||||
prompt: str
|
||||
0
api/models/__init__.py
Normal file
0
api/models/__init__.py
Normal file
0
api/service/__init__.py
Normal file
0
api/service/__init__.py
Normal file
186
api/service/generation_service.py
Normal file
186
api/service/generation_service.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
from io import BytesIO
|
||||
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from models.Asset import Asset, AssetType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
prompt: str,
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
gemini: GoogleAdapter
|
||||
) -> List[bytes]:
|
||||
"""
|
||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||
Возвращает список байтов сгенерированных изображений.
|
||||
"""
|
||||
|
||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||
generated_images_io: List[BytesIO] = await asyncio.to_thread(
|
||||
gemini.generate_image,
|
||||
prompt=prompt,
|
||||
images_list=media_group_bytes,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
)
|
||||
|
||||
images_bytes = []
|
||||
if generated_images_io:
|
||||
for img_io in generated_images_io:
|
||||
# Читаем байты из BytesIO
|
||||
img_io.seek(0)
|
||||
content = img_io.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
# Закрываем поток
|
||||
img_io.close()
|
||||
|
||||
return images_bytes
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:"""
|
||||
future_prompt += prompt
|
||||
assets_data = []
|
||||
if assets is not None:
|
||||
assets_data = await self.dao.assets.get_assets_by_ids(assets)
|
||||
generated_prompt = self.gemini.generate_text(future_prompt, assets_data)
|
||||
return generated_prompt
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
||||
Generation]:
|
||||
return await self.dao.generations.get_generations(limit=limit, offset=offset)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if gen is None:
|
||||
return None
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
|
||||
async def get_running_generations(self) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump())
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
|
||||
asyncio.create_task(runner(generation_model))
|
||||
|
||||
return GenerationResponse(**generation_model.model_dump())
|
||||
|
||||
except Exception:
|
||||
# если не успели создать запись — нечего помечать
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
|
||||
async def create_generation(self, generation: Generation):
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = "You are creating image. "
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
media_group_bytes.append(char_info.character_image_data)
|
||||
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
# Фильтруем, чтобы отправлять только картинки, и где есть data
|
||||
media_group_bytes.extend(
|
||||
asset.data
|
||||
for asset in reference_assets
|
||||
if asset.data is not None and asset.type == AssetType.IMAGE
|
||||
)
|
||||
generation_prompt+=f"PROMPT: {generation.prompt}"
|
||||
|
||||
# 3. Запускаем процесс генерации
|
||||
try:
|
||||
generated_bytes_list = await generate_image_task(
|
||||
prompt=generation_prompt, # или request.prompt
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||
quality=generation.quality,
|
||||
gemini=self.gemini
|
||||
)
|
||||
except Exception as e:
|
||||
# Тут стоит добавить логирование ошибки
|
||||
logging.error(f"Generation failed: {e}")
|
||||
# Можно обновить статус генерации на FAILED в БД
|
||||
raise e
|
||||
|
||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||
created_assets: List[Asset] = []
|
||||
|
||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||
new_asset = Asset(
|
||||
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
|
||||
type=AssetType.IMAGE,
|
||||
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
|
||||
data=img_bytes,
|
||||
# Остальные поля заполнятся дефолтными значениями (created_at)
|
||||
)
|
||||
|
||||
# Сохраняем в БД
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||||
|
||||
created_assets.append(new_asset)
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = [a.id for a in created_assets]
|
||||
|
||||
generation.assets_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = generation_prompt
|
||||
await self.dao.generations.update_generation(generation)
|
||||
17
main.py
17
main.py
@@ -16,6 +16,7 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from middlewares.album import AlbumMiddleware
|
||||
from middlewares.auth import AuthMiddleware
|
||||
from middlewares.dao import DaoMiddleware
|
||||
@@ -33,8 +34,9 @@ from routers.auth_router import router as auth_router
|
||||
from routers.gen_router import router as gen_router
|
||||
from routers.char_router import router as char_router
|
||||
from routers.assets_router import router as assets_router # Роутер бота для ассетов
|
||||
from api.endpoints.assets import router as api_assets_router # Роутер FastAPI
|
||||
from api.endpoints.assets_router import router as api_assets_router # Роутер FastAPI
|
||||
from api.endpoints.character_router import router as api_char_router # Роутер FastAPI
|
||||
from api.endpoints.generation_router import router as api_gen_router
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -47,7 +49,7 @@ ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||
|
||||
|
||||
def setup_logging():
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
||||
|
||||
|
||||
@@ -61,6 +63,8 @@ mongo_client = AsyncIOMotorClient(MONGO_HOST)
|
||||
users_repo = UsersRepo(mongo_client)
|
||||
char_repo = CharacterRepo(mongo_client)
|
||||
dao = DAO(mongo_client) # Главный DAO для бота
|
||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||
generation_service = GenerationService(dao, gemini)
|
||||
|
||||
# Dispatcher
|
||||
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
||||
@@ -68,7 +72,7 @@ dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
||||
# Внедрение зависимостей (глобально для бота)
|
||||
dp["repo"] = users_repo
|
||||
dp["admin_id"] = ADMIN_ID
|
||||
dp["gemini"] = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||
dp["gemini"] = gemini
|
||||
|
||||
# --- НАСТРОЙКА РОУТЕРОВ БОТА ---
|
||||
|
||||
@@ -108,7 +112,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Инициализируем DAO для ассетов и кладем в state приложения
|
||||
# Теперь в эндпоинтах можно делать request.app.state.assets_dao
|
||||
app.state.dao = dao
|
||||
|
||||
app.state.mongo_client = mongo_client
|
||||
app.state.gemini_client = gemini
|
||||
|
||||
print("✅ DB & DAO initialized")
|
||||
|
||||
@@ -152,6 +158,7 @@ app.add_middleware(
|
||||
# Подключаем роутер API
|
||||
app.include_router(api_assets_router)
|
||||
app.include_router(api_char_router)
|
||||
app.include_router(api_gen_router)
|
||||
|
||||
|
||||
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
||||
@@ -179,7 +186,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# Создаем конфигурацию uvicorn вручную
|
||||
# loop="asyncio" заставляет использовать стандартный цикл
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio")
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Запускаем сервер (lifespan запустится внутри)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Optional, Any, List
|
||||
|
||||
from pydantic import BaseModel, computed_field
|
||||
|
||||
@@ -16,8 +16,9 @@ class Asset(BaseModel):
|
||||
type: AssetType
|
||||
linked_char_id: Optional[str] = None
|
||||
data: Optional[bytes] = None
|
||||
tg_doc_file_id: str
|
||||
tg_doc_file_id: Optional[str] = None
|
||||
tg_photo_file_id: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
|
||||
# --- CALCULATED FIELD ---
|
||||
|
||||
27
models/Generation.py
Normal file
27
models/Generation.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class GenerationStatus(str, Enum):
|
||||
RUNNING = "running"
|
||||
DONE = "done"
|
||||
FAILED = "failed"
|
||||
|
||||
class Generation(BaseModel):
|
||||
id: Optional[str] = None
|
||||
status: GenerationStatus = GenerationStatus.RUNNING
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
tech_prompt: Optional[str] = None
|
||||
assets_list: List[str]
|
||||
result: Optional[str] = None
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
@@ -1,19 +1,43 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class AspectRatios(Enum):
|
||||
NINESIXTEEN = '9:16'
|
||||
SIXTEENNINE = '16:9'
|
||||
THREEFOUR = '3:4'
|
||||
FOURTHREE = '4:3'
|
||||
class AspectRatios(str, Enum):
|
||||
NINESIXTEEN = "NINESIXTEEN"
|
||||
SIXTEENNINE = "SIXTEENNINE"
|
||||
THREEFOUR = "THREEFOUR"
|
||||
FOURTHREE = "FOURTHREE"
|
||||
|
||||
@property
|
||||
def value_ratio(self) -> str:
|
||||
return {
|
||||
AspectRatios.NINESIXTEEN: "9:16",
|
||||
AspectRatios.SIXTEENNINE: "16:9",
|
||||
AspectRatios.THREEFOUR: "3:4",
|
||||
AspectRatios.FOURTHREE: "4:3",
|
||||
}[self]
|
||||
|
||||
|
||||
class Quality(Enum):
|
||||
ONEK = '1K'
|
||||
TWOK = '2K'
|
||||
FOURK = '4K'
|
||||
class Quality(str, Enum):
|
||||
ONEK = 'ONEK'
|
||||
TWOK = 'TWOK'
|
||||
FOURK = 'FOURK'
|
||||
|
||||
@property
|
||||
def value_quality(self) -> str:
|
||||
return {
|
||||
Quality.ONEK: '1K',
|
||||
Quality.TWOK: '2K',
|
||||
Quality.FOURK: '4K'
|
||||
}[self]
|
||||
|
||||
|
||||
class GenType(Enum):
|
||||
class GenType(str, Enum):
|
||||
TEXT = 'Text'
|
||||
IMAGE = 'Image'
|
||||
|
||||
@property
|
||||
def value_type(self) -> str:
|
||||
return {
|
||||
GenType.TEXT: 'Text',
|
||||
GenType.IMAGE: 'Image'
|
||||
}[self]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
@@ -10,10 +10,10 @@ class AssetsRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["assets"]
|
||||
|
||||
async def save_asset(self, asset: Asset) -> Asset:
|
||||
async def create_asset(self, asset: Asset) -> str:
|
||||
res = await self.collection.insert_one(asset.model_dump())
|
||||
asset.id = res.inserted_id
|
||||
return asset
|
||||
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]:
|
||||
res = await self.collection.find({}, {"data": 0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||
@@ -27,6 +27,7 @@ class AssetsRepo:
|
||||
|
||||
return assets
|
||||
|
||||
|
||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
||||
projection = {"_id": 1, "name": 1, "type": 1, "tg_doc_file_id": 1}
|
||||
if with_data:
|
||||
@@ -54,3 +55,17 @@ class AssetsRepo:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
assets.append(Asset(**doc))
|
||||
return assets
|
||||
|
||||
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
|
||||
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
|
||||
|
||||
|
||||
|
||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||
res = self.collection.find({"_id": {"$in": object_ids}})
|
||||
assets = []
|
||||
async for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
assets.append(Asset(**doc))
|
||||
return assets
|
||||
|
||||
@@ -2,6 +2,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from repos.assets_repo import AssetsRepo
|
||||
from repos.char_repo import CharacterRepo
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from repos.user_repo import UsersRepo
|
||||
|
||||
|
||||
@@ -9,3 +10,4 @@ class DAO:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.chars = CharacterRepo(client, db_name)
|
||||
self.assets = AssetsRepo(client, db_name)
|
||||
self.generations = GenerationRepo(client, db_name)
|
||||
|
||||
43
repos/generation_repo.py
Normal file
43
repos/generation_repo.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from PIL.ImageChops import offset
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
|
||||
|
||||
class GenerationRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["generations"]
|
||||
|
||||
async def create_generation(self, generation: Generation) -> str:
|
||||
res = await self.collection.insert_one(generation.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
||||
if res is None:
|
||||
return None
|
||||
else:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Generation(**res)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
limit: int = 10, offset: int = 10) -> List[Generation]:
|
||||
args = {}
|
||||
if character_id is not None:
|
||||
args["character_id"] = character_id
|
||||
if status is not None:
|
||||
args["status"] = status
|
||||
res = await self.collection.find(args).sort("created_at", -1).skip(
|
||||
offset).limit(limit).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
generation["id"] = str(generation.pop("_id"))
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def update_generation(self, generation: Generation, ):
|
||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||
@@ -73,7 +73,7 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||
file_info = await bot.get_file(char.character_image_doc_tg_id)
|
||||
file_bytes = await bot.download_file(file_info.file_path)
|
||||
file_io = file_bytes.read()
|
||||
avatar_asset = await dao.assets.save_asset(
|
||||
avatar_asset = await dao.assets.create_asset(
|
||||
Asset(name="avatar.png", type=AssetType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||
tg_doc_file_id=file_id))
|
||||
char.avatar_image = avatar_asset.link
|
||||
|
||||
@@ -50,8 +50,8 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
|
||||
gemini=gemini)
|
||||
await wait_msg.delete()
|
||||
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||
await dao.assets.save_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
||||
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
||||
|
||||
|
||||
@router.message(Command("gen_mode"))
|
||||
@@ -260,9 +260,9 @@ async def handle_album(
|
||||
if generated_files:
|
||||
for file in generated_files:
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||
linked_char_id = data["char_id"]))
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||
linked_char_id = data["char_id"]))
|
||||
else:
|
||||
await message.answer("❌ Генерация не вернула изображений.")
|
||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||
@@ -315,9 +315,9 @@ async def gen_mode_start(
|
||||
if generated_files:
|
||||
for file in generated_files:
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.save_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||
linked_char_id=data["char_id"]))
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||
linked_char_id=data["char_id"]))
|
||||
|
||||
else:
|
||||
await message.answer("❌ Ничего не сгенерировалось.")
|
||||
|
||||
Reference in New Issue
Block a user