This commit is contained in:
xds
2026-02-20 13:10:37 +03:00
parent 9e0c522b5f
commit 1868864f76
6 changed files with 358 additions and 478 deletions

View File

@@ -1,5 +1,4 @@
import logging
import os
import json
from typing import List, Optional
@@ -30,12 +29,22 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/generations', tags=["Generation"])
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
"""Helper to check if user has access to project."""
if not project_id:
return
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service),
current_user: dict = Depends(get_current_user)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
async def ask_prompt_assistant(
prompt_request: PromptRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@@ -47,32 +56,33 @@ async def prompt_from_image(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
images_bytes = [await img.read() for img in images]
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
async def get_generations(
character_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
# If project_id is set, we don't filter by user to show all project-wide generations
created_by_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
return await generation_service.get_generations(
character_id=character_id,
limit=limit,
offset=offset,
created_by=created_by_filter,
project_id=project_id
)
@router.get("/usage", response_model=FinancialReport)
@@ -83,32 +93,15 @@ async def get_usage_report(
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"]) if not project_id else None
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
return await generation_service.get_financial_report(
user_id=user_id_filter,
@@ -116,58 +109,61 @@ async def get_usage_report(
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(
generation: GenerationRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> GenerationGroupResponse:
await check_project_access(project_id, current_user, dao)
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
return await generation_service.create_generation_task(
generation,
user_id=str(current_user.get("_id"))
)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
async def get_running_generations(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
return await generation_service.get_running_generations(
user_id=user_id_filter,
project_id=project_id
)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
async def get_generation_group(
group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
return await generation_service.get_generations_by_group(group_id)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
async def get_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> GenerationResponse:
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
@@ -180,43 +176,24 @@ async def get_generation(generation_id: str,
return gen
@router.post("/import", response_model=GenerationResponse)
async def import_external_generation(
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
"""
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = settings.EXTERNAL_API_SECRET
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
logger.warning("Invalid signature for external generation import")
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
# Import generation
try:
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
@@ -225,11 +202,11 @@ async def import_external_generation(
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
async def delete_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
if not await generation_service.delete_generation(generation_id):
raise HTTPException(status_code=404, detail="Generation not found")
return None