Files
ai-char-bot/api/service/generation_service.py
2026-02-27 14:33:37 +03:00

442 lines
20 KiB
Python

import asyncio
import base64
import logging
import random
from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict
from uuid import uuid4
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter
from api.models import (
FinancialReport, UsageStats, UsageByEntity,
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
)
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from repos.dao import DAO
from utils.image_utils import create_thumbnail
logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
gemini: GoogleAdapter,
) -> Tuple[List[bytes], Dict[str, Any]]:
"""
Wrapper for calling Gemini's synchronous method in a separate thread.
"""
try:
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
result = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=aspect_ratio,
quality=quality,
)
generated_images_io, metrics = result
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException:
raise
finally:
del media_group_bytes
images_bytes = []
if generated_images_io:
for img_io in generated_images_io:
img_io.seek(0)
images_bytes.append(img_io.read())
img_io.close()
del generated_images_io
return images_bytes, metrics
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
self.dao = dao
self.gemini = gemini
self.s3_adapter = s3_adapter
self.bot = bot
# --- Public API ---
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = (
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
"User may upload reference image too. I will provide sources prompt entered by user. "
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
f"USER_ENTERED_PROMPT: {prompt}"
)
assets_data = []
if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
logger.info(f"Prompt Assistant: {generated_prompt}")
return generated_prompt
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
if user_prompt:
technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt."
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, **kwargs) -> GenerationsResponse:
current_user_id = kwargs.pop('current_user_id', None)
generations = await self.dao.generations.get_generations(**kwargs)
total_count = await self.dao.generations.count_generations(
character_id=kwargs.get('character_id'),
created_by=kwargs.get('created_by'),
project_id=kwargs.get('project_id'),
idea_id=kwargs.get('idea_id'),
only_liked_by=kwargs.get('only_liked_by')
)
return GenerationsResponse(
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
total_count=total_count
)
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
return self._map_to_response(gen, current_user_id) if gen else None
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
return await self.dao.generations.toggle_like(generation_id, user_id)
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
generations = await self.dao.generations.get_generations_by_group(group_id)
return GenerationGroupResponse(
generation_group_id=group_id,
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
)
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
res = GenerationResponse(**gen.model_dump())
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
return res
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(generation_request.count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def create_generation(self, generation: Generation):
start_time = datetime.now()
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 1. Prepare input
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
# 2. Run generation with progress simulation
progress_task = asyncio.create_task(self._simulate_progress(generation))
try:
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
gemini=self.gemini
)
self._update_generation_metrics(generation, metrics)
# 3. Process results
created_assets = await self._process_generated_images(generation, generated_bytes_list)
# 4. Finalize generation record
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
# 5. Notify
if generation.telegram_id and self.bot:
await self._notify_telegram(generation, created_assets)
finally:
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
async def import_external_generation(self, external_gen) -> Generation:
external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
image_bytes = await self._fetch_external_image(external_gen)
# Reuse internal processing logic
new_asset = await self._save_asset(
image_bytes=image_bytes,
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
created_by=external_gen.created_by,
project_id=external_gen.project_id,
linked_char_id=external_gen.linked_character_id,
folder="external"
)
generation = Generation(
status=GenerationStatus.DONE,
linked_character_id=external_gen.linked_character_id,
aspect_ratio=external_gen.aspect_ratio,
quality=external_gen.quality,
prompt=external_gen.prompt,
tech_prompt=external_gen.tech_prompt,
result_list=[new_asset.id],
result=new_asset.id,
progress=100,
nsfw=external_gen.nsfw,
execution_time_seconds=external_gen.execution_time_seconds,
api_execution_time_seconds=external_gen.api_execution_time_seconds,
token_usage=external_gen.token_usage,
input_token_usage=external_gen.input_token_usage,
output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
project_id=external_gen.project_id
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
return generation
async def delete_generation(self, generation_id: str) -> bool:
try:
generation = await self.dao.generations.get_generation(generation_id)
if not generation:
return False
generation.is_deleted = True
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
return True
except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}")
return False
async def cleanup_stale_generations(self):
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
if count > 0:
logger.info(f"Cleaned up {count} stale generations")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 30):
try:
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
if asset_ids:
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user, by_project = None, None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
# --- Private Helpers ---
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
try:
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
gen_model.created_by = user_id
gen_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(gen_model)
gen_model.id = gen_id
asyncio.create_task(self._queued_generation_runner(gen_model))
return GenerationResponse(**gen_model.model_dump())
except Exception:
logger.exception("Failed to initiate single generation")
raise
async def _queued_generation_runner(self, gen: Generation):
logger.info(f"Generation {gen.id} waiting for slot...")
try:
async with generation_semaphore:
await self.create_generation(gen)
except Exception as e:
await self._handle_generation_failure(gen, e)
logger.exception(f"Background generation task failed for ID: {gen.id}")
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
media_group_bytes: List[bytes] = []
prompt = generation.prompt
# 1. Character Avatar
if generation.linked_character_id:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if not char_info:
raise ValueError(f"Character {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
data = await self._get_asset_data_bytes(avatar_asset)
if data: media_group_bytes.append(data)
# 2. Reference Assets
if generation.assets_list:
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
# 3. Environment Assets
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
if media_group_bytes:
prompt += (
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
)
return media_group_bytes, prompt
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
logger.error(f"Generation {generation.id} failed: {error}")
generation.status = GenerationStatus.FAILED
# Don't overwrite if reason is already set, unless a new error is provided
if error:
generation.failed_reason = str(error)
elif not generation.failed_reason:
generation.failed_reason = "Unknown error"
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
created_assets = []
for img_bytes in bytes_list:
asset = await self._save_asset(
image_bytes=img_bytes,
name=f"Generated_{generation.linked_character_id}",
created_by=generation.created_by,
project_id=generation.project_id,
linked_char_id=generation.linked_character_id,
folder="generated"
)
created_assets.append(asset)
return created_assets
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
new_asset = Asset(
name=name,
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
return new_asset
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
generation.result_list = [a.id for a in assets]
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = tech_prompt
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
try:
for asset in assets:
# Need to get data for telegram if it's not in Asset object
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
if img_data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
caption=f"Generated from: {generation.prompt[:100]}..."
)
except Exception as e:
logger.error(f"Failed to send to Telegram: {e}")
async def _simulate_progress(self, generation: Generation):
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
current_progress = min(current_progress + random.randint(5, 15), 90)
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
pass
async def _fetch_external_image(self, external_gen) -> bytes:
if external_gen.image_url:
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
return response.content
elif external_gen.image_data:
return base64.b64decode(external_gen.image_data)
raise ValueError("No image source provided")