import asyncio import logging import random import base64 from datetime import datetime, UTC from typing import List, Optional, Tuple, Any, Dict from io import BytesIO from uuid import uuid4 import httpx from aiogram import Bot from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse from api.models import FinancialReport, UsageStats, UsageByEntity # Импортируйте ваши модели DAO, Asset, Generation корректно from models.Asset import Asset, AssetType, AssetContentType from models.Generation import Generation, GenerationStatus from models.enums import AspectRatios, Quality, GenType from repos.dao import DAO from adapters.s3_adapter import S3Adapter logger = logging.getLogger(__name__) # Limit concurrent generations to 4 generation_semaphore = asyncio.Semaphore(4) # --- Вспомогательная функция генерации --- async def generate_image_task( prompt: str, media_group_bytes: List[bytes], aspect_ratio: AspectRatios, quality: Quality, gemini: GoogleAdapter, ) -> Tuple[List[bytes], Dict[str, Any]]: """ Обертка для вызова синхронного метода Gemini в отдельном потоке. Возвращает список байтов сгенерированных изображений. """ try : logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}") # Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop result = await asyncio.to_thread( gemini.generate_image, prompt=prompt, images_list=media_group_bytes, aspect_ratio=aspect_ratio, quality=quality, ) generated_images_io, metrics = result logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") except GoogleGenerationException as e: raise e finally: # Освобождаем входные данные — они больше не нужны del media_group_bytes images_bytes = [] if generated_images_io: for img_io in generated_images_io: img_io.seek(0) images_bytes.append(img_io.read()) img_io.close() # Освобождаем список BytesIO сразу del generated_images_io return images_bytes, metrics class GenerationService: def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): self.dao = dao self.gemini = gemini self.s3_adapter = s3_adapter self.bot = bot async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str: future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ future_prompt += prompt assets_data = [] if assets is not None: assets_db = await self.dao.assets.get_assets_by_ids(assets) assets_data.extend(asset.data for asset in assets_db) generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data) logger.info(future_prompt) logger.info(generated_prompt) return generated_prompt async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str: technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. " if user_prompt: technical_prompt += f"User also provided this context: {user_prompt}. " technical_prompt += "Provide ONLY the detailed prompt." return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse: generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id) total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id) generations = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationsResponse(generations=generations, total_count=total_count) async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]: gen = await self.dao.generations.get_generation(generation_id) if gen is None: return None else: return GenerationResponse(**gen.model_dump()) async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: count = generation_request.count if generation_group_id is None: generation_group_id = str(uuid4()) results = [] for _ in range(count): gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id) results.append(gen_response) return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results) async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse: gen_id = None generation_model = None try: generation_model = Generation(**generation_request.model_dump(exclude={'count'})) if user_id: generation_model.created_by = user_id if generation_group_id: generation_model.generation_group_id = generation_group_id # Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity) if generation_request.idea_id: generation_model.idea_id = generation_request.idea_id gen_id = await self.dao.generations.create_generation(generation_model) generation_model.id = gen_id async def runner(gen): logger.info(f"Generation {gen.id} entered queue (waiting for slot)...") try: async with generation_semaphore: logger.info(f"Starting background generation task for ID: {gen.id}") await self.create_generation(gen) logger.info(f"Background generation task finished for ID: {gen.id}") except Exception: # если генерация уже пошла и упала — пометим FAILED try: db_gen = await self.dao.generations.get_generation(gen.id) if db_gen is not None: db_gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(db_gen) except Exception: logger.exception("Failed to mark generation as FAILED") logger.exception("create_generation task failed") asyncio.create_task(runner(generation_model)) return GenerationResponse(**generation_model.model_dump()) except Exception: # если не успели создать запись — нечего помечать if gen_id is not None: try: gen = await self.dao.generations.get_generation(gen_id) if gen is not None: gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(gen) except Exception: logger.exception("Failed to mark generation as FAILED in create_generation_task") raise async def create_generation(self, generation: Generation): start_time = datetime.now() logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") # 2. Получаем ассеты-референсы (если они есть) reference_assets: List[Asset] = [] media_group_bytes: List[bytes] = [] generation_prompt = generation.prompt # generation_prompt = f""" # Create detailed image of character in scene. # SCENE DESCRIPTION: {generation.prompt} # Rules: # - Integrate the character's appearance naturally into the scene description. # - Focus on lighting, texture, and composition. # """ if generation.linked_character_id is not None: char_info = await self.dao.chars.get_character(generation.linked_character_id) if char_info is None: raise Exception(f"Character ID {generation.linked_character_id} not found") if generation.use_profile_image: if char_info.avatar_asset_id is not None: avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) if avatar_asset and avatar_asset.data: media_group_bytes.append(avatar_asset.data) # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) # Извлекаем данные (bytes) из ассетов для отправки в Gemini for asset in reference_assets: if asset.content_type != AssetContentType.IMAGE: continue img_data = None if asset.minio_object_name: img_data = await self.s3_adapter.get_file(asset.minio_object_name) elif asset.data: img_data = asset.data if img_data: media_group_bytes.append(img_data) if media_group_bytes: generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity." logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}") # 3. Запускаем процесс генерации и симуляцию прогресса progress_task = asyncio.create_task(self._simulate_progress(generation)) try: # Default to Image Generation (Gemini) generated_bytes_list, metrics = await generate_image_task( prompt=generation_prompt, # или request.prompt media_group_bytes=media_group_bytes, aspect_ratio=generation.aspect_ratio, # предполагаем поля в request quality=generation.quality, gemini=self.gemini ) # Update metrics from API (Common for both) generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") generation.token_usage = metrics.get("token_usage") generation.input_token_usage = metrics.get("input_token_usage") generation.output_token_usage = metrics.get("output_token_usage") except GoogleGenerationException as e: generation.status = GenerationStatus.FAILED generation.failed_reason = str(e) generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) raise e except Exception as e: # Тут стоит добавить логирование ошибки logging.error(f"Generation failed: {e}") generation.status = GenerationStatus.FAILED generation.failed_reason = str(e) generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) raise e finally: if not progress_task.done(): progress_task.cancel() try: await progress_task except asyncio.CancelledError: pass # 4. Сохраняем полученные изображения как новые Ассеты created_assets: List[Asset] = [] for idx, img_bytes in enumerate(generated_bytes_list): # Generate thumbnail thumbnail_bytes = None from utils.image_utils import create_thumbnail thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes) # Save to S3 filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png") new_asset = Asset( name=f"Generated_{generation.linked_character_id}", type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, linked_char_id=generation.linked_character_id, data=None, # Not storing bytes in DB anymore minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, thumbnail=thumbnail_bytes, created_by=generation.created_by, project_id=generation.project_id ) # Сохраняем в БД asset_id = await self.dao.assets.create_asset(new_asset) new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы created_assets.append(new_asset) # 5. (Опционально) Обновляем запись генерации ссылками на результаты # Предполагаем, что у модели Generation есть поле result_asset_ids result_ids = [] for a in created_assets: result_ids.append(a.id) generation.result_list = result_ids generation.status = GenerationStatus.DONE generation.progress = 100 generation.updated_at = datetime.now(UTC) generation.tech_prompt = generation_prompt end_time = datetime.now() generation.execution_time_seconds = (end_time - start_time).total_seconds() logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}") await self.dao.generations.update_generation(generation) logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s") # 6. Send to Telegram if telegram_id is provided if generation.telegram_id and self.bot: try: for asset in created_assets: if asset.data: await self.bot.send_photo( chat_id=generation.telegram_id, photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"), caption=f"Generated from prompt: {generation.prompt[:100]}..." ) logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}") except Exception as e: logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}") async def _simulate_progress(self, generation: Generation): """ Increments progress from 0 to 90 over ~20 seconds. """ current_progress = 0 try: while current_progress < 90: await asyncio.sleep(4) # Random increment between 5 and 15 increment = random.randint(5, 15) current_progress = min(current_progress + increment, 90) # Fetch latest state (optional, but good practice to avoid overwriting unrelated fields) # But for simplicity here we just use the object we have and save it. # Ideally, we should fetch-update-save or use partial update if DAO supports it. # Assuming simple update is fine for now. generation.progress = current_progress await self.dao.generations.update_generation(generation) except asyncio.CancelledError: # Task cancelled, generation finished (or failed) pass except Exception as e: logger.error(f"Error in progress simulation: {e}") async def import_external_generation(self, external_gen) -> Generation: """ Import a generation from an external source. Args: external_gen: ExternalGenerationRequest with generation data and image Returns: Created Generation object """ from api.models.ExternalGenerationDTO import ExternalGenerationRequest # Validate image source external_gen.validate_image_source() logger.info(f"Importing external generation for user: {external_gen.created_by}") # 1. Process image (download or decode) image_bytes = None if external_gen.image_url: # Download image from URL logger.info(f"Downloading image from URL: {external_gen.image_url}") async with httpx.AsyncClient() as client: response = await client.get(external_gen.image_url, timeout=30.0) response.raise_for_status() image_bytes = response.content elif external_gen.image_data: # Decode base64 image logger.info("Decoding base64 image data") image_bytes = base64.b64decode(external_gen.image_data) if not image_bytes: raise ValueError("Failed to process image data") # 2. Generate thumbnail from utils.image_utils import create_thumbnail thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes) # 3. Save to S3 filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png") # 4. Create Asset new_asset = Asset( name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}", type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, linked_char_id=external_gen.linked_character_id, data=None, # Not storing bytes in DB minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, thumbnail=thumbnail_bytes, created_by=external_gen.created_by, project_id=external_gen.project_id ) asset_id = await self.dao.assets.create_asset(new_asset) new_asset.id = str(asset_id) logger.info(f"Created asset {asset_id} for external generation") # 5. Create Generation record generation = Generation( status=GenerationStatus.DONE, linked_character_id=external_gen.linked_character_id, aspect_ratio=external_gen.aspect_ratio, quality=external_gen.quality, prompt=external_gen.prompt, tech_prompt=external_gen.tech_prompt, result_list=[new_asset.id], result=new_asset.id, progress=100, execution_time_seconds=external_gen.execution_time_seconds, api_execution_time_seconds=external_gen.api_execution_time_seconds, token_usage=external_gen.token_usage, input_token_usage=external_gen.input_token_usage, output_token_usage=external_gen.output_token_usage, created_by=external_gen.created_by, project_id=external_gen.project_id, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) gen_id = await self.dao.generations.create_generation(generation) generation.id = gen_id logger.info(f"Created generation {gen_id} from external source") return generation async def delete_generation(self, generation_id: str) -> bool: """ Soft delete generation by marking it as deleted. """ try: generation = await self.dao.generations.get_generation(generation_id) if not generation: return False generation.is_deleted = True generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) return True except Exception as e: logger.error(f"Error deleting generation {generation_id}: {e}") return False async def cleanup_stale_generations(self): """ Cancels generations that have been running for more than 1 hour. """ try: count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60) if count > 0: logger.info(f"Cleaned up {count} stale generations (timeout)") except Exception as e: logger.error(f"Error cleaning up stale generations: {e}") async def cleanup_old_data(self, days: int = 2): """ Очистка старых данных: 1. Мягко удаляет генерации старше N дней 2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3 """ try: # 1. Мягко удаляем генерации и собираем asset IDs gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days) if gen_count > 0: logger.info(f"Soft-deleted {gen_count} generations older than {days} days. " f"Found {len(asset_ids)} associated asset IDs.") # 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3 if asset_ids: purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids) logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).") except Exception as e: logger.error(f"Error during old data cleanup: {e}") async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport: """ Generates a financial usage report for a specific user or project. 'breakdown_by' can be 'created_by' or 'project_id'. """ summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id) summary = UsageStats(**summary_data) by_user = None by_project = None if breakdown_by == "created_by": res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id) by_user = [UsageByEntity(**item) for item in res] if breakdown_by == "project_id": res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id) by_project = [UsageByEntity(**item) for item in res] return FinancialReport( summary=summary, by_user=by_user, by_project=by_project )