Files
ai-char-bot/api/endpoints/generation_router.py

84 lines
4.1 KiB
Python

from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form
from fastapi.params import Depends
from starlette.requests import Request
from api import service
from api.dependency import get_generation_service
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
from starlette import status
import logging
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
@router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
return await generation_service.create_generation_task(generation, user_id=current_user.get("username"))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
return await generation_service.get_generation(generation_id)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service)):
return await generation_service.get_running_generations()
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")
return None