import logging import os import json from typing import List, Optional from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException from fastapi.params import Depends from starlette import status from starlette.requests import Request from api.dependency import get_generation_service, get_project_id, get_dao from api.endpoints.auth import get_current_user from api.models import ( GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse, FinancialReport, ExternalGenerationRequest ) from api.service.generation_service import GenerationService from repos.dao import DAO from utils.external_auth import verify_signature logger = logging.getLogger(__name__) router = APIRouter(prefix='/api/generations', tags=["Generation"]) @router.post("/prompt-assistant", response_model=PromptResponse) async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, generation_service: GenerationService = Depends( get_generation_service), current_user: dict = Depends(get_current_user)) -> PromptResponse: logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}") generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) return PromptResponse(prompt=generated_prompt) @router.post("/prompt-from-image", response_model=PromptResponse) async def prompt_from_image( prompt: Optional[str] = Form(None), images: List[UploadFile] = File(...), generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> PromptResponse: logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}") images_bytes = [] for image in images: content = await image.read() images_bytes.append(content) generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt) return PromptResponse(prompt=generated_prompt) @router.get("", response_model=GenerationsResponse) async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)): logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") user_id_filter = str(current_user["_id"]) if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") user_id_filter = None # Show all project generations return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) @router.get("/usage", response_model=FinancialReport) async def get_usage_report( breakdown: Optional[str] = None, # "user" or "project" generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ) -> FinancialReport: """ Returns usage statistics (runs, tokens, cost) for the current user or project. If project_id is provided, returns stats for that project. Otherwise, returns stats for the current user. """ user_id_filter = str(current_user["_id"]) breakdown_by = None if project_id: # Permission check project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") user_id_filter = None # If we are in project, we see stats for the WHOLE project by default if breakdown == "user": breakdown_by = "created_by" elif breakdown == "project": breakdown_by = "project_id" else: # Default: Stats for current user if breakdown == "project": breakdown_by = "project_id" elif breakdown == "user": # This would breakdown personal usage by user (yourself), but could be useful if it included collaborators? # No, if project_id is None, it's personal. breakdown_by = "created_by" return await generation_service.get_financial_report( user_id=user_id_filter, project_id=project_id, breakdown_by=breakdown_by ) @router.post("/_run", response_model=GenerationGroupResponse) async def post_generation(generation: GenerationRequest, request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") generation.project_id = project_id return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) @router.get("/running") async def get_running_generations(request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)): user_id_filter = str(current_user["_id"]) if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") user_id_filter = None return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) @router.get("/group/{group_id}", response_model=GenerationGroupResponse) async def get_generation_group(group_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)): logger.info(f"get_generation_group called for group_id: {group_id}") generations = await generation_service.dao.generations.get_generations_by_group(group_id) gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) @router.get("/{generation_id}", response_model=GenerationResponse) async def get_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)) -> GenerationResponse: logger.debug(f"get_generation called for ID: {generation_id}") gen = await generation_service.get_generation(generation_id) if gen and gen.created_by != str(current_user["_id"]): # Check project membership is_member = False if gen.project_id: project = await generation_service.dao.projects.get_project(gen.project_id) if project and str(current_user["_id"]) in project.members: is_member = True if not is_member: raise HTTPException(status_code=403, detail="Access denied") return gen @router.post("/import", response_model=GenerationResponse) async def import_external_generation( request: Request, generation_service: GenerationService = Depends(get_generation_service), x_signature: str = Header(..., alias="X-Signature") ) -> GenerationResponse: """ Import a generation from an external source. Requires server-to-server authentication via HMAC signature. """ logger.info("import_external_generation called") # Get raw request body for signature verification body = await request.body() # Verify signature secret = os.getenv("EXTERNAL_API_SECRET") if not secret: logger.error("EXTERNAL_API_SECRET not configured") raise HTTPException(status_code=500, detail="Server configuration error") if not verify_signature(body, x_signature, secret): logger.warning("Invalid signature for external generation import") raise HTTPException(status_code=401, detail="Invalid signature") # Parse request body try: data = json.loads(body.decode('utf-8')) external_gen = ExternalGenerationRequest(**data) except Exception as e: logger.error(f"Failed to parse request body: {e}") raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}") # Import generation try: generation = await generation_service.import_external_generation(external_gen) return GenerationResponse(**generation.model_dump()) except Exception as e: logger.error(f"Failed to import external generation: {e}") raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)): logger.info(f"delete_generation called for ID: {generation_id}") deleted = await generation_service.delete_generation(generation_id) if not deleted: raise HTTPException(status_code=404, detail="Generation not found") return None