442 lines
20 KiB
Python
442 lines
20 KiB
Python
import asyncio
|
|
import base64
|
|
import logging
|
|
import random
|
|
from datetime import datetime, UTC
|
|
from typing import List, Optional, Tuple, Any, Dict
|
|
from uuid import uuid4
|
|
|
|
import httpx
|
|
from aiogram import Bot
|
|
from aiogram.types import BufferedInputFile
|
|
|
|
from adapters.Exception import GoogleGenerationException
|
|
from adapters.google_adapter import GoogleAdapter
|
|
from adapters.s3_adapter import S3Adapter
|
|
from api.models import (
|
|
FinancialReport, UsageStats, UsageByEntity,
|
|
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
|
)
|
|
from models.Asset import Asset, AssetType, AssetContentType
|
|
from models.Generation import Generation, GenerationStatus
|
|
from models.enums import AspectRatios, Quality
|
|
from repos.dao import DAO
|
|
from utils.image_utils import create_thumbnail
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Limit concurrent generations to 4
|
|
generation_semaphore = asyncio.Semaphore(4)
|
|
|
|
|
|
async def generate_image_task(
|
|
prompt: str,
|
|
media_group_bytes: List[bytes],
|
|
aspect_ratio: AspectRatios,
|
|
quality: Quality,
|
|
gemini: GoogleAdapter,
|
|
) -> Tuple[List[bytes], Dict[str, Any]]:
|
|
"""
|
|
Wrapper for calling Gemini's synchronous method in a separate thread.
|
|
"""
|
|
try:
|
|
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
|
result = await asyncio.to_thread(
|
|
gemini.generate_image,
|
|
prompt=prompt,
|
|
images_list=media_group_bytes,
|
|
aspect_ratio=aspect_ratio,
|
|
quality=quality,
|
|
)
|
|
generated_images_io, metrics = result
|
|
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
|
except GoogleGenerationException:
|
|
raise
|
|
finally:
|
|
del media_group_bytes
|
|
|
|
images_bytes = []
|
|
if generated_images_io:
|
|
for img_io in generated_images_io:
|
|
img_io.seek(0)
|
|
images_bytes.append(img_io.read())
|
|
img_io.close()
|
|
del generated_images_io
|
|
|
|
return images_bytes, metrics
|
|
|
|
|
|
class GenerationService:
|
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
|
self.dao = dao
|
|
self.gemini = gemini
|
|
self.s3_adapter = s3_adapter
|
|
self.bot = bot
|
|
|
|
# --- Public API ---
|
|
|
|
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
|
future_prompt = (
|
|
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
|
"User may upload reference image too. I will provide sources prompt entered by user. "
|
|
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
|
f"USER_ENTERED_PROMPT: {prompt}"
|
|
)
|
|
assets_data = []
|
|
if assets:
|
|
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
|
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
|
|
|
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
|
logger.info(f"Prompt Assistant: {generated_prompt}")
|
|
return generated_prompt
|
|
|
|
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
|
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
|
if user_prompt:
|
|
technical_prompt += f"User also provided this context: {user_prompt}. "
|
|
technical_prompt += "Provide ONLY the detailed prompt."
|
|
|
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
|
|
|
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
|
current_user_id = kwargs.pop('current_user_id', None)
|
|
generations = await self.dao.generations.get_generations(**kwargs)
|
|
total_count = await self.dao.generations.count_generations(
|
|
character_id=kwargs.get('character_id'),
|
|
created_by=kwargs.get('created_by'),
|
|
project_id=kwargs.get('project_id'),
|
|
idea_id=kwargs.get('idea_id'),
|
|
only_liked_by=kwargs.get('only_liked_by')
|
|
)
|
|
return GenerationsResponse(
|
|
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
|
|
total_count=total_count
|
|
)
|
|
|
|
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
|
|
gen = await self.dao.generations.get_generation(generation_id)
|
|
return self._map_to_response(gen, current_user_id) if gen else None
|
|
|
|
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
|
return await self.dao.generations.toggle_like(generation_id, user_id)
|
|
|
|
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
|
|
generations = await self.dao.generations.get_generations_by_group(group_id)
|
|
return GenerationGroupResponse(
|
|
generation_group_id=group_id,
|
|
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
|
|
)
|
|
|
|
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
|
|
res = GenerationResponse(**gen.model_dump())
|
|
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
|
|
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
|
|
return res
|
|
|
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
|
|
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
|
if generation_group_id is None:
|
|
generation_group_id = str(uuid4())
|
|
|
|
results = []
|
|
for _ in range(generation_request.count):
|
|
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
|
results.append(gen_response)
|
|
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
|
|
|
async def create_generation(self, generation: Generation):
|
|
start_time = datetime.now()
|
|
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
|
|
|
# 1. Prepare input
|
|
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
|
|
|
# 2. Run generation with progress simulation
|
|
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
|
try:
|
|
generated_bytes_list, metrics = await generate_image_task(
|
|
prompt=generation_prompt,
|
|
media_group_bytes=media_group_bytes,
|
|
aspect_ratio=generation.aspect_ratio,
|
|
quality=generation.quality,
|
|
gemini=self.gemini
|
|
)
|
|
self._update_generation_metrics(generation, metrics)
|
|
|
|
# 3. Process results
|
|
created_assets = await self._process_generated_images(generation, generated_bytes_list)
|
|
|
|
# 4. Finalize generation record
|
|
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
|
|
|
|
# 5. Notify
|
|
if generation.telegram_id and self.bot:
|
|
await self._notify_telegram(generation, created_assets)
|
|
finally:
|
|
if not progress_task.done():
|
|
progress_task.cancel()
|
|
try:
|
|
await progress_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def import_external_generation(self, external_gen) -> Generation:
|
|
external_gen.validate_image_source()
|
|
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
|
|
|
image_bytes = await self._fetch_external_image(external_gen)
|
|
|
|
# Reuse internal processing logic
|
|
new_asset = await self._save_asset(
|
|
image_bytes=image_bytes,
|
|
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
|
created_by=external_gen.created_by,
|
|
project_id=external_gen.project_id,
|
|
linked_char_id=external_gen.linked_character_id,
|
|
folder="external"
|
|
)
|
|
|
|
generation = Generation(
|
|
status=GenerationStatus.DONE,
|
|
linked_character_id=external_gen.linked_character_id,
|
|
aspect_ratio=external_gen.aspect_ratio,
|
|
quality=external_gen.quality,
|
|
prompt=external_gen.prompt,
|
|
tech_prompt=external_gen.tech_prompt,
|
|
result_list=[new_asset.id],
|
|
result=new_asset.id,
|
|
progress=100,
|
|
nsfw=external_gen.nsfw,
|
|
execution_time_seconds=external_gen.execution_time_seconds,
|
|
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
|
token_usage=external_gen.token_usage,
|
|
input_token_usage=external_gen.input_token_usage,
|
|
output_token_usage=external_gen.output_token_usage,
|
|
created_by=external_gen.created_by,
|
|
project_id=external_gen.project_id
|
|
)
|
|
|
|
gen_id = await self.dao.generations.create_generation(generation)
|
|
generation.id = gen_id
|
|
return generation
|
|
|
|
async def delete_generation(self, generation_id: str) -> bool:
|
|
try:
|
|
generation = await self.dao.generations.get_generation(generation_id)
|
|
if not generation:
|
|
return False
|
|
generation.is_deleted = True
|
|
generation.updated_at = datetime.now(UTC)
|
|
await self.dao.generations.update_generation(generation)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting generation {generation_id}: {e}")
|
|
return False
|
|
|
|
async def cleanup_stale_generations(self):
|
|
try:
|
|
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
|
|
if count > 0:
|
|
logger.info(f"Cleaned up {count} stale generations")
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up stale generations: {e}")
|
|
|
|
async def cleanup_old_data(self, days: int = 30):
|
|
try:
|
|
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
|
if gen_count > 0:
|
|
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
|
|
if asset_ids:
|
|
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
|
except Exception as e:
|
|
logger.error(f"Error during old data cleanup: {e}")
|
|
|
|
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
|
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
|
summary = UsageStats(**summary_data)
|
|
|
|
by_user, by_project = None, None
|
|
if breakdown_by == "created_by":
|
|
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
|
by_user = [UsageByEntity(**item) for item in res]
|
|
if breakdown_by == "project_id":
|
|
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
|
by_project = [UsageByEntity(**item) for item in res]
|
|
|
|
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
|
|
|
|
# --- Private Helpers ---
|
|
|
|
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
|
|
try:
|
|
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
|
gen_model.created_by = user_id
|
|
gen_model.generation_group_id = generation_group_id
|
|
|
|
gen_id = await self.dao.generations.create_generation(gen_model)
|
|
gen_model.id = gen_id
|
|
|
|
asyncio.create_task(self._queued_generation_runner(gen_model))
|
|
return GenerationResponse(**gen_model.model_dump())
|
|
except Exception:
|
|
logger.exception("Failed to initiate single generation")
|
|
raise
|
|
|
|
async def _queued_generation_runner(self, gen: Generation):
|
|
logger.info(f"Generation {gen.id} waiting for slot...")
|
|
try:
|
|
async with generation_semaphore:
|
|
await self.create_generation(gen)
|
|
except Exception as e:
|
|
await self._handle_generation_failure(gen, e)
|
|
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
|
|
|
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
|
media_group_bytes: List[bytes] = []
|
|
prompt = generation.prompt
|
|
|
|
# 1. Character Avatar
|
|
if generation.linked_character_id:
|
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
|
if not char_info:
|
|
raise ValueError(f"Character {generation.linked_character_id} not found")
|
|
|
|
if generation.use_profile_image and char_info.avatar_asset_id:
|
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
|
if avatar_asset:
|
|
data = await self._get_asset_data_bytes(avatar_asset)
|
|
if data: media_group_bytes.append(data)
|
|
|
|
# 2. Reference Assets
|
|
if generation.assets_list:
|
|
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
|
for asset in assets:
|
|
data = await self._get_asset_data_bytes(asset)
|
|
if data: media_group_bytes.append(data)
|
|
|
|
# 3. Environment Assets
|
|
if generation.environment_id:
|
|
env = await self.dao.environments.get_env(generation.environment_id)
|
|
if env and env.asset_ids:
|
|
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
|
for asset in env_assets:
|
|
data = await self._get_asset_data_bytes(asset)
|
|
if data: media_group_bytes.append(data)
|
|
|
|
if media_group_bytes:
|
|
prompt += (
|
|
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
|
|
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
|
)
|
|
|
|
return media_group_bytes, prompt
|
|
|
|
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
|
if asset.content_type != AssetContentType.IMAGE:
|
|
return None
|
|
if asset.minio_object_name:
|
|
return await self.s3_adapter.get_file(asset.minio_object_name)
|
|
return asset.data
|
|
|
|
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
|
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
|
generation.token_usage = metrics.get("token_usage")
|
|
generation.input_token_usage = metrics.get("input_token_usage")
|
|
generation.output_token_usage = metrics.get("output_token_usage")
|
|
|
|
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
|
|
logger.error(f"Generation {generation.id} failed: {error}")
|
|
generation.status = GenerationStatus.FAILED
|
|
# Don't overwrite if reason is already set, unless a new error is provided
|
|
if error:
|
|
generation.failed_reason = str(error)
|
|
elif not generation.failed_reason:
|
|
generation.failed_reason = "Unknown error"
|
|
|
|
generation.updated_at = datetime.now(UTC)
|
|
await self.dao.generations.update_generation(generation)
|
|
|
|
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
|
|
created_assets = []
|
|
for img_bytes in bytes_list:
|
|
asset = await self._save_asset(
|
|
image_bytes=img_bytes,
|
|
name=f"Generated_{generation.linked_character_id}",
|
|
created_by=generation.created_by,
|
|
project_id=generation.project_id,
|
|
linked_char_id=generation.linked_character_id,
|
|
folder="generated"
|
|
)
|
|
created_assets.append(asset)
|
|
return created_assets
|
|
|
|
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
|
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
|
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
|
|
|
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
|
|
|
new_asset = Asset(
|
|
name=name,
|
|
type=AssetType.GENERATED,
|
|
content_type=AssetContentType.IMAGE,
|
|
linked_char_id=linked_char_id,
|
|
data=None,
|
|
minio_object_name=filename,
|
|
minio_bucket=self.s3_adapter.bucket_name,
|
|
thumbnail=thumbnail_bytes,
|
|
created_by=created_by,
|
|
project_id=project_id
|
|
)
|
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
|
new_asset.id = str(asset_id)
|
|
return new_asset
|
|
|
|
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
|
|
generation.result_list = [a.id for a in assets]
|
|
generation.status = GenerationStatus.DONE
|
|
generation.progress = 100
|
|
generation.updated_at = datetime.now(UTC)
|
|
generation.tech_prompt = tech_prompt
|
|
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
|
|
await self.dao.generations.update_generation(generation)
|
|
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
|
|
|
|
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
|
|
try:
|
|
for asset in assets:
|
|
# Need to get data for telegram if it's not in Asset object
|
|
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
|
|
if img_data:
|
|
await self.bot.send_photo(
|
|
chat_id=generation.telegram_id,
|
|
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
|
|
caption=f"Generated from: {generation.prompt[:100]}..."
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to send to Telegram: {e}")
|
|
|
|
async def _simulate_progress(self, generation: Generation):
|
|
current_progress = 0
|
|
try:
|
|
while current_progress < 90:
|
|
await asyncio.sleep(4)
|
|
current_progress = min(current_progress + random.randint(5, 15), 90)
|
|
generation.progress = current_progress
|
|
await self.dao.generations.update_generation(generation)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
async def _fetch_external_image(self, external_gen) -> bytes:
|
|
if external_gen.image_url:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(external_gen.image_url, timeout=30.0)
|
|
response.raise_for_status()
|
|
return response.content
|
|
elif external_gen.image_data:
|
|
return base64.b64decode(external_gen.image_data)
|
|
raise ValueError("No image source provided")
|