diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py index 58eadb4..dc6a98b 100644 --- a/api/endpoints/generation_router.py +++ b/api/endpoints/generation_router.py @@ -1,5 +1,4 @@ import logging -import os import json from typing import List, Optional @@ -30,12 +29,22 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix='/api/generations', tags=["Generation"]) +async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO): + """Helper to check if user has access to project.""" + if not project_id: + return + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + + @router.post("/prompt-assistant", response_model=PromptResponse) -async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, - generation_service: GenerationService = Depends( - get_generation_service), - current_user: dict = Depends(get_current_user)) -> PromptResponse: - logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}") +async def ask_prompt_assistant( + prompt_request: PromptRequest, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user) +) -> PromptResponse: + logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars") generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) return PromptResponse(prompt=generated_prompt) @@ -47,32 +56,33 @@ async def prompt_from_image( generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user) ) -> PromptResponse: - logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}") - images_bytes = [] - for image in images: - content = await image.read() - images_bytes.append(content) - + images_bytes = [await img.read() for img in images] generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt) return PromptResponse(prompt=generated_prompt) @router.get("", response_model=GenerationsResponse) -async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user), - project_id: Optional[str] = Depends(get_project_id), - dao: DAO = Depends(get_dao)): - logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") +async def get_generations( + character_id: Optional[str] = None, + limit: int = 10, + offset: int = 0, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao) +): + await check_project_access(project_id, current_user, dao) - user_id_filter = str(current_user["_id"]) - if project_id: - project = await dao.projects.get_project(project_id) - if not project or str(current_user["_id"]) not in project.members: - raise HTTPException(status_code=403, detail="Project access denied") - user_id_filter = None # Show all project generations + # If project_id is set, we don't filter by user to show all project-wide generations + created_by_filter = None if project_id else str(current_user["_id"]) - return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) + return await generation_service.get_generations( + character_id=character_id, + limit=limit, + offset=offset, + created_by=created_by_filter, + project_id=project_id + ) @router.get("/usage", response_model=FinancialReport) @@ -83,32 +93,15 @@ async def get_usage_report( project_id: Optional[str] = Depends(get_project_id), dao: DAO = Depends(get_dao) ) -> FinancialReport: - """ - Returns usage statistics (runs, tokens, cost) for the current user or project. - If project_id is provided, returns stats for that project. - Otherwise, returns stats for the current user. - """ - user_id_filter = str(current_user["_id"]) + await check_project_access(project_id, current_user, dao) + + user_id_filter = str(current_user["_id"]) if not project_id else None breakdown_by = None - if project_id: - # Permission check - project = await dao.projects.get_project(project_id) - if not project or str(current_user["_id"]) not in project.members: - raise HTTPException(status_code=403, detail="Project access denied") - user_id_filter = None # If we are in project, we see stats for the WHOLE project by default - if breakdown == "user": - breakdown_by = "created_by" - elif breakdown == "project": - breakdown_by = "project_id" - else: - # Default: Stats for current user - if breakdown == "project": - breakdown_by = "project_id" - elif breakdown == "user": - # This would breakdown personal usage by user (yourself), but could be useful if it included collaborators? - # No, if project_id is None, it's personal. - breakdown_by = "created_by" + if breakdown == "user": + breakdown_by = "created_by" + elif breakdown == "project": + breakdown_by = "project_id" return await generation_service.get_financial_report( user_id=user_id_filter, @@ -116,58 +109,61 @@ async def get_usage_report( breakdown_by=breakdown_by ) -@router.post("/_run", response_model=GenerationGroupResponse) -async def post_generation(generation: GenerationRequest, request: Request, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user), - project_id: Optional[str] = Depends(get_project_id), - dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: - logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") +@router.post("/_run", response_model=GenerationGroupResponse) +async def post_generation( + generation: GenerationRequest, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao) +) -> GenerationGroupResponse: + await check_project_access(project_id, current_user, dao) if project_id: - project = await dao.projects.get_project(project_id) - if not project or str(current_user["_id"]) not in project.members: - raise HTTPException(status_code=403, detail="Project access denied") generation.project_id = project_id - return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) - + return await generation_service.create_generation_task( + generation, + user_id=str(current_user.get("_id")) + ) @router.get("/running") -async def get_running_generations(request: Request, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user), - project_id: Optional[str] = Depends(get_project_id), - dao: DAO = Depends(get_dao)): - - user_id_filter = str(current_user["_id"]) - if project_id: - project = await dao.projects.get_project(project_id) - if not project or str(current_user["_id"]) not in project.members: - raise HTTPException(status_code=403, detail="Project access denied") - user_id_filter = None +async def get_running_generations( + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao) +): + await check_project_access(project_id, current_user, dao) + user_id_filter = None if project_id else str(current_user["_id"]) - return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) + return await generation_service.get_running_generations( + user_id=user_id_filter, + project_id=project_id + ) @router.get("/group/{group_id}", response_model=GenerationGroupResponse) -async def get_generation_group(group_id: str, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)): - logger.info(f"get_generation_group called for group_id: {group_id}") - generations = await generation_service.dao.generations.get_generations_by_group(group_id) - gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] - return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) +async def get_generation_group( + group_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user) +): + return await generation_service.get_generations_by_group(group_id) @router.get("/{generation_id}", response_model=GenerationResponse) -async def get_generation(generation_id: str, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)) -> GenerationResponse: - logger.debug(f"get_generation called for ID: {generation_id}") +async def get_generation( + generation_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user) +) -> GenerationResponse: gen = await generation_service.get_generation(generation_id) - if gen and gen.created_by != str(current_user["_id"]): + if not gen: + raise HTTPException(status_code=404, detail="Generation not found") + + if gen.created_by != str(current_user["_id"]): # Check project membership is_member = False if gen.project_id: @@ -180,43 +176,24 @@ async def get_generation(generation_id: str, return gen - - @router.post("/import", response_model=GenerationResponse) async def import_external_generation( request: Request, generation_service: GenerationService = Depends(get_generation_service), x_signature: str = Header(..., alias="X-Signature") ) -> GenerationResponse: - """ - Import a generation from an external source. - Requires server-to-server authentication via HMAC signature. - """ - - logger.info("import_external_generation called") - # Get raw request body for signature verification body = await request.body() - # Verify signature secret = settings.EXTERNAL_API_SECRET if not secret: - logger.error("EXTERNAL_API_SECRET not configured") raise HTTPException(status_code=500, detail="Server configuration error") if not verify_signature(body, x_signature, secret): - logger.warning("Invalid signature for external generation import") raise HTTPException(status_code=401, detail="Invalid signature") - # Parse request body try: data = json.loads(body.decode('utf-8')) external_gen = ExternalGenerationRequest(**data) - except Exception as e: - logger.error(f"Failed to parse request body: {e}") - raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}") - - # Import generation - try: generation = await generation_service.import_external_generation(external_gen) return GenerationResponse(**generation.model_dump()) except Exception as e: @@ -225,11 +202,11 @@ async def import_external_generation( @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT) -async def delete_generation(generation_id: str, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)): - logger.info(f"delete_generation called for ID: {generation_id}") - deleted = await generation_service.delete_generation(generation_id) - if not deleted: +async def delete_generation( + generation_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user) +): + if not await generation_service.delete_generation(generation_id): raise HTTPException(status_code=404, detail="Generation not found") return None diff --git a/api/models/ExternalGenerationDTO.py b/api/models/ExternalGenerationDTO.py index b9a2c00..a56deca 100644 --- a/api/models/ExternalGenerationDTO.py +++ b/api/models/ExternalGenerationDTO.py @@ -14,7 +14,7 @@ class ExternalGenerationRequest(BaseModel): image_url: Optional[str] = Field(None, description="URL to download image from") # Generation metadata - aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN + aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9" quality: Quality = Quality.ONEK # Optional linking diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index 0094189..c7fe14e 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -10,7 +10,7 @@ from models.enums import AspectRatios, Quality, GenType class GenerationRequest(BaseModel): linked_character_id: Optional[str] = None - aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN + aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9" quality: Quality = Quality.ONEK prompt: str telegram_id: Optional[int] = None diff --git a/api/service/generation_service.py b/api/service/generation_service.py index 81c8bd3..70711ae 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -13,13 +13,15 @@ from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter from adapters.s3_adapter import S3Adapter -from api.models import FinancialReport, UsageStats, UsageByEntity -from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse -# Импортируйте ваши модели DAO, Asset, Generation корректно +from api.models import ( + FinancialReport, UsageStats, UsageByEntity, + GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse +) from models.Asset import Asset, AssetType, AssetContentType from models.Generation import Generation, GenerationStatus from models.enums import AspectRatios, Quality from repos.dao import DAO +from utils.image_utils import create_thumbnail logger = logging.getLogger(__name__) @@ -27,22 +29,18 @@ logger = logging.getLogger(__name__) generation_semaphore = asyncio.Semaphore(4) -# --- Вспомогательная функция генерации --- async def generate_image_task( prompt: str, media_group_bytes: List[bytes], aspect_ratio: AspectRatios, quality: Quality, gemini: GoogleAdapter, - ) -> Tuple[List[bytes], Dict[str, Any]]: """ - Обертка для вызова синхронного метода Gemini в отдельном потоке. - Возвращает список байтов сгенерированных изображений. + Wrapper for calling Gemini's synchronous method in a separate thread. """ - try : + try: logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}") - # Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop result = await asyncio.to_thread( gemini.generate_image, prompt=prompt, @@ -51,12 +49,10 @@ async def generate_image_task( quality=quality, ) generated_images_io, metrics = result - logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") - except GoogleGenerationException as e: - raise e + except GoogleGenerationException: + raise finally: - # Освобождаем входные данные — они больше не нужны del media_group_bytes images_bytes = [] @@ -65,371 +61,136 @@ async def generate_image_task( img_io.seek(0) images_bytes.append(img_io.read()) img_io.close() - # Освобождаем список BytesIO сразу del generated_images_io return images_bytes, metrics + class GenerationService: - def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): + def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): self.dao = dao self.gemini = gemini self.s3_adapter = s3_adapter self.bot = bot - + + # --- Public API --- async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str: - future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. - I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. - ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ - future_prompt += prompt + future_prompt = ( + "You are an prompt-assistant. You improving user-entered prompts for image generation. " + "User may upload reference image too. I will provide sources prompt entered by user. " + "Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! " + f"USER_ENTERED_PROMPT: {prompt}" + ) assets_data = [] - if assets is not None: + if assets: assets_db = await self.dao.assets.get_assets_by_ids(assets) - assets_data.extend(asset.data for asset in assets_db) + assets_data.extend(asset.data for asset in assets_db if asset.data) + generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data) - logger.info(future_prompt) - logger.info(generated_prompt) + logger.info(f"Prompt Assistant: {generated_prompt}") return generated_prompt async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str: technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. " if user_prompt: technical_prompt += f"User also provided this context: {user_prompt}. " - technical_prompt += "Provide ONLY the detailed prompt." return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) - async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse: - generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id) - total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id) - generations = [GenerationResponse(**gen.model_dump()) for gen in generations] - return GenerationsResponse(generations=generations, total_count=total_count) + async def get_generations(self, **kwargs) -> GenerationsResponse: + generations = await self.dao.generations.get_generations(**kwargs) + total_count = await self.dao.generations.count_generations( + character_id=kwargs.get('character_id'), + created_by=kwargs.get('created_by'), + project_id=kwargs.get('project_id'), + idea_id=kwargs.get('idea_id') + ) + return GenerationsResponse( + generations=[GenerationResponse(**gen.model_dump()) for gen in generations], + total_count=total_count + ) async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]: gen = await self.dao.generations.get_generation(generation_id) - if gen is None: - return None - else: - return GenerationResponse(**gen.model_dump()) + return GenerationResponse(**gen.model_dump()) if gen else None + + async def get_generations_by_group(self, group_id: str) -> GenerationGroupResponse: + generations = await self.dao.generations.get_generations_by_group(group_id) + return GenerationGroupResponse( + generation_group_id=group_id, + generations=[GenerationResponse(**gen.model_dump()) for gen in generations] + ) async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: - count = generation_request.count - if generation_group_id is None: generation_group_id = str(uuid4()) - + results = [] - for _ in range(count): + for _ in range(generation_request.count): gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id) results.append(gen_response) return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results) - async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse: - gen_id = None - generation_model = None - - try: - generation_model = Generation(**generation_request.model_dump(exclude={'count'})) - if user_id: - generation_model.created_by = user_id - if generation_group_id: - generation_model.generation_group_id = generation_group_id - - # Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity) - if generation_request.idea_id: - generation_model.idea_id = generation_request.idea_id - - gen_id = await self.dao.generations.create_generation(generation_model) - generation_model.id = gen_id - - async def runner(gen): - logger.info(f"Generation {gen.id} entered queue (waiting for slot)...") - try: - async with generation_semaphore: - logger.info(f"Starting background generation task for ID: {gen.id}") - await self.create_generation(gen) - logger.info(f"Background generation task finished for ID: {gen.id}") - except Exception: - # если генерация уже пошла и упала — пометим FAILED - try: - db_gen = await self.dao.generations.get_generation(gen.id) - if db_gen is not None: - db_gen.status = GenerationStatus.FAILED - await self.dao.generations.update_generation(db_gen) - except Exception: - logger.exception("Failed to mark generation as FAILED") - logger.exception("create_generation task failed") - - asyncio.create_task(runner(generation_model)) - - return GenerationResponse(**generation_model.model_dump()) - - except Exception: - # если не успели создать запись — нечего помечать - if gen_id is not None: - try: - gen = await self.dao.generations.get_generation(gen_id) - if gen is not None: - gen.status = GenerationStatus.FAILED - await self.dao.generations.update_generation(gen) - except Exception: - logger.exception("Failed to mark generation as FAILED in create_generation_task") - raise - async def create_generation(self, generation: Generation): start_time = datetime.now() logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") - # 2. Получаем ассеты-референсы (если они есть) - reference_assets: List[Asset] = [] - media_group_bytes: List[bytes] = [] - generation_prompt = generation.prompt -# generation_prompt = f""" - -# Create detailed image of character in scene. - -# SCENE DESCRIPTION: {generation.prompt} - -# Rules: -# - Integrate the character's appearance naturally into the scene description. -# - Focus on lighting, texture, and composition. -# """ - if generation.linked_character_id is not None: - char_info = await self.dao.chars.get_character(generation.linked_character_id) - if char_info is None: - raise Exception(f"Character ID {generation.linked_character_id} not found") - if generation.use_profile_image: - if char_info.avatar_asset_id is not None: - avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) - if avatar_asset and avatar_asset.data: - media_group_bytes.append(avatar_asset.data) - # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") + # 1. Prepare input + media_group_bytes, generation_prompt = await self._prepare_generation_input(generation) - reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) - - # Извлекаем данные (bytes) из ассетов для отправки в Gemini - for asset in reference_assets: - if asset.content_type != AssetContentType.IMAGE: - continue - - img_data = None - if asset.minio_object_name: - img_data = await self.s3_adapter.get_file(asset.minio_object_name) - elif asset.data: - img_data = asset.data - - if img_data: - media_group_bytes.append(img_data) - - if media_group_bytes: - generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity." - - logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}") - - # 3. Запускаем процесс генерации и симуляцию прогресса + # 2. Run generation with progress simulation progress_task = asyncio.create_task(self._simulate_progress(generation)) - try: - - # Default to Image Generation (Gemini) generated_bytes_list, metrics = await generate_image_task( - prompt=generation_prompt, # или request.prompt + prompt=generation_prompt, media_group_bytes=media_group_bytes, - aspect_ratio=generation.aspect_ratio, # предполагаем поля в request + aspect_ratio=generation.aspect_ratio, quality=generation.quality, gemini=self.gemini ) - - - # Update metrics from API (Common for both) - generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") - generation.token_usage = metrics.get("token_usage") - generation.input_token_usage = metrics.get("input_token_usage") - generation.output_token_usage = metrics.get("output_token_usage") - - except GoogleGenerationException as e: - generation.status = GenerationStatus.FAILED - generation.failed_reason = str(e) - generation.updated_at = datetime.now(UTC) - await self.dao.generations.update_generation(generation) - raise e + self._update_generation_metrics(generation, metrics) except Exception as e: - # Тут стоит добавить логирование ошибки - logging.error(f"Generation failed: {e}") - generation.status = GenerationStatus.FAILED - generation.failed_reason = str(e) - generation.updated_at = datetime.now(UTC) - await self.dao.generations.update_generation(generation) - raise e + await self._handle_generation_failure(generation, e) + raise finally: - if not progress_task.done(): + if not progress_task.done(): progress_task.cancel() try: await progress_task except asyncio.CancelledError: pass - # 4. Сохраняем полученные изображения как новые Ассеты - created_assets: List[Asset] = [] + # 3. Process results + created_assets = await self._process_generated_images(generation, generated_bytes_list) - for idx, img_bytes in enumerate(generated_bytes_list): - # Generate thumbnail - thumbnail_bytes = None - from utils.image_utils import create_thumbnail - thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes) - - # Save to S3 - filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" - await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png") - - new_asset = Asset( - name=f"Generated_{generation.linked_character_id}", - type=AssetType.GENERATED, - content_type=AssetContentType.IMAGE, - linked_char_id=generation.linked_character_id, - data=None, # Not storing bytes in DB anymore - minio_object_name=filename, - minio_bucket=self.s3_adapter.bucket_name, - thumbnail=thumbnail_bytes, - created_by=generation.created_by, - project_id=generation.project_id - ) + # 4. Finalize generation record + await self._finalize_generation(generation, created_assets, generation_prompt, start_time) - # Сохраняем в БД - asset_id = await self.dao.assets.create_asset(new_asset) - new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы - - created_assets.append(new_asset) - - # 5. (Опционально) Обновляем запись генерации ссылками на результаты - # Предполагаем, что у модели Generation есть поле result_asset_ids - result_ids = [] - for a in created_assets: - result_ids.append(a.id) - - generation.result_list = result_ids - generation.status = GenerationStatus.DONE - generation.progress = 100 - generation.updated_at = datetime.now(UTC) - generation.tech_prompt = generation_prompt - - end_time = datetime.now() - generation.execution_time_seconds = (end_time - start_time).total_seconds() - - logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}") - - await self.dao.generations.update_generation(generation) - logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s") - - # 6. Send to Telegram if telegram_id is provided + # 5. Notify if generation.telegram_id and self.bot: - try: - for asset in created_assets: - if asset.data: - await self.bot.send_photo( - chat_id=generation.telegram_id, - photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"), - caption=f"Generated from prompt: {generation.prompt[:100]}..." - ) - logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}") - except Exception as e: - logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}") - - - async def _simulate_progress(self, generation: Generation): - """ - Increments progress from 0 to 90 over ~20 seconds. - """ - current_progress = 0 - try: - while current_progress < 90: - await asyncio.sleep(4) - # Random increment between 5 and 15 - increment = random.randint(5, 15) - current_progress = min(current_progress + increment, 90) - - # Fetch latest state (optional, but good practice to avoid overwriting unrelated fields) - # But for simplicity here we just use the object we have and save it. - # Ideally, we should fetch-update-save or use partial update if DAO supports it. - # Assuming simple update is fine for now. - generation.progress = current_progress - await self.dao.generations.update_generation(generation) - except asyncio.CancelledError: - # Task cancelled, generation finished (or failed) - pass - except Exception as e: - logger.error(f"Error in progress simulation: {e}") - - - + await self._notify_telegram(generation, created_assets) async def import_external_generation(self, external_gen) -> Generation: - """ - Import a generation from an external source. - - Args: - external_gen: ExternalGenerationRequest with generation data and image - - Returns: - Created Generation object - """ - - # Validate image source external_gen.validate_image_source() - logger.info(f"Importing external generation for user: {external_gen.created_by}") + + image_bytes = await self._fetch_external_image(external_gen) - # 1. Process image (download or decode) - image_bytes = None - - if external_gen.image_url: - # Download image from URL - logger.info(f"Downloading image from URL: {external_gen.image_url}") - async with httpx.AsyncClient() as client: - response = await client.get(external_gen.image_url, timeout=30.0) - response.raise_for_status() - image_bytes = response.content - elif external_gen.image_data: - # Decode base64 image - logger.info("Decoding base64 image data") - image_bytes = base64.b64decode(external_gen.image_data) - - if not image_bytes: - raise ValueError("Failed to process image data") - - # 2. Generate thumbnail - from utils.image_utils import create_thumbnail - thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes) - - # 3. Save to S3 - filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" - await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png") - - # 4. Create Asset - new_asset = Asset( + # Reuse internal processing logic + new_asset = await self._save_asset( + image_bytes=image_bytes, name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}", - type=AssetType.GENERATED, - content_type=AssetContentType.IMAGE, - linked_char_id=external_gen.linked_character_id, - data=None, # Not storing bytes in DB - minio_object_name=filename, - minio_bucket=self.s3_adapter.bucket_name, - thumbnail=thumbnail_bytes, created_by=external_gen.created_by, - project_id=external_gen.project_id + project_id=external_gen.project_id, + linked_char_id=external_gen.linked_character_id, + folder="external" ) - - asset_id = await self.dao.assets.create_asset(new_asset) - new_asset.id = str(asset_id) - - logger.info(f"Created asset {asset_id} for external generation") - - # 5. Create Generation record + generation = Generation( status=GenerationStatus.DONE, linked_character_id=external_gen.linked_character_id, @@ -446,27 +207,18 @@ class GenerationService: input_token_usage=external_gen.input_token_usage, output_token_usage=external_gen.output_token_usage, created_by=external_gen.created_by, - project_id=external_gen.project_id, - created_at=datetime.now(UTC), - updated_at=datetime.now(UTC) + project_id=external_gen.project_id ) gen_id = await self.dao.generations.create_generation(generation) generation.id = gen_id - - logger.info(f"Created generation {gen_id} from external source") - return generation async def delete_generation(self, generation_id: str) -> bool: - """ - Soft delete generation by marking it as deleted. - """ try: generation = await self.dao.generations.get_generation(generation_id) if not generation: return False - generation.is_deleted = True generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) @@ -476,59 +228,200 @@ class GenerationService: return False async def cleanup_stale_generations(self): - """ - Cancels generations that have been running for more than 1 hour. - """ try: - count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60) + count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5) if count > 0: - logger.info(f"Cleaned up {count} stale generations (timeout)") + logger.info(f"Cleaned up {count} stale generations") except Exception as e: logger.error(f"Error cleaning up stale generations: {e}") async def cleanup_old_data(self, days: int = 30): - """ - Очистка старых данных: - 1. Мягко удаляет генерации старше N дней - 2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3 - """ try: - # 1. Мягко удаляем генерации и собираем asset IDs gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days) - if gen_count > 0: - logger.info(f"Soft-deleted {gen_count} generations older than {days} days. " - f"Found {len(asset_ids)} associated asset IDs.") - - # 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3 + logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.") if asset_ids: - purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids) - logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).") - + await self.dao.assets.soft_delete_and_purge_assets(asset_ids) except Exception as e: logger.error(f"Error during old data cleanup: {e}") async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport: - """ - Generates a financial usage report for a specific user or project. - 'breakdown_by' can be 'created_by' or 'project_id'. - """ summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id) summary = UsageStats(**summary_data) - by_user = None - by_project = None - + by_user, by_project = None, None if breakdown_by == "created_by": res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id) by_user = [UsageByEntity(**item) for item in res] - if breakdown_by == "project_id": res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id) by_project = [UsageByEntity(**item) for item in res] - return FinancialReport( - summary=summary, - by_user=by_user, - by_project=by_project - ) \ No newline at end of file + return FinancialReport(summary=summary, by_user=by_user, by_project=by_project) + + # --- Private Helpers --- + + async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse: + try: + gen_model = Generation(**generation_request.model_dump(exclude={'count'})) + gen_model.created_by = user_id + gen_model.generation_group_id = generation_group_id + + gen_id = await self.dao.generations.create_generation(gen_model) + gen_model.id = gen_id + + asyncio.create_task(self._queued_generation_runner(gen_model)) + return GenerationResponse(**gen_model.model_dump()) + except Exception: + logger.exception("Failed to initiate single generation") + raise + + async def _queued_generation_runner(self, gen: Generation): + logger.info(f"Generation {gen.id} waiting for slot...") + try: + async with generation_semaphore: + await self.create_generation(gen) + except Exception: + await self._handle_generation_failure(gen, None) + logger.exception(f"Background generation task failed for ID: {gen.id}") + + async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]: + media_group_bytes: List[bytes] = [] + prompt = generation.prompt + + # 1. Character Avatar + if generation.linked_character_id: + char_info = await self.dao.chars.get_character(generation.linked_character_id) + if not char_info: + raise ValueError(f"Character {generation.linked_character_id} not found") + + if generation.use_profile_image and char_info.avatar_asset_id: + avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) + if avatar_asset: + data = await self._get_asset_data_bytes(avatar_asset) + if data: media_group_bytes.append(data) + + # 2. Reference Assets + if generation.assets_list: + assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) + for asset in assets: + data = await self._get_asset_data_bytes(asset) + if data: media_group_bytes.append(data) + + # 3. Environment Assets + if generation.environment_id: + env = await self.dao.environments.get_env(generation.environment_id) + if env and env.asset_ids: + env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids) + for asset in env_assets: + data = await self._get_asset_data_bytes(asset) + if data: media_group_bytes.append(data) + + if media_group_bytes: + prompt += ( + " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main " + "character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity." + ) + + return media_group_bytes, prompt + + async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]: + if asset.content_type != AssetContentType.IMAGE: + return None + if asset.minio_object_name: + return await self.s3_adapter.get_file(asset.minio_object_name) + return asset.data + + def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]): + generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") + generation.token_usage = metrics.get("token_usage") + generation.input_token_usage = metrics.get("input_token_usage") + generation.output_token_usage = metrics.get("output_token_usage") + + async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]): + logger.error(f"Generation {generation.id} failed: {error}") + generation.status = GenerationStatus.FAILED + generation.failed_reason = str(error) if error else "Unknown error" + generation.updated_at = datetime.now(UTC) + await self.dao.generations.update_generation(generation) + + async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]: + created_assets = [] + for img_bytes in bytes_list: + asset = await self._save_asset( + image_bytes=img_bytes, + name=f"Generated_{generation.linked_character_id}", + created_by=generation.created_by, + project_id=generation.project_id, + linked_char_id=generation.linked_character_id, + folder="generated" + ) + created_assets.append(asset) + return created_assets + + async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset: + thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes) + filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" + + await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png") + + new_asset = Asset( + name=name, + type=AssetType.GENERATED, + content_type=AssetContentType.IMAGE, + linked_char_id=linked_char_id, + data=None, + minio_object_name=filename, + minio_bucket=self.s3_adapter.bucket_name, + thumbnail=thumbnail_bytes, + created_by=created_by, + project_id=project_id + ) + asset_id = await self.dao.assets.create_asset(new_asset) + new_asset.id = str(asset_id) + return new_asset + + async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime): + generation.result_list = [a.id for a in assets] + generation.status = GenerationStatus.DONE + generation.progress = 100 + generation.updated_at = datetime.now(UTC) + generation.tech_prompt = tech_prompt + generation.execution_time_seconds = (datetime.now() - start_time).total_seconds() + await self.dao.generations.update_generation(generation) + logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s") + + async def _notify_telegram(self, generation: Generation, assets: List[Asset]): + try: + for asset in assets: + # Need to get data for telegram if it's not in Asset object + img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data + if img_data: + await self.bot.send_photo( + chat_id=generation.telegram_id, + photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"), + caption=f"Generated from: {generation.prompt[:100]}..." + ) + except Exception as e: + logger.error(f"Failed to send to Telegram: {e}") + + async def _simulate_progress(self, generation: Generation): + current_progress = 0 + try: + while current_progress < 90: + await asyncio.sleep(4) + current_progress = min(current_progress + random.randint(5, 15), 90) + generation.progress = current_progress + await self.dao.generations.update_generation(generation) + except asyncio.CancelledError: + pass + + async def _fetch_external_image(self, external_gen) -> bytes: + if external_gen.image_url: + async with httpx.AsyncClient() as client: + response = await client.get(external_gen.image_url, timeout=30.0) + response.raise_for_status() + return response.content + elif external_gen.image_data: + return base64.b64decode(external_gen.image_data) + raise ValueError("No image source provided") diff --git a/models/enums.py b/models/enums.py index a0fd856..cfc7401 100644 --- a/models/enums.py +++ b/models/enums.py @@ -2,19 +2,30 @@ from enum import Enum class AspectRatios(str, Enum): - NINESIXTEEN = "NINESIXTEEN" - SIXTEENNINE = "SIXTEENNINE" - THREEFOUR = "THREEFOUR" - FOURTHREE = "FOURTHREE" + ONEONE = "1:1" + TWOTHREE = "2:3" + THREETWO = "3:2" + THREEFOUR = "3:4" + FOURTHREE = "4:3" + FOURFIVE = "4:5" + FIVEFOUR = "5:4" + NINESIXTEEN = "9:16" + SIXTEENNINE = "16:9" + TWENTYONENINE = "21:9" + + @classmethod + def _missing_(cls, value): + mapping = { + "NINESIXTEEN": cls.NINESIXTEEN, + "SIXTEENNINE": cls.SIXTEENNINE, + "THREEFOUR": cls.THREEFOUR, + "FOURTHREE": cls.FOURTHREE, + } + return mapping.get(value) @property def value_ratio(self) -> str: - return { - AspectRatios.NINESIXTEEN: "9:16", - AspectRatios.SIXTEENNINE: "16:9", - AspectRatios.THREEFOUR: "3:4", - AspectRatios.FOURTHREE: "4:3", - }[self] + return self.value class Quality(str, Enum): diff --git a/routers/gen_router.py b/routers/gen_router.py index ca525c4..5d0a26a 100644 --- a/routers/gen_router.py +++ b/routers/gen_router.py @@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO): @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio') async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO): await call.answer() - keyboards = [] - for ratio in AspectRatios: - keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}')) + buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios] + keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)] + keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]) await call.message.edit_caption(caption="Выбери соотношение сторон", - reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton( - text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]])) + reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows)) @router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))