229 lines
8.5 KiB
Python
229 lines
8.5 KiB
Python
import logging
|
|
import json
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
|
from fastapi.params import Depends
|
|
from starlette import status
|
|
from starlette.requests import Request
|
|
|
|
from config import settings
|
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
|
from api.endpoints.auth import get_current_user
|
|
from api.models import (
|
|
GenerationResponse,
|
|
GenerationRequest,
|
|
GenerationsResponse,
|
|
PromptResponse,
|
|
PromptRequest,
|
|
GenerationGroupResponse,
|
|
FinancialReport,
|
|
ExternalGenerationRequest
|
|
)
|
|
from api.service.generation_service import GenerationService
|
|
from repos.dao import DAO
|
|
from utils.external_auth import verify_signature
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
|
|
|
|
|
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
|
|
"""Helper to check if user has access to project."""
|
|
if not project_id:
|
|
return
|
|
project = await dao.projects.get_project(project_id)
|
|
if not project or str(current_user["_id"]) not in project.members:
|
|
raise HTTPException(status_code=403, detail="Project access denied")
|
|
|
|
|
|
@router.post("/prompt-assistant", response_model=PromptResponse)
|
|
async def ask_prompt_assistant(
|
|
prompt_request: PromptRequest,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> PromptResponse:
|
|
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
|
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
|
return PromptResponse(prompt=generated_prompt)
|
|
|
|
|
|
@router.post("/prompt-from-image", response_model=PromptResponse)
|
|
async def prompt_from_image(
|
|
prompt: Optional[str] = Form(None),
|
|
images: List[UploadFile] = File(...),
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> PromptResponse:
|
|
images_bytes = [await img.read() for img in images]
|
|
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
|
return PromptResponse(prompt=generated_prompt)
|
|
|
|
|
|
@router.get("", response_model=GenerationsResponse)
|
|
async def get_generations(
|
|
character_id: Optional[str] = None,
|
|
limit: int = 10,
|
|
offset: int = 0,
|
|
only_liked: bool = False,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)
|
|
):
|
|
await check_project_access(project_id, current_user, dao)
|
|
|
|
# If project_id is set, we don't filter by user to show all project-wide generations
|
|
created_by_filter = None if project_id else str(current_user["_id"])
|
|
only_liked_by = str(current_user["_id"]) if only_liked else None
|
|
|
|
return await generation_service.get_generations(
|
|
character_id=character_id,
|
|
limit=limit,
|
|
offset=offset,
|
|
created_by=created_by_filter,
|
|
project_id=project_id,
|
|
only_liked_by=only_liked_by,
|
|
current_user_id=str(current_user["_id"])
|
|
)
|
|
|
|
|
|
@router.get("/usage", response_model=FinancialReport)
|
|
async def get_usage_report(
|
|
breakdown: Optional[str] = None, # "user" or "project"
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)
|
|
) -> FinancialReport:
|
|
await check_project_access(project_id, current_user, dao)
|
|
|
|
user_id_filter = str(current_user["_id"]) if not project_id else None
|
|
breakdown_by = None
|
|
|
|
if breakdown == "user":
|
|
breakdown_by = "created_by"
|
|
elif breakdown == "project":
|
|
breakdown_by = "project_id"
|
|
|
|
return await generation_service.get_financial_report(
|
|
user_id=user_id_filter,
|
|
project_id=project_id,
|
|
breakdown_by=breakdown_by
|
|
)
|
|
|
|
|
|
@router.post("/_run", response_model=GenerationGroupResponse)
|
|
async def post_generation(
|
|
generation: GenerationRequest,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)
|
|
) -> GenerationGroupResponse:
|
|
await check_project_access(project_id, current_user, dao)
|
|
if project_id:
|
|
generation.project_id = project_id
|
|
|
|
return await generation_service.create_generation_task(
|
|
generation,
|
|
user_id=str(current_user.get("_id"))
|
|
)
|
|
|
|
|
|
@router.get("/running")
|
|
async def get_running_generations(
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)
|
|
):
|
|
await check_project_access(project_id, current_user, dao)
|
|
user_id_filter = None if project_id else str(current_user["_id"])
|
|
|
|
return await generation_service.get_running_generations(
|
|
user_id=user_id_filter,
|
|
project_id=project_id
|
|
)
|
|
|
|
|
|
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
|
async def get_generation_group(
|
|
group_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
|
|
|
|
|
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
|
async def get_generation(
|
|
generation_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> GenerationResponse:
|
|
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
|
if not gen:
|
|
raise HTTPException(status_code=404, detail="Generation not found")
|
|
|
|
if gen.created_by != str(current_user["_id"]):
|
|
# Check project membership
|
|
is_member = False
|
|
if gen.project_id:
|
|
project = await generation_service.dao.projects.get_project(gen.project_id)
|
|
if project and str(current_user["_id"]) in project.members:
|
|
is_member = True
|
|
|
|
if not is_member:
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
return gen
|
|
|
|
|
|
@router.post("/{generation_id}/like", response_model=dict)
|
|
async def toggle_like(
|
|
generation_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
|
|
if is_liked is None:
|
|
raise HTTPException(status_code=404, detail="Generation not found")
|
|
return {"is_liked": is_liked}
|
|
|
|
|
|
@router.post("/import", response_model=GenerationResponse)
|
|
async def import_external_generation(
|
|
request: Request,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
x_signature: str = Header(..., alias="X-Signature")
|
|
) -> GenerationResponse:
|
|
body = await request.body()
|
|
|
|
secret = settings.EXTERNAL_API_SECRET
|
|
if not secret:
|
|
raise HTTPException(status_code=500, detail="Server configuration error")
|
|
|
|
if not verify_signature(body, x_signature, secret):
|
|
raise HTTPException(status_code=401, detail="Invalid signature")
|
|
|
|
try:
|
|
data = json.loads(body.decode('utf-8'))
|
|
external_gen = ExternalGenerationRequest(**data)
|
|
generation = await generation_service.import_external_generation(external_gen)
|
|
return GenerationResponse(**generation.model_dump())
|
|
except Exception as e:
|
|
logger.error(f"Failed to import external generation: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
|
|
|
|
|
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_generation(
|
|
generation_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
):
|
|
if not await generation_service.delete_generation(generation_id):
|
|
raise HTTPException(status_code=404, detail="Generation not found")
|
|
return None
|