Files
ai-char-bot/api/endpoints/generation_router.py

72 lines
3.4 KiB
Python

from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form
from fastapi.params import Depends
from starlette.requests import Request
from api import service
from api.dependency import get_generation_service
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
import logging
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/generations', tags=["Generation"])
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
@router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
return await generation_service.create_generation_task(generation)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
return await generation_service.get_generation(generation_id)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service)):
return await generation_service.get_running_generations()