226 lines
10 KiB
Python
226 lines
10 KiB
Python
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
|
from fastapi.params import Depends
|
|
from starlette.requests import Request
|
|
|
|
from api import service
|
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
|
from repos.dao import DAO
|
|
|
|
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
|
|
from api.models.FinancialUsageDTO import FinancialReport
|
|
from api.service.generation_service import GenerationService
|
|
from models.Generation import Generation
|
|
|
|
from starlette import status
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from api.endpoints.auth import get_current_user
|
|
|
|
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
|
|
|
|
|
@router.post("/prompt-assistant", response_model=PromptResponse)
|
|
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
|
generation_service: GenerationService = Depends(
|
|
get_generation_service),
|
|
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
|
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
|
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
|
return PromptResponse(prompt=generated_prompt)
|
|
|
|
|
|
@router.post("/prompt-from-image", response_model=PromptResponse)
|
|
async def prompt_from_image(
|
|
prompt: Optional[str] = Form(None),
|
|
images: List[UploadFile] = File(...),
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> PromptResponse:
|
|
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
|
images_bytes = []
|
|
for image in images:
|
|
content = await image.read()
|
|
images_bytes.append(content)
|
|
|
|
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
|
return PromptResponse(prompt=generated_prompt)
|
|
|
|
|
|
@router.get("", response_model=GenerationsResponse)
|
|
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)):
|
|
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
|
|
|
user_id_filter = str(current_user["_id"])
|
|
if project_id:
|
|
project = await dao.projects.get_project(project_id)
|
|
if not project or str(current_user["_id"]) not in project.members:
|
|
raise HTTPException(status_code=403, detail="Project access denied")
|
|
user_id_filter = None # Show all project generations
|
|
|
|
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
|
|
|
|
|
@router.get("/usage", response_model=FinancialReport)
|
|
async def get_usage_report(
|
|
breakdown: Optional[str] = None, # "user" or "project"
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)
|
|
) -> FinancialReport:
|
|
"""
|
|
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
|
If project_id is provided, returns stats for that project.
|
|
Otherwise, returns stats for the current user.
|
|
"""
|
|
user_id_filter = str(current_user["_id"])
|
|
breakdown_by = None
|
|
|
|
if project_id:
|
|
# Permission check
|
|
project = await dao.projects.get_project(project_id)
|
|
if not project or str(current_user["_id"]) not in project.members:
|
|
raise HTTPException(status_code=403, detail="Project access denied")
|
|
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
|
if breakdown == "user":
|
|
breakdown_by = "created_by"
|
|
elif breakdown == "project":
|
|
breakdown_by = "project_id"
|
|
else:
|
|
# Default: Stats for current user
|
|
if breakdown == "project":
|
|
breakdown_by = "project_id"
|
|
elif breakdown == "user":
|
|
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
|
# No, if project_id is None, it's personal.
|
|
breakdown_by = "created_by"
|
|
|
|
return await generation_service.get_financial_report(
|
|
user_id=user_id_filter,
|
|
project_id=project_id,
|
|
breakdown_by=breakdown_by
|
|
)
|
|
|
|
@router.post("/_run", response_model=GenerationGroupResponse)
|
|
async def post_generation(generation: GenerationRequest, request: Request,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
|
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
|
|
|
if project_id:
|
|
project = await dao.projects.get_project(project_id)
|
|
if not project or str(current_user["_id"]) not in project.members:
|
|
raise HTTPException(status_code=403, detail="Project access denied")
|
|
generation.project_id = project_id
|
|
|
|
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
|
|
|
|
|
|
|
@router.get("/running")
|
|
async def get_running_generations(request: Request,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user),
|
|
project_id: Optional[str] = Depends(get_project_id),
|
|
dao: DAO = Depends(get_dao)):
|
|
|
|
user_id_filter = str(current_user["_id"])
|
|
if project_id:
|
|
project = await dao.projects.get_project(project_id)
|
|
if not project or str(current_user["_id"]) not in project.members:
|
|
raise HTTPException(status_code=403, detail="Project access denied")
|
|
user_id_filter = None
|
|
|
|
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
|
|
|
|
|
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
|
async def get_generation_group(group_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)):
|
|
logger.info(f"get_generation_group called for group_id: {group_id}")
|
|
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
|
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
|
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
|
|
|
|
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
|
async def get_generation(generation_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
|
logger.debug(f"get_generation called for ID: {generation_id}")
|
|
gen = await generation_service.get_generation(generation_id)
|
|
if gen and gen.created_by != str(current_user["_id"]):
|
|
raise HTTPException(status_code=403, detail="Access denied")
|
|
return gen
|
|
|
|
|
|
|
|
|
|
@router.post("/import", response_model=GenerationResponse)
|
|
async def import_external_generation(
|
|
request: Request,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
x_signature: str = Header(..., alias="X-Signature")
|
|
) -> GenerationResponse:
|
|
"""
|
|
Import a generation from an external source.
|
|
Requires server-to-server authentication via HMAC signature.
|
|
"""
|
|
import os
|
|
from utils.external_auth import verify_signature
|
|
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
|
|
|
logger.info("import_external_generation called")
|
|
|
|
# Get raw request body for signature verification
|
|
body = await request.body()
|
|
|
|
# Verify signature
|
|
secret = os.getenv("EXTERNAL_API_SECRET")
|
|
if not secret:
|
|
logger.error("EXTERNAL_API_SECRET not configured")
|
|
raise HTTPException(status_code=500, detail="Server configuration error")
|
|
|
|
if not verify_signature(body, x_signature, secret):
|
|
logger.warning("Invalid signature for external generation import")
|
|
raise HTTPException(status_code=401, detail="Invalid signature")
|
|
|
|
# Parse request body
|
|
import json
|
|
try:
|
|
data = json.loads(body.decode('utf-8'))
|
|
external_gen = ExternalGenerationRequest(**data)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse request body: {e}")
|
|
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
|
|
|
# Import generation
|
|
try:
|
|
generation = await generation_service.import_external_generation(external_gen)
|
|
return GenerationResponse(**generation.model_dump())
|
|
except Exception as e:
|
|
logger.error(f"Failed to import external generation: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
|
|
|
|
|
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_generation(generation_id: str,
|
|
generation_service: GenerationService = Depends(get_generation_service),
|
|
current_user: dict = Depends(get_current_user)):
|
|
logger.info(f"delete_generation called for ID: {generation_id}")
|
|
deleted = await generation_service.delete_generation(generation_id)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Generation not found")
|
|
return None |