import logging import json from typing import List, Optional from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException from fastapi.params import Depends from starlette import status from starlette.requests import Request from config import settings from api.dependency import get_generation_service, get_project_id, get_dao from api.endpoints.auth import get_current_user from api.models import ( GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse, FinancialReport, ExternalGenerationRequest, NsfwRequest ) from api.service.generation_service import GenerationService from repos.dao import DAO from utils.external_auth import verify_signature logger = logging.getLogger(__name__) router = APIRouter(prefix='/api/generations', tags=["Generation"]) async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO): """Helper to check if user has access to project.""" if not project_id: return project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") @router.post("/prompt-assistant", response_model=PromptResponse) async def ask_prompt_assistant( prompt_request: PromptRequest, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> PromptResponse: logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars") generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) return PromptResponse(prompt=generated_prompt) @router.post("/prompt-from-image", response_model=PromptResponse) async def prompt_from_image( prompt: Optional[str] = Form(None), images: List[UploadFile] = File(...), generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> PromptResponse: images_bytes = [await img.read() for img in images] generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt) return PromptResponse(prompt=generated_prompt) @router.get("", response_model=GenerationsResponse) async def get_generations( character_id: Optional[str] = None, limit: int = 10, offset: int = 0, only_liked: bool = False, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ): await check_project_access(project_id, current_user, dao) # If project_id is set, we don't filter by user to show all project-wide generations created_by_filter = None if project_id else str(current_user["_id"]) only_liked_by = str(current_user["_id"]) if only_liked else None return await generation_service.get_generations( character_id=character_id, limit=limit, offset=offset, created_by=created_by_filter, project_id=project_id, only_liked_by=only_liked_by, current_user_id=str(current_user["_id"]) ) @router.get("/usage", response_model=FinancialReport) async def get_usage_report( breakdown: Optional[str] = None, # "user" or "project" generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ) -> FinancialReport: await check_project_access(project_id, current_user, dao) user_id_filter = str(current_user["_id"]) if not project_id else None breakdown_by = None if breakdown == "user": breakdown_by = "created_by" elif breakdown == "project": breakdown_by = "project_id" return await generation_service.get_financial_report( user_id=user_id_filter, project_id=project_id, breakdown_by=breakdown_by ) @router.post("/_run", response_model=GenerationGroupResponse) async def post_generation( generation: GenerationRequest, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ) -> GenerationGroupResponse: await check_project_access(project_id, current_user, dao) if project_id: generation.project_id = project_id return await generation_service.create_generation_task( generation, user_id=str(current_user.get("_id")) ) @router.get("/running") async def get_running_generations( generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ): await check_project_access(project_id, current_user, dao) user_id_filter = None if project_id else str(current_user["_id"]) return await generation_service.get_running_generations( user_id=user_id_filter, project_id=project_id ) @router.get("/group/{group_id}", response_model=GenerationGroupResponse) async def get_generation_group( group_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ): return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"])) @router.get("/{generation_id}", response_model=GenerationResponse) async def get_generation( generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> GenerationResponse: gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"])) if not gen: raise HTTPException(status_code=404, detail="Generation not found") if gen.created_by != str(current_user["_id"]): # Check project membership is_member = False if gen.project_id: project = await generation_service.dao.projects.get_project(gen.project_id) if project and str(current_user["_id"]) in project.members: is_member = True if not is_member: raise HTTPException(status_code=403, detail="Access denied") return gen @router.post("/{generation_id}/like", response_model=dict) async def toggle_like( generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ): is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"])) if is_liked is None: raise HTTPException(status_code=404, detail="Generation not found") return {"is_liked": is_liked} @router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT) async def mark_generation_nsfw( generation_id: str, request: NsfwRequest, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ): gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"])) if not gen: raise HTTPException(status_code=404, detail="Generation not found") if gen.created_by != str(current_user["_id"]): is_member = False if gen.project_id: project = await generation_service.dao.projects.get_project(gen.project_id) if project and str(current_user["_id"]) in project.members: is_member = True if not is_member: raise HTTPException(status_code=403, detail="Access denied") await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw) return None @router.post("/import", response_model=GenerationResponse) async def import_external_generation( request: Request, generation_service: GenerationService = Depends(get_generation_service), x_signature: str = Header(..., alias="X-Signature") ) -> GenerationResponse: body = await request.body() secret = settings.EXTERNAL_API_SECRET if not secret: raise HTTPException(status_code=500, detail="Server configuration error") if not verify_signature(body, x_signature, secret): raise HTTPException(status_code=401, detail="Invalid signature") try: data = json.loads(body.decode('utf-8')) external_gen = ExternalGenerationRequest(**data) generation = await generation_service.import_external_generation(external_gen) return GenerationResponse(**generation.model_dump()) except Exception as e: logger.error(f"Failed to import external generation: {e}") raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_generation( generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ): if not await generation_service.delete_generation(generation_id): raise HTTPException(status_code=404, detail="Generation not found") return None