models + refactor

This commit is contained in:
xds
2026-02-27 20:37:24 +03:00
parent d9caececd7
commit e011805186
31 changed files with 234 additions and 223 deletions

View File

@@ -1,6 +1,5 @@
import logging
import json
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends
@@ -30,7 +29,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/generations', tags=["Generation"])
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
"""Helper to check if user has access to project."""
if not project_id:
return
@@ -46,31 +45,36 @@ async def ask_prompt_assistant(
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
generated_prompt = await generation_service.ask_prompt_assistant(
prompt_request.prompt,
prompt_request.linked_assets,
prompt_request.model
)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
prompt: str | None = Form(None),
model: str = Form("gemini-3.1-pro-preview"),
images: list[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
images_bytes = [await img.read() for img in images]
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(
character_id: Optional[str] = None,
character_id: str | None = None,
limit: int = 10,
offset: int = 0,
only_liked: bool = False,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
@@ -92,10 +96,10 @@ async def get_generations(
@router.get("/usage", response_model=FinancialReport)
async def get_usage_report(
breakdown: Optional[str] = None, # "user" or "project"
breakdown: str | None = None, # "user" or "project"
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
await check_project_access(project_id, current_user, dao)
@@ -120,7 +124,7 @@ async def post_generation(
generation: GenerationRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> GenerationGroupResponse:
await check_project_access(project_id, current_user, dao)
@@ -137,7 +141,7 @@ async def post_generation(
async def get_running_generations(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)