This commit is contained in:
xds
2026-02-04 15:10:55 +03:00
parent 11c1f4f7dc
commit 35de8efc56
20 changed files with 566 additions and 135 deletions

30
api/dependency.py Normal file
View File

@@ -0,0 +1,30 @@
# dependency.py
from fastapi import Request, Depends
from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService
from repos.dao import DAO
# ... ваши импорты ...
# Провайдеры "сырых" клиентов из состояния приложения
def get_mongo_client(request: Request) -> AsyncIOMotorClient:
return request.app.state.mongo_client
def get_gemini_client(request: Request) -> GoogleAdapter:
return request.app.state.gemini_client
# Провайдер DAO (собирается из mongo_client)
def get_dao(mongo_client: AsyncIOMotorClient = Depends(get_mongo_client)) -> DAO:
# FastAPI кэширует результат Depends в рамках одного запроса,
# так что DAO создастся один раз за запрос.
return DAO(mongo_client)
# Провайдер сервиса (собирается из DAO и Gemini)
def get_generation_service(
dao: DAO = Depends(get_dao),
gemini: GoogleAdapter = Depends(get_gemini_client)
) -> GenerationService:
return GenerationService(dao, gemini)

View File

@@ -1,92 +0,0 @@
from typing import List, Optional
from aiogram.types import BufferedInputFile
from fastapi import APIRouter, UploadFile, File, Form
from fastapi.openapi.models import MediaType
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from models.Asset import Asset, AssetType
from repos.dao import DAO
router = APIRouter(prefix="/api/assets", tags=["Assets"])
@router.get("/{asset_id}")
async def get_asset(asset_id: str, request: Request) -> Response:
dao = request.app.state.dao
asset = await dao.assets.get_asset(asset_id)
# 2. Проверка на существование
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable"
}
return Response(content=asset.data, media_type="image/png", headers=headers)
@router.get("")
async def get_assets(request: Request) -> List[Asset]:
dao: DAO = request.app.state.dao
assets = await dao.assets.get_assets()
return assets
@router.post("/upload", response_model=Asset)
async def upload_asset(
request: Request,
# Файл обязателен
file: UploadFile = File(...),
# Остальные поля принимаем как Form-data (не JSON!)
name: str = Form(...),
type: AssetType = Form(...),
linked_char_id: Optional[str] = Form(None)
):
"""
Загружает файл, отправляет его в ТГ (для получения ID) и сохраняет в БД.
"""
# 1. Читаем байты файла
file_content = await file.read()
if not file_content:
raise HTTPException(status_code=400, detail="File is empty")
# 2. Получаем необходимые зависимости из state
bot = request.app.state.bot # Бот нужен, чтобы получить tg_file_id
admin_id = request.app.state.admin_id # Куда отправлять файл "на хранение"
dao = request.app.state.assets_dao
# 3. Отправляем файл в Telegram, чтобы получить tg_doc_file_id
# (Это обязательно, так как ваша модель требует этот ID)
try:
tg_msg = await bot.send_document(
chat_id=admin_id,
document=BufferedInputFile(file_content, filename=file.filename),
caption=f"📥 Uploaded via API: {name}"
)
# Получаем ID документа из ответа ТГ
tg_doc_id = tg_msg.document.file_id
# Если это картинка, можно попытаться достать и photo_id (для превью)
# Но send_document обычно возвращает именно документ.
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to upload to Telegram: {e}")
# 4. Создаем объект Asset
# Pydantic сам подставит created_at и вычислит link
new_asset = Asset(
name=name,
type=type,
linked_char_id=linked_char_id,
data=file_content, # Сохраняем байты в БД
tg_doc_file_id=tg_doc_id # ID из телеграма
)
# 5. Сохраняем через DAO
saved_asset = await dao.save_asset(new_asset)
return saved_asset

View File

@@ -0,0 +1,74 @@
from typing import List, Optional
from aiogram.types import BufferedInputFile
from fastapi import APIRouter, UploadFile, File, Form, Depends
from fastapi.openapi.models import MediaType
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from api.models.AssetDTO import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType
from repos.dao import DAO
from api.dependency import get_dao
router = APIRouter(prefix="/api/assets", tags=["Assets"])
@router.get("/{asset_id}")
async def get_asset(asset_id: str, request: Request,dao: DAO = Depends(get_dao),) -> Response:
asset = await dao.assets.get_asset(asset_id)
# 2. Проверка на существование
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable"
}
return Response(content=asset.data, media_type="image/png", headers=headers)
@router.get("")
async def get_assets(request: Request, dao: DAO = Depends(get_dao), limit: int = 10, offset: int = 0) -> AssetsResponse:
assets = await dao.assets.get_assets(limit, offset)
assets = await dao.assets.get_assets()
total_count = await dao.assets.get_asset_count()
return AssetsResponse(assets=assets, total_count=total_count)
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset(
file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None),
dao: DAO = Depends(get_dao),
):
if not file.content_type:
raise HTTPException(status_code=400, detail="Unknown file type")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="Empty file")
asset = Asset(
name=file.filename or "upload",
type=AssetType.IMAGE,
linked_char_id=linked_char_id,
data=data,
)
asset_id = await dao.assets.create_asset(asset)
asset.id = str(asset_id)
return AssetResponse(
id=asset.id,
name=asset.name,
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
linked_char_id=asset.linked_char_id,
created_at=asset.created_at,
)

View File

@@ -1,35 +1,44 @@
from typing import List
from typing import List, Any, Coroutine
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.requests import Request
from api.models.AssetDTO import AssetsResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from models.Asset import Asset
from models.Character import Character
from repos.dao import DAO
from api.dependency import get_dao
router = APIRouter(prefix="/api/characters", tags=["Characters"])
@router.get("/", response_model=List[Character])
async def get_characters(request: Request) -> List[Character]:
dao: DAO = request.app.state.dao
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
characters = await dao.chars.get_all_characters()
return characters
@router.get("/{character_id}/assets", response_model=List[Asset])
async def get_character_assets(character_id: str, request: Request) -> List[Asset]:
dao: DAO = request.app.state.dao
@router.get("/{character_id}/assets", response_model=AssetsResponse)
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
offset: int = 0, ) -> AssetsResponse:
character = await dao.chars.get_character(character_id)
if character is None:
raise HTTPException(status_code=404, detail="Character not found")
return await dao.assets.get_assets_by_char_id(character_id)
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
total_count = await dao.assets.get_asset_count(character_id)
return AssetsResponse(assets=assets, total_count=total_count)
@router.get("/{character_id}", response_model=Character)
async def get_character_by_id(character_id: str, request: Request) -> Character:
dao: DAO = request.app.state.dao
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
character = await dao.chars.get_character(character_id)
return character
@router.post("/{character_id}/_run", response_model=Asset)
async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse:
generation_service = request.app.state.generation_service

View File

@@ -0,0 +1,47 @@
from typing import List, Optional
from fastapi import APIRouter
from fastapi.params import Depends
from starlette.requests import Request
from api import service
from api.dependency import get_generation_service
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, PromptResponse, PromptRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
router = APIRouter(prefix='/api/generations', tags=["Generation"])
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> PromptResponse:
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=List[GenerationResponse])
async def get_generations(character_id: Optional[str], limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)):
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
@router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> GenerationResponse:
return await generation_service.create_generation_task(generation)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
return await generation_service.get_generation(generation_id)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service)):
return await generation_service.get_running_generations()

19
api/models/AssetDTO.py Normal file
View File

@@ -0,0 +1,19 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from models.Asset import Asset
class AssetsResponse(BaseModel):
assets: List[Asset]
total_count: int
class AssetResponse(BaseModel):
id: str
name: str
type: str
linked_char_id: Optional[str] = None
created_at: datetime

View File

@@ -0,0 +1,39 @@
from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel
from models.Asset import Asset
from models.Generation import GenerationStatus
from models.enums import AspectRatios, Quality
class GenerationRequest(BaseModel):
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
quality: Quality = Quality.ONEK
prompt: str
assets_list: List[str]
class GenerationResponse(BaseModel):
id: str
status: GenerationStatus
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios
quality: Quality
prompt: str
assets_list: List[str]
result: Optional[str] = None
created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC)
class PromptRequest(BaseModel):
prompt: str
linked_assets: List[str] = []
class PromptResponse(BaseModel):
prompt: str

0
api/models/__init__.py Normal file
View File

0
api/service/__init__.py Normal file
View File

View File

@@ -0,0 +1,186 @@
import asyncio
import logging
import random
from datetime import datetime, UTC
from typing import List, Optional
from io import BytesIO
from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from repos.dao import DAO
logger = logging.getLogger(__name__)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
gemini: GoogleAdapter
) -> List[bytes]:
"""
Обертка для вызова синхронного метода Gemini в отдельном потоке.
Возвращает список байтов сгенерированных изображений.
"""
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
generated_images_io: List[BytesIO] = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=aspect_ratio,
quality=quality,
)
images_bytes = []
if generated_images_io:
for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0)
content = img_io.read()
images_bytes.append(content)
# Закрываем поток
img_io.close()
return images_bytes
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter):
self.dao = dao
self.gemini = gemini
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:"""
future_prompt += prompt
assets_data = []
if assets is not None:
assets_data = await self.dao.assets.get_assets_by_ids(assets)
generated_prompt = self.gemini.generate_text(future_prompt, assets_data)
return generated_prompt
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
Generation]:
return await self.dao.generations.get_generations(limit=limit, offset=offset)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
if gen is None:
return None
else:
return GenerationResponse(**gen.model_dump())
async def get_running_generations(self) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump())
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
try:
await self.create_generation(gen)
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model))
return GenerationResponse(**generation_model.model_dump())
except Exception:
# если не успели создать запись — нечего помечать
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation):
# 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = []
generation_prompt = "You are creating image. "
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
media_group_bytes.append(char_info.character_image_data)
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
# Фильтруем, чтобы отправлять только картинки, и где есть data
media_group_bytes.extend(
asset.data
for asset in reference_assets
if asset.data is not None and asset.type == AssetType.IMAGE
)
generation_prompt+=f"PROMPT: {generation.prompt}"
# 3. Запускаем процесс генерации
try:
generated_bytes_list = await generate_image_task(
prompt=generation_prompt, # или request.prompt
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
quality=generation.quality,
gemini=self.gemini
)
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
# Можно обновить статус генерации на FAILED в БД
raise e
# 4. Сохраняем полученные изображения как новые Ассеты
created_assets: List[Asset] = []
for idx, img_bytes in enumerate(generated_bytes_list):
new_asset = Asset(
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
type=AssetType.IMAGE,
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
data=img_bytes,
# Остальные поля заполнятся дефолтными значениями (created_at)
)
# Сохраняем в БД
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
created_assets.append(new_asset)
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
generation.assets_list = result_ids
generation.status = GenerationStatus.DONE
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = generation_prompt
await self.dao.generations.update_generation(generation)