feat: Introduce generation grouping, enabling multiple generations per request via a new count parameter and retrieval by group ID.

This commit is contained in:
xds
2026-02-13 11:18:11 +03:00
parent 977cab92f8
commit 30138bab38
10 changed files with 58 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ from api import service
from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
from api.service.generation_service import GenerationService
from models.Generation import Generation
@@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id:
@@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request,
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running")
async def get_running_generations(request: Request,
@@ -113,6 +103,27 @@ async def get_running_generations(request: Request,
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.post("/import", response_model=GenerationResponse)

View File

@@ -1,7 +1,7 @@
from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from models.Asset import Asset
from models.Generation import GenerationStatus
@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True
assets_list: List[str]
project_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10)
class GenerationsResponse(BaseModel):
@@ -45,10 +46,15 @@ class GenerationResponse(BaseModel):
progress: int = 0
cost: Optional[float] = None
created_by: Optional[str] = None
generation_group_id: Optional[str] = None
created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
class PromptRequest(BaseModel):
prompt: str

View File

@@ -5,13 +5,14 @@ import base64
from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO
from uuid import uuid4
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
@@ -113,14 +114,28 @@ class GenerationService:
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump())
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id:
generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id

View File

@@ -35,6 +35,7 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None
is_deleted: bool = False
album_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -77,3 +77,11 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations