diff --git a/api/endpoints/__pycache__/generation_router.cpython-313.pyc b/api/endpoints/__pycache__/generation_router.cpython-313.pyc index fb99eaf..f9e6eba 100644 Binary files a/api/endpoints/__pycache__/generation_router.cpython-313.pyc and b/api/endpoints/__pycache__/generation_router.cpython-313.pyc differ diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py index 85c4f61..79f68bd 100644 --- a/api/endpoints/generation_router.py +++ b/api/endpoints/generation_router.py @@ -8,7 +8,7 @@ from api import service from api.dependency import get_generation_service, get_project_id, get_dao from repos.dao import DAO -from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest +from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse from api.service.generation_service import GenerationService from models.Generation import Generation @@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) -@router.post("/_run", response_model=GenerationResponse) +@router.post("/_run", response_model=GenerationGroupResponse) async def post_generation(generation: GenerationRequest, request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), - dao: DAO = Depends(get_dao)) -> GenerationResponse: + dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") if project_id: @@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request, return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) -@router.get("/{generation_id}", response_model=GenerationResponse) -async def get_generation(generation_id: str, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)) -> GenerationResponse: - logger.debug(f"get_generation called for ID: {generation_id}") - gen = await generation_service.get_generation(generation_id) - if gen and gen.created_by != str(current_user["_id"]): - raise HTTPException(status_code=403, detail="Access denied") - return gen - @router.get("/running") async def get_running_generations(request: Request, @@ -113,6 +103,27 @@ async def get_running_generations(request: Request, return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) +@router.get("/group/{group_id}", response_model=GenerationGroupResponse) +async def get_generation_group(group_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user)): + logger.info(f"get_generation_group called for group_id: {group_id}") + generations = await generation_service.dao.generations.get_generations_by_group(group_id) + gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] + return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) + + +@router.get("/{generation_id}", response_model=GenerationResponse) +async def get_generation(generation_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user)) -> GenerationResponse: + logger.debug(f"get_generation called for ID: {generation_id}") + gen = await generation_service.get_generation(generation_id) + if gen and gen.created_by != str(current_user["_id"]): + raise HTTPException(status_code=403, detail="Access denied") + return gen + + @router.post("/import", response_model=GenerationResponse) diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index 40e9d18..33a4010 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -1,7 +1,7 @@ from datetime import datetime, UTC from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from models.Asset import Asset from models.Generation import GenerationStatus @@ -17,6 +17,7 @@ class GenerationRequest(BaseModel): use_profile_image: bool = True assets_list: List[str] project_id: Optional[str] = None + count: int = Field(default=1, ge=1, le=10) class GenerationsResponse(BaseModel): @@ -45,10 +46,15 @@ class GenerationResponse(BaseModel): progress: int = 0 cost: Optional[float] = None created_by: Optional[str] = None + generation_group_id: Optional[str] = None created_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC) +class GenerationGroupResponse(BaseModel): + generation_group_id: str + generations: List[GenerationResponse] + class PromptRequest(BaseModel): prompt: str diff --git a/api/models/__pycache__/GenerationRequest.cpython-313.pyc b/api/models/__pycache__/GenerationRequest.cpython-313.pyc index b958d60..0726ddb 100644 Binary files a/api/models/__pycache__/GenerationRequest.cpython-313.pyc and b/api/models/__pycache__/GenerationRequest.cpython-313.pyc differ diff --git a/api/service/__pycache__/generation_service.cpython-313.pyc b/api/service/__pycache__/generation_service.cpython-313.pyc index 24461f4..1970192 100644 Binary files a/api/service/__pycache__/generation_service.cpython-313.pyc and b/api/service/__pycache__/generation_service.cpython-313.pyc differ diff --git a/api/service/generation_service.py b/api/service/generation_service.py index 8cf3b49..74a388a 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -5,13 +5,14 @@ import base64 from datetime import datetime, UTC from typing import List, Optional, Tuple, Any, Dict from io import BytesIO +from uuid import uuid4 import httpx from aiogram import Bot from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter -from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse +from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse # Импортируйте ваши модели DAO, Asset, Generation корректно from models.Asset import Asset, AssetType, AssetContentType from models.Generation import Generation, GenerationStatus @@ -113,14 +114,28 @@ class GenerationService: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) - async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: + async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: + count = generation_request.count + + if generation_group_id is None: + generation_group_id = str(uuid4()) + + results = [] + for _ in range(count): + gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id) + results.append(gen_response) + return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results) + + async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse: gen_id = None generation_model = None try: - generation_model = Generation(**generation_request.model_dump()) + generation_model = Generation(**generation_request.model_dump(exclude={'count'})) if user_id: generation_model.created_by = user_id + if generation_group_id: + generation_model.generation_group_id = generation_group_id gen_id = await self.dao.generations.create_generation(generation_model) generation_model.id = gen_id diff --git a/models/Generation.py b/models/Generation.py index 6c74100..9a75c84 100644 --- a/models/Generation.py +++ b/models/Generation.py @@ -35,6 +35,7 @@ class Generation(BaseModel): output_token_usage: Optional[int] = None is_deleted: bool = False album_id: Optional[str] = None + generation_group_id: Optional[str] = None created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) project_id: Optional[str] = None created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/models/__pycache__/Generation.cpython-313.pyc b/models/__pycache__/Generation.cpython-313.pyc index b052293..cb9ecdb 100644 Binary files a/models/__pycache__/Generation.cpython-313.pyc and b/models/__pycache__/Generation.cpython-313.pyc differ diff --git a/repos/__pycache__/generation_repo.cpython-313.pyc b/repos/__pycache__/generation_repo.cpython-313.pyc index 7ef0c95..34fb78e 100644 Binary files a/repos/__pycache__/generation_repo.cpython-313.pyc and b/repos/__pycache__/generation_repo.cpython-313.pyc differ diff --git a/repos/generation_repo.py b/repos/generation_repo.py index c668e97..b561548 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -77,3 +77,11 @@ class GenerationRepo: async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) + + async def get_generations_by_group(self, group_id: str) -> List[Generation]: + res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) + generations: List[Generation] = [] + for generation in res: + generation["id"] = str(generation.pop("_id")) + generations.append(Generation(**generation)) + return generations