feat: Introduce generation grouping, enabling multiple generations per request via a new count parameter and retrieval by group ID.
This commit is contained in:
Binary file not shown.
@@ -8,7 +8,7 @@ from api import service
|
|||||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
|
||||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from models.Generation import Generation
|
from models.Generation import Generation
|
||||||
|
|
||||||
@@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
|
|||||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/_run", response_model=GenerationResponse)
|
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||||
async def post_generation(generation: GenerationRequest, request: Request,
|
async def post_generation(generation: GenerationRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
dao: DAO = Depends(get_dao)) -> GenerationResponse:
|
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
||||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||||
|
|
||||||
if project_id:
|
if project_id:
|
||||||
@@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request,
|
|||||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
|
||||||
async def get_generation(generation_id: str,
|
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
|
||||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
|
||||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
|
||||||
gen = await generation_service.get_generation(generation_id)
|
|
||||||
if gen and gen.created_by != str(current_user["_id"]):
|
|
||||||
raise HTTPException(status_code=403, detail="Access denied")
|
|
||||||
return gen
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/running")
|
@router.get("/running")
|
||||||
async def get_running_generations(request: Request,
|
async def get_running_generations(request: Request,
|
||||||
@@ -113,6 +103,27 @@ async def get_running_generations(request: Request,
|
|||||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||||
|
async def get_generation_group(group_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)):
|
||||||
|
logger.info(f"get_generation_group called for group_id: {group_id}")
|
||||||
|
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
||||||
|
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
|
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||||
|
async def get_generation(generation_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||||
|
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||||
|
gen = await generation_service.get_generation(generation_id)
|
||||||
|
if gen and gen.created_by != str(current_user["_id"]):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
return gen
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=GenerationResponse)
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Generation import GenerationStatus
|
from models.Generation import GenerationStatus
|
||||||
@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
|
|||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
|
count: int = Field(default=1, ge=1, le=10)
|
||||||
|
|
||||||
|
|
||||||
class GenerationsResponse(BaseModel):
|
class GenerationsResponse(BaseModel):
|
||||||
@@ -45,10 +46,15 @@ class GenerationResponse(BaseModel):
|
|||||||
progress: int = 0
|
progress: int = 0
|
||||||
cost: Optional[float] = None
|
cost: Optional[float] = None
|
||||||
created_by: Optional[str] = None
|
created_by: Optional[str] = None
|
||||||
|
generation_group_id: Optional[str] = None
|
||||||
created_at: datetime = datetime.now(UTC)
|
created_at: datetime = datetime.now(UTC)
|
||||||
updated_at: datetime = datetime.now(UTC)
|
updated_at: datetime = datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationGroupResponse(BaseModel):
|
||||||
|
generation_group_id: str
|
||||||
|
generations: List[GenerationResponse]
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -5,13 +5,14 @@ import base64
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional, Tuple, Any, Dict
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from uuid import uuid4
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||||
from models.Asset import Asset, AssetType, AssetContentType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Generation import Generation, GenerationStatus
|
from models.Generation import Generation, GenerationStatus
|
||||||
@@ -113,14 +114,28 @@ class GenerationService:
|
|||||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||||
|
|
||||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||||
|
count = generation_request.count
|
||||||
|
|
||||||
|
if generation_group_id is None:
|
||||||
|
generation_group_id = str(uuid4())
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for _ in range(count):
|
||||||
|
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||||
|
results.append(gen_response)
|
||||||
|
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||||
|
|
||||||
|
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
|
||||||
gen_id = None
|
gen_id = None
|
||||||
generation_model = None
|
generation_model = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_model = Generation(**generation_request.model_dump())
|
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||||
if user_id:
|
if user_id:
|
||||||
generation_model.created_by = user_id
|
generation_model.created_by = user_id
|
||||||
|
if generation_group_id:
|
||||||
|
generation_model.generation_group_id = generation_group_id
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||||
generation_model.id = gen_id
|
generation_model.id = gen_id
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class Generation(BaseModel):
|
|||||||
output_token_usage: Optional[int] = None
|
output_token_usage: Optional[int] = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
album_id: Optional[str] = None
|
album_id: Optional[str] = None
|
||||||
|
generation_group_id: Optional[str] = None
|
||||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -77,3 +77,11 @@ class GenerationRepo:
|
|||||||
|
|
||||||
async def update_generation(self, generation: Generation, ):
|
async def update_generation(self, generation: Generation, ):
|
||||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||||
|
|
||||||
|
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
||||||
|
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
||||||
|
generations: List[Generation] = []
|
||||||
|
for generation in res:
|
||||||
|
generation["id"] = str(generation.pop("_id"))
|
||||||
|
generations.append(Generation(**generation))
|
||||||
|
return generations
|
||||||
|
|||||||
Reference in New Issue
Block a user