import asyncio import logging import random import base64 from datetime import datetime, UTC from typing import List, Optional, Tuple, Any, Dict from io import BytesIO import httpx from aiogram import Bot from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse # Импортируйте ваши модели DAO, Asset, Generation корректно from models.Asset import Asset, AssetType, AssetContentType from models.Generation import Generation, GenerationStatus from models.enums import AspectRatios, Quality, GenType from repos.dao import DAO from adapters.s3_adapter import S3Adapter logger = logging.getLogger(__name__) # --- Вспомогательная функция генерации --- async def generate_image_task( prompt: str, media_group_bytes: List[bytes], aspect_ratio: AspectRatios, quality: Quality, gemini: GoogleAdapter, ) -> Tuple[List[bytes], Dict[str, Any]]: """ Обертка для вызова синхронного метода Gemini в отдельном потоке. Возвращает список байтов сгенерированных изображений. """ try : logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}") # Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop result = await asyncio.to_thread( gemini.generate_image, prompt=prompt, images_list=media_group_bytes, aspect_ratio=aspect_ratio, quality=quality, ) generated_images_io, metrics = result logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") except GoogleGenerationException as e: raise e images_bytes = [] if generated_images_io: for img_io in generated_images_io: # Читаем байты из BytesIO img_io.seek(0) content = img_io.read() images_bytes.append(content) # Закрываем поток img_io.close() return images_bytes, metrics class GenerationService: def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): self.dao = dao self.gemini = gemini self.s3_adapter = s3_adapter self.bot = bot async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str: future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ future_prompt += prompt assets_data = [] if assets is not None: assets_db = await self.dao.assets.get_assets_by_ids(assets) assets_data.extend(asset.data for asset in assets_db) generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data) logger.info(future_prompt) logger.info(generated_prompt) return generated_prompt async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str: technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. " if user_prompt: technical_prompt += f"User also provided this context: {user_prompt}. " technical_prompt += "Provide ONLY the detailed prompt." return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[ Generation]: generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id) total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id) generations = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationsResponse(generations=generations, total_count=total_count) async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]: gen = await self.dao.generations.get_generation(generation_id) if gen is None: return None else: return GenerationResponse(**gen.model_dump()) async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: gen_id = None generation_model = None try: generation_model = Generation(**generation_request.model_dump()) if user_id: generation_model.created_by = user_id gen_id = await self.dao.generations.create_generation(generation_model) generation_model.id = gen_id async def runner(gen): logger.info(f"Starting background generation task for ID: {gen.id}") try: await self.create_generation(gen) logger.info(f"Background generation task finished for ID: {gen.id}") except Exception: # если генерация уже пошла и упала — пометим FAILED try: db_gen = await self.dao.generations.get_generation(gen.id) db_gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(db_gen) except Exception: logger.exception("Failed to mark generation as FAILED") logger.exception("create_generation task failed") asyncio.create_task(runner(generation_model)) return GenerationResponse(**generation_model.model_dump()) except Exception: # если не успели создать запись — нечего помечать if gen_id is not None: try: gen = await self.dao.generations.get_generation(gen_id) gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(gen) except Exception: logger.exception("Failed to mark generation as FAILED in create_generation_task") raise async def create_generation(self, generation: Generation): start_time = datetime.now() logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") # 2. Получаем ассеты-референсы (если они есть) reference_assets: List[Asset] = [] media_group_bytes: List[bytes] = [] generation_prompt = generation.prompt # generation_prompt = f""" # Create detailed image of character in scene. # SCENE DESCRIPTION: {generation.prompt} # Rules: # - Integrate the character's appearance naturally into the scene description. # - Focus on lighting, texture, and composition. # """ if generation.linked_character_id is not None: char_info = await self.dao.chars.get_character(generation.linked_character_id) if char_info is None: raise Exception(f"Character ID {generation.linked_character_id} not found") if generation.use_profile_image: avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) if avatar_asset: media_group_bytes.append(avatar_asset.data) # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) # Извлекаем данные (bytes) из ассетов для отправки в Gemini for asset in reference_assets: if asset.content_type != AssetContentType.IMAGE: continue img_data = None if asset.minio_object_name: img_data = await self.s3_adapter.get_file(asset.minio_object_name) elif asset.data: img_data = asset.data if img_data: media_group_bytes.append(img_data) if media_group_bytes: generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity." logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}") # 3. Запускаем процесс генерации и симуляцию прогресса progress_task = asyncio.create_task(self._simulate_progress(generation)) try: # Default to Image Generation (Gemini) generated_bytes_list, metrics = await generate_image_task( prompt=generation_prompt, # или request.prompt media_group_bytes=media_group_bytes, aspect_ratio=generation.aspect_ratio, # предполагаем поля в request quality=generation.quality, gemini=self.gemini ) # Update metrics from API (Common for both) generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") generation.token_usage = metrics.get("token_usage") generation.input_token_usage = metrics.get("input_token_usage") generation.output_token_usage = metrics.get("output_token_usage") except GoogleGenerationException as e: generation.status = GenerationStatus.FAILED generation.failed_reason = str(e) generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) raise e except Exception as e: # Тут стоит добавить логирование ошибки logging.error(f"Generation failed: {e}") generation.status = GenerationStatus.FAILED generation.failed_reason = str(e) generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) raise e finally: if not progress_task.done(): progress_task.cancel() try: await progress_task except asyncio.CancelledError: pass # 4. Сохраняем полученные изображения как новые Ассеты created_assets: List[Asset] = [] for idx, img_bytes in enumerate(generated_bytes_list): # Generate thumbnail thumbnail_bytes = None from utils.image_utils import create_thumbnail thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes) # Save to S3 filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png") new_asset = Asset( name=f"Generated_{generation.linked_character_id}", type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, linked_char_id=generation.linked_character_id, data=None, # Not storing bytes in DB anymore minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, thumbnail=thumbnail_bytes, created_by=generation.created_by, project_id=generation.project_id ) # Сохраняем в БД asset_id = await self.dao.assets.create_asset(new_asset) new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы created_assets.append(new_asset) # 5. (Опционально) Обновляем запись генерации ссылками на результаты # Предполагаем, что у модели Generation есть поле result_asset_ids result_ids = [a.id for a in created_assets] generation.result_list = result_ids generation.status = GenerationStatus.DONE generation.progress = 100 generation.updated_at = datetime.now(UTC) generation.tech_prompt = generation_prompt end_time = datetime.now() generation.execution_time_seconds = (end_time - start_time).total_seconds() logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}") await self.dao.generations.update_generation(generation) logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s") # 6. Send to Telegram if telegram_id is provided if generation.telegram_id and self.bot: try: for asset in created_assets: if asset.data: await self.bot.send_photo( chat_id=generation.telegram_id, photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"), caption=f"Generated from prompt: {generation.prompt[:100]}..." ) logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}") except Exception as e: logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}") async def _simulate_progress(self, generation: Generation): """ Increments progress from 0 to 90 over ~20 seconds. """ current_progress = 0 try: while current_progress < 90: await asyncio.sleep(4) # Random increment between 5 and 15 increment = random.randint(5, 15) current_progress = min(current_progress + increment, 90) # Fetch latest state (optional, but good practice to avoid overwriting unrelated fields) # But for simplicity here we just use the object we have and save it. # Ideally, we should fetch-update-save or use partial update if DAO supports it. # Assuming simple update is fine for now. generation.progress = current_progress await self.dao.generations.update_generation(generation) except asyncio.CancelledError: # Task cancelled, generation finished (or failed) pass except Exception as e: logger.error(f"Error in progress simulation: {e}") async def import_external_generation(self, external_gen) -> Generation: """ Import a generation from an external source. Args: external_gen: ExternalGenerationRequest with generation data and image Returns: Created Generation object """ from api.models.ExternalGenerationDTO import ExternalGenerationRequest # Validate image source external_gen.validate_image_source() logger.info(f"Importing external generation for user: {external_gen.created_by}") # 1. Process image (download or decode) image_bytes = None if external_gen.image_url: # Download image from URL logger.info(f"Downloading image from URL: {external_gen.image_url}") async with httpx.AsyncClient() as client: response = await client.get(external_gen.image_url, timeout=30.0) response.raise_for_status() image_bytes = response.content elif external_gen.image_data: # Decode base64 image logger.info("Decoding base64 image data") image_bytes = base64.b64decode(external_gen.image_data) if not image_bytes: raise ValueError("Failed to process image data") # 2. Generate thumbnail from utils.image_utils import create_thumbnail thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes) # 3. Save to S3 filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png") # 4. Create Asset new_asset = Asset( name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}", type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, linked_char_id=external_gen.linked_character_id, data=None, # Not storing bytes in DB minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, thumbnail=thumbnail_bytes, created_by=external_gen.created_by, project_id=external_gen.project_id ) asset_id = await self.dao.assets.create_asset(new_asset) new_asset.id = str(asset_id) logger.info(f"Created asset {asset_id} for external generation") # 5. Create Generation record generation = Generation( status=GenerationStatus.DONE, linked_character_id=external_gen.linked_character_id, aspect_ratio=external_gen.aspect_ratio, quality=external_gen.quality, prompt=external_gen.prompt, tech_prompt=external_gen.tech_prompt, result_list=[new_asset.id], result=new_asset.id, progress=100, execution_time_seconds=external_gen.execution_time_seconds, api_execution_time_seconds=external_gen.api_execution_time_seconds, token_usage=external_gen.token_usage, input_token_usage=external_gen.input_token_usage, output_token_usage=external_gen.output_token_usage, created_by=external_gen.created_by, project_id=external_gen.project_id, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) gen_id = await self.dao.generations.create_generation(generation) generation.id = gen_id logger.info(f"Created generation {gen_id} from external source") return generation async def delete_generation(self, generation_id: str) -> bool: """ Soft delete generation by marking it as deleted. """ try: generation = await self.dao.generations.get_generation(generation_id) if not generation: return False generation.is_deleted = True generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) return True except Exception as e: logger.error(f"Error deleting generation {generation_id}: {e}") return False