from typing import List, Optional from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException from fastapi.params import Depends from starlette.requests import Request from api import service from api.dependency import get_generation_service, get_project_id, get_dao from repos.dao import DAO from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse from api.service.generation_service import GenerationService from models.Generation import Generation from starlette import status import logging logger = logging.getLogger(__name__) from api.endpoints.auth import get_current_user router = APIRouter(prefix='/api/generations', tags=["Generation"]) @router.post("/prompt-assistant", response_model=PromptResponse) async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, generation_service: GenerationService = Depends( get_generation_service), current_user: dict = Depends(get_current_user)) -> PromptResponse: logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}") generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) return PromptResponse(prompt=generated_prompt) @router.post("/prompt-from-image", response_model=PromptResponse) async def prompt_from_image( prompt: Optional[str] = Form(None), images: List[UploadFile] = File(...), generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> PromptResponse: logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}") images_bytes = [] for image in images: content = await image.read() images_bytes.append(content) generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt) return PromptResponse(prompt=generated_prompt) @router.get("", response_model=GenerationsResponse) async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)): logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") user_id_filter = str(current_user["_id"]) if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") user_id_filter = None # Show all project generations return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) @router.post("/_run", response_model=GenerationGroupResponse) async def post_generation(generation: GenerationRequest, request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") generation.project_id = project_id return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) @router.get("/running") async def get_running_generations(request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao)): user_id_filter = str(current_user["_id"]) if project_id: project = await dao.projects.get_project(project_id) if not project or str(current_user["_id"]) not in project.members: raise HTTPException(status_code=403, detail="Project access denied") user_id_filter = None return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) @router.get("/group/{group_id}", response_model=GenerationGroupResponse) async def get_generation_group(group_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)): logger.info(f"get_generation_group called for group_id: {group_id}") generations = await generation_service.dao.generations.get_generations_by_group(group_id) gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) @router.get("/{generation_id}", response_model=GenerationResponse) async def get_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)) -> GenerationResponse: logger.debug(f"get_generation called for ID: {generation_id}") gen = await generation_service.get_generation(generation_id) if gen and gen.created_by != str(current_user["_id"]): raise HTTPException(status_code=403, detail="Access denied") return gen @router.post("/import", response_model=GenerationResponse) async def import_external_generation( request: Request, generation_service: GenerationService = Depends(get_generation_service), x_signature: str = Header(..., alias="X-Signature") ) -> GenerationResponse: """ Import a generation from an external source. Requires server-to-server authentication via HMAC signature. """ import os from utils.external_auth import verify_signature from api.models.ExternalGenerationDTO import ExternalGenerationRequest logger.info("import_external_generation called") # Get raw request body for signature verification body = await request.body() # Verify signature secret = os.getenv("EXTERNAL_API_SECRET") if not secret: logger.error("EXTERNAL_API_SECRET not configured") raise HTTPException(status_code=500, detail="Server configuration error") if not verify_signature(body, x_signature, secret): logger.warning("Invalid signature for external generation import") raise HTTPException(status_code=401, detail="Invalid signature") # Parse request body import json try: data = json.loads(body.decode('utf-8')) external_gen = ExternalGenerationRequest(**data) except Exception as e: logger.error(f"Failed to parse request body: {e}") raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}") # Import generation try: generation = await generation_service.import_external_generation(external_gen) return GenerationResponse(**generation.model_dump()) except Exception as e: logger.error(f"Failed to import external generation: {e}") raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}") @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user)): logger.info(f"delete_generation called for ID: {generation_id}") deleted = await generation_service.delete_generation(generation_id) if not deleted: raise HTTPException(status_code=404, detail="Generation not found") return None