import asyncio import base64 import logging import random from datetime import datetime, UTC from typing import List, Optional, Tuple, Any, Dict from uuid import uuid4 import httpx from aiogram import Bot from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter from adapters.s3_adapter import S3Adapter from api.models import ( FinancialReport, UsageStats, UsageByEntity, GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse ) from models.Asset import Asset, AssetType, AssetContentType from models.Generation import Generation, GenerationStatus from models.enums import AspectRatios, Quality from repos.dao import DAO from utils.image_utils import create_thumbnail logger = logging.getLogger(__name__) # Limit concurrent generations to 4 generation_semaphore = asyncio.Semaphore(4) async def generate_image_task( prompt: str, media_group_bytes: List[bytes], aspect_ratio: AspectRatios, quality: Quality, gemini: GoogleAdapter, ) -> Tuple[List[bytes], Dict[str, Any]]: """ Wrapper for calling Gemini's synchronous method in a separate thread. """ try: logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}") result = await asyncio.to_thread( gemini.generate_image, prompt=prompt, images_list=media_group_bytes, aspect_ratio=aspect_ratio, quality=quality, ) generated_images_io, metrics = result logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") except GoogleGenerationException: raise finally: del media_group_bytes images_bytes = [] if generated_images_io: for img_io in generated_images_io: img_io.seek(0) images_bytes.append(img_io.read()) img_io.close() del generated_images_io return images_bytes, metrics class GenerationService: def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): self.dao = dao self.gemini = gemini self.s3_adapter = s3_adapter self.bot = bot # --- Public API --- async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str: future_prompt = ( "You are an prompt-assistant. You improving user-entered prompts for image generation. " "User may upload reference image too. I will provide sources prompt entered by user. " "Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! " f"USER_ENTERED_PROMPT: {prompt}" ) assets_data = [] if assets: assets_db = await self.dao.assets.get_assets_by_ids(assets) assets_data.extend(asset.data for asset in assets_db if asset.data) generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data) logger.info(f"Prompt Assistant: {generated_prompt}") return generated_prompt async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str: technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. " if user_prompt: technical_prompt += f"User also provided this context: {user_prompt}. " technical_prompt += "Provide ONLY the detailed prompt." return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) async def get_generations(self, **kwargs) -> GenerationsResponse: current_user_id = kwargs.pop('current_user_id', None) generations = await self.dao.generations.get_generations(**kwargs) total_count = await self.dao.generations.count_generations( character_id=kwargs.get('character_id'), created_by=kwargs.get('created_by'), project_id=kwargs.get('project_id'), idea_id=kwargs.get('idea_id'), only_liked_by=kwargs.get('only_liked_by') ) return GenerationsResponse( generations=[self._map_to_response(gen, current_user_id) for gen in generations], total_count=total_count ) async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]: gen = await self.dao.generations.get_generation(generation_id) return self._map_to_response(gen, current_user_id) if gen else None async def toggle_like(self, generation_id: str, user_id: str) -> bool | None: return await self.dao.generations.toggle_like(generation_id, user_id) async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse: generations = await self.dao.generations.get_generations_by_group(group_id) return GenerationGroupResponse( generation_group_id=group_id, generations=[self._map_to_response(gen, current_user_id) for gen in generations] ) def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse: res = GenerationResponse(**gen.model_dump()) res.likes_count = len(gen.liked_by) if gen.liked_by else 0 res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False return res async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: if generation_group_id is None: generation_group_id = str(uuid4()) results = [] for _ in range(generation_request.count): gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id) results.append(gen_response) return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results) async def create_generation(self, generation: Generation): start_time = datetime.now() logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") # 1. Prepare input media_group_bytes, generation_prompt = await self._prepare_generation_input(generation) # 2. Run generation with progress simulation progress_task = asyncio.create_task(self._simulate_progress(generation)) try: generated_bytes_list, metrics = await generate_image_task( prompt=generation_prompt, media_group_bytes=media_group_bytes, aspect_ratio=generation.aspect_ratio, quality=generation.quality, gemini=self.gemini ) self._update_generation_metrics(generation, metrics) # 3. Process results created_assets = await self._process_generated_images(generation, generated_bytes_list) # 4. Finalize generation record await self._finalize_generation(generation, created_assets, generation_prompt, start_time) # 5. Notify if generation.telegram_id and self.bot: await self._notify_telegram(generation, created_assets) finally: if not progress_task.done(): progress_task.cancel() try: await progress_task except asyncio.CancelledError: pass async def import_external_generation(self, external_gen) -> Generation: external_gen.validate_image_source() logger.info(f"Importing external generation for user: {external_gen.created_by}") image_bytes = await self._fetch_external_image(external_gen) # Reuse internal processing logic new_asset = await self._save_asset( image_bytes=image_bytes, name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}", created_by=external_gen.created_by, project_id=external_gen.project_id, linked_char_id=external_gen.linked_character_id, folder="external" ) generation = Generation( status=GenerationStatus.DONE, linked_character_id=external_gen.linked_character_id, aspect_ratio=external_gen.aspect_ratio, quality=external_gen.quality, prompt=external_gen.prompt, tech_prompt=external_gen.tech_prompt, result_list=[new_asset.id], result=new_asset.id, progress=100, execution_time_seconds=external_gen.execution_time_seconds, api_execution_time_seconds=external_gen.api_execution_time_seconds, token_usage=external_gen.token_usage, input_token_usage=external_gen.input_token_usage, output_token_usage=external_gen.output_token_usage, created_by=external_gen.created_by, project_id=external_gen.project_id ) gen_id = await self.dao.generations.create_generation(generation) generation.id = gen_id return generation async def delete_generation(self, generation_id: str) -> bool: try: generation = await self.dao.generations.get_generation(generation_id) if not generation: return False generation.is_deleted = True generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) return True except Exception as e: logger.error(f"Error deleting generation {generation_id}: {e}") return False async def cleanup_stale_generations(self): try: count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5) if count > 0: logger.info(f"Cleaned up {count} stale generations") except Exception as e: logger.error(f"Error cleaning up stale generations: {e}") async def cleanup_old_data(self, days: int = 30): try: gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days) if gen_count > 0: logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.") if asset_ids: await self.dao.assets.soft_delete_and_purge_assets(asset_ids) except Exception as e: logger.error(f"Error during old data cleanup: {e}") async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport: summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id) summary = UsageStats(**summary_data) by_user, by_project = None, None if breakdown_by == "created_by": res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id) by_user = [UsageByEntity(**item) for item in res] if breakdown_by == "project_id": res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id) by_project = [UsageByEntity(**item) for item in res] return FinancialReport(summary=summary, by_user=by_user, by_project=by_project) # --- Private Helpers --- async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse: try: gen_model = Generation(**generation_request.model_dump(exclude={'count'})) gen_model.created_by = user_id gen_model.generation_group_id = generation_group_id gen_id = await self.dao.generations.create_generation(gen_model) gen_model.id = gen_id asyncio.create_task(self._queued_generation_runner(gen_model)) return GenerationResponse(**gen_model.model_dump()) except Exception: logger.exception("Failed to initiate single generation") raise async def _queued_generation_runner(self, gen: Generation): logger.info(f"Generation {gen.id} waiting for slot...") try: async with generation_semaphore: await self.create_generation(gen) except Exception as e: await self._handle_generation_failure(gen, e) logger.exception(f"Background generation task failed for ID: {gen.id}") async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]: media_group_bytes: List[bytes] = [] prompt = generation.prompt # 1. Character Avatar if generation.linked_character_id: char_info = await self.dao.chars.get_character(generation.linked_character_id) if not char_info: raise ValueError(f"Character {generation.linked_character_id} not found") if generation.use_profile_image and char_info.avatar_asset_id: avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) if avatar_asset: data = await self._get_asset_data_bytes(avatar_asset) if data: media_group_bytes.append(data) # 2. Reference Assets if generation.assets_list: assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) for asset in assets: data = await self._get_asset_data_bytes(asset) if data: media_group_bytes.append(data) # 3. Environment Assets if generation.environment_id: env = await self.dao.environments.get_env(generation.environment_id) if env and env.asset_ids: env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids) for asset in env_assets: data = await self._get_asset_data_bytes(asset) if data: media_group_bytes.append(data) if media_group_bytes: prompt += ( " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main " "character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity." ) return media_group_bytes, prompt async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]: if asset.content_type != AssetContentType.IMAGE: return None if asset.minio_object_name: return await self.s3_adapter.get_file(asset.minio_object_name) return asset.data def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]): generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") generation.token_usage = metrics.get("token_usage") generation.input_token_usage = metrics.get("input_token_usage") generation.output_token_usage = metrics.get("output_token_usage") async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]): logger.error(f"Generation {generation.id} failed: {error}") generation.status = GenerationStatus.FAILED # Don't overwrite if reason is already set, unless a new error is provided if error: generation.failed_reason = str(error) elif not generation.failed_reason: generation.failed_reason = "Unknown error" generation.updated_at = datetime.now(UTC) await self.dao.generations.update_generation(generation) async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]: created_assets = [] for img_bytes in bytes_list: asset = await self._save_asset( image_bytes=img_bytes, name=f"Generated_{generation.linked_character_id}", created_by=generation.created_by, project_id=generation.project_id, linked_char_id=generation.linked_character_id, folder="generated" ) created_assets.append(asset) return created_assets async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset: thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes) filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png") new_asset = Asset( name=name, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, linked_char_id=linked_char_id, data=None, minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, thumbnail=thumbnail_bytes, created_by=created_by, project_id=project_id ) asset_id = await self.dao.assets.create_asset(new_asset) new_asset.id = str(asset_id) return new_asset async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime): generation.result_list = [a.id for a in assets] generation.status = GenerationStatus.DONE generation.progress = 100 generation.updated_at = datetime.now(UTC) generation.tech_prompt = tech_prompt generation.execution_time_seconds = (datetime.now() - start_time).total_seconds() await self.dao.generations.update_generation(generation) logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s") async def _notify_telegram(self, generation: Generation, assets: List[Asset]): try: for asset in assets: # Need to get data for telegram if it's not in Asset object img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data if img_data: await self.bot.send_photo( chat_id=generation.telegram_id, photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"), caption=f"Generated from: {generation.prompt[:100]}..." ) except Exception as e: logger.error(f"Failed to send to Telegram: {e}") async def _simulate_progress(self, generation: Generation): current_progress = 0 try: while current_progress < 90: await asyncio.sleep(4) current_progress = min(current_progress + random.randint(5, 15), 90) generation.progress = current_progress await self.dao.generations.update_generation(generation) except asyncio.CancelledError: pass async def _fetch_external_image(self, external_gen) -> bytes: if external_gen.image_url: async with httpx.AsyncClient() as client: response = await client.get(external_gen.image_url, timeout=30.0) response.raise_for_status() return response.content elif external_gen.image_data: return base64.b64decode(external_gen.image_data) raise ValueError("No image source provided")