Files
ai-char-bot/api/endpoints/generation_router.py

120 lines
6.0 KiB
Python

from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form
from fastapi.params import Depends
from starlette.requests import Request
from api import service
from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
from starlette import status
import logging
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")
return None