fixes
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -30,12 +29,22 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
|
||||
"""Helper to check if user has access to project."""
|
||||
if not project_id:
|
||||
return
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
async def ask_prompt_assistant(
|
||||
prompt_request: PromptRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
@@ -47,32 +56,33 @@ async def prompt_from_image(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
for image in images:
|
||||
content = await image.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
images_bytes = [await img.read() for img in images]
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.get("", response_model=GenerationsResponse)
|
||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
async def get_generations(
|
||||
character_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # Show all project generations
|
||||
# If project_id is set, we don't filter by user to show all project-wide generations
|
||||
created_by_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||
return await generation_service.get_generations(
|
||||
character_id=character_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
created_by=created_by_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
@@ -83,32 +93,15 @@ async def get_usage_report(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
"""
|
||||
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
||||
If project_id is provided, returns stats for that project.
|
||||
Otherwise, returns stats for the current user.
|
||||
"""
|
||||
user_id_filter = str(current_user["_id"])
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"]) if not project_id else None
|
||||
breakdown_by = None
|
||||
|
||||
if project_id:
|
||||
# Permission check
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
else:
|
||||
# Default: Stats for current user
|
||||
if breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
elif breakdown == "user":
|
||||
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
||||
# No, if project_id is None, it's personal.
|
||||
breakdown_by = "created_by"
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
@@ -116,58 +109,61 @@ async def get_usage_report(
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(
|
||||
generation: GenerationRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> GenerationGroupResponse:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
generation.project_id = project_id
|
||||
|
||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||
|
||||
return await generation_service.create_generation_task(
|
||||
generation,
|
||||
user_id=str(current_user.get("_id"))
|
||||
)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
async def get_running_generations(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
user_id_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||
return await generation_service.get_running_generations(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||
async def get_generation_group(group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"get_generation_group called for group_id: {group_id}")
|
||||
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
||||
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
||||
async def get_generation_group(
|
||||
group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations_by_group(group_id)
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> GenerationResponse:
|
||||
gen = await generation_service.get_generation(generation_id)
|
||||
if gen and gen.created_by != str(current_user["_id"]):
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
@@ -180,43 +176,24 @@ async def get_generation(generation_id: str,
|
||||
return gen
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Verify signature
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
logger.warning("Invalid signature for external generation import")
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||
|
||||
# Import generation
|
||||
try:
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
@@ -225,11 +202,11 @@ async def import_external_generation(
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
if not await generation_service.delete_generation(generation_id):
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user