feat: Introduce generation grouping, enabling multiple generations per request via a new count parameter and retrieval by group ID.

This commit is contained in:
xds
2026-02-13 11:18:11 +03:00
parent 977cab92f8
commit 30138bab38
10 changed files with 58 additions and 17 deletions

View File

@@ -8,7 +8,7 @@ from api import service
from api.dependency import get_generation_service, get_project_id, get_dao from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Generation import Generation from models.Generation import Generation
@@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse) @router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service), generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id), project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse: dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id: if project_id:
@@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request,
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(request: Request,
@@ -113,6 +103,27 @@ async def get_running_generations(request: Request,
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.post("/import", response_model=GenerationResponse) @router.post("/import", response_model=GenerationResponse)

View File

@@ -1,7 +1,7 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from models.Asset import Asset from models.Asset import Asset
from models.Generation import GenerationStatus from models.Generation import GenerationStatus
@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None project_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10)
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):
@@ -45,10 +46,15 @@ class GenerationResponse(BaseModel):
progress: int = 0 progress: int = 0
cost: Optional[float] = None cost: Optional[float] = None
created_by: Optional[str] = None created_by: Optional[str] = None
generation_group_id: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str

View File

@@ -5,13 +5,14 @@ import base64
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO from io import BytesIO
from uuid import uuid4
import httpx import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно # Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
@@ -113,14 +114,28 @@ class GenerationService:
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None gen_id = None
generation_model = None generation_model = None
try: try:
generation_model = Generation(**generation_request.model_dump()) generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id: if user_id:
generation_model.created_by = user_id generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id

View File

@@ -35,6 +35,7 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None album_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -77,3 +77,11 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations