297 lines
14 KiB
Python
297 lines
14 KiB
Python
import asyncio
|
||
import logging
|
||
import random
|
||
from datetime import datetime, UTC
|
||
from typing import List, Optional, Tuple, Any, Dict
|
||
from io import BytesIO
|
||
|
||
from aiogram import Bot
|
||
from aiogram.types import BufferedInputFile
|
||
from adapters.Exception import GoogleGenerationException
|
||
from adapters.google_adapter import GoogleAdapter
|
||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||
from models.Asset import Asset, AssetType
|
||
from models.Generation import Generation, GenerationStatus
|
||
from models.enums import AspectRatios, Quality, GenType
|
||
from repos.dao import DAO
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# --- Вспомогательная функция генерации ---
|
||
async def generate_image_task(
|
||
prompt: str,
|
||
media_group_bytes: List[bytes],
|
||
aspect_ratio: AspectRatios,
|
||
quality: Quality,
|
||
gemini: GoogleAdapter
|
||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||
"""
|
||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||
Возвращает список байтов сгенерированных изображений.
|
||
"""
|
||
try :
|
||
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||
result = await asyncio.to_thread(
|
||
gemini.generate_image,
|
||
prompt=prompt,
|
||
images_list=media_group_bytes,
|
||
aspect_ratio=aspect_ratio,
|
||
quality=quality,
|
||
)
|
||
generated_images_io, metrics = result
|
||
|
||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||
except GoogleGenerationException as e:
|
||
raise e
|
||
images_bytes = []
|
||
if generated_images_io:
|
||
for img_io in generated_images_io:
|
||
# Читаем байты из BytesIO
|
||
img_io.seek(0)
|
||
content = img_io.read()
|
||
images_bytes.append(content)
|
||
|
||
# Закрываем поток
|
||
img_io.close()
|
||
|
||
return images_bytes, metrics
|
||
|
||
class GenerationService:
|
||
def __init__(self, dao: DAO, gemini: GoogleAdapter, bot: Optional[Bot] = None):
|
||
self.dao = dao
|
||
self.gemini = gemini
|
||
self.bot = bot
|
||
|
||
|
||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||
future_prompt += prompt
|
||
assets_data = []
|
||
if assets is not None:
|
||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||
assets_data.extend(asset.data for asset in assets_db)
|
||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||
logger.info(future_prompt)
|
||
logger.info(generated_prompt)
|
||
return generated_prompt
|
||
|
||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||
if user_prompt:
|
||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||
|
||
technical_prompt += "Provide ONLY the detailed prompt."
|
||
|
||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||
|
||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
||
Generation]:
|
||
return await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset)
|
||
|
||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||
gen = await self.dao.generations.get_generation(generation_id)
|
||
if gen is None:
|
||
return None
|
||
else:
|
||
return GenerationResponse(**gen.model_dump())
|
||
|
||
async def get_running_generations(self) -> List[Generation]:
|
||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
||
|
||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
||
gen_id = None
|
||
generation_model = None
|
||
|
||
try:
|
||
generation_model = Generation(**generation_request.model_dump())
|
||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||
generation_model.id = gen_id
|
||
|
||
async def runner(gen):
|
||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||
try:
|
||
await self.create_generation(gen)
|
||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||
except Exception:
|
||
# если генерация уже пошла и упала — пометим FAILED
|
||
try:
|
||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||
db_gen.status = GenerationStatus.FAILED
|
||
await self.dao.generations.update_generation(db_gen)
|
||
except Exception:
|
||
logger.exception("Failed to mark generation as FAILED")
|
||
logger.exception("create_generation task failed")
|
||
|
||
asyncio.create_task(runner(generation_model))
|
||
|
||
return GenerationResponse(**generation_model.model_dump())
|
||
|
||
except Exception:
|
||
# если не успели создать запись — нечего помечать
|
||
if gen_id is not None:
|
||
try:
|
||
gen = await self.dao.generations.get_generation(gen_id)
|
||
gen.status = GenerationStatus.FAILED
|
||
await self.dao.generations.update_generation(gen)
|
||
except Exception:
|
||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||
raise
|
||
|
||
async def create_generation(self, generation: Generation):
|
||
start_time = datetime.now()
|
||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||
|
||
# 2. Получаем ассеты-референсы (если они есть)
|
||
reference_assets: List[Asset] = []
|
||
media_group_bytes: List[bytes] = []
|
||
generation_prompt = "You are creating image. "
|
||
if generation.linked_character_id is not None:
|
||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
||
if char_info is None:
|
||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||
if generation.use_profile_image:
|
||
media_group_bytes.append(char_info.character_image_data)
|
||
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
|
||
|
||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||
# Фильтруем, чтобы отправлять только картинки, и где есть data
|
||
media_group_bytes.extend(
|
||
asset.data
|
||
for asset in reference_assets
|
||
if asset.data is not None and asset.type == AssetType.IMAGE
|
||
)
|
||
generation_prompt+=f" PROMPT: {generation.prompt}"
|
||
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
||
|
||
# 3. Запускаем процесс генерации и симуляцию прогресса
|
||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||
|
||
try:
|
||
|
||
# Default to Image Generation (Gemini)
|
||
generated_bytes_list, metrics = await generate_image_task(
|
||
prompt=generation_prompt, # или request.prompt
|
||
media_group_bytes=media_group_bytes,
|
||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||
quality=generation.quality,
|
||
gemini=self.gemini
|
||
)
|
||
|
||
# Update metrics from API (Common for both)
|
||
# Update metrics from API (Common for both)
|
||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||
generation.token_usage = metrics.get("token_usage")
|
||
generation.input_token_usage = metrics.get("input_token_usage")
|
||
generation.output_token_usage = metrics.get("output_token_usage")
|
||
|
||
except GoogleGenerationException as e:
|
||
generation.status = GenerationStatus.FAILED
|
||
generation.failed_reason = str(e)
|
||
generation.updated_at = datetime.now(UTC)
|
||
await self.dao.generations.update_generation(generation)
|
||
raise e
|
||
except Exception as e:
|
||
# Тут стоит добавить логирование ошибки
|
||
logging.error(f"Generation failed: {e}")
|
||
generation.status = GenerationStatus.FAILED
|
||
generation.failed_reason = str(e)
|
||
generation.updated_at = datetime.now(UTC)
|
||
await self.dao.generations.update_generation(generation)
|
||
raise e
|
||
finally:
|
||
if not progress_task.done():
|
||
progress_task.cancel()
|
||
try:
|
||
await progress_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||
created_assets: List[Asset] = []
|
||
|
||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||
# Generate thumbnail
|
||
thumbnail_bytes = None
|
||
# Assuming AssetType.IMAGE since we are in generated_bytes_list which are images usually
|
||
# Or use explicit check if we have distinct types in list (not currently)
|
||
from utils.image_utils import create_thumbnail
|
||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
||
|
||
new_asset = Asset(
|
||
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
|
||
type=AssetType.IMAGE,
|
||
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
|
||
data=img_bytes,
|
||
thumbnail=thumbnail_bytes
|
||
# Остальные поля заполнятся дефолтными значениями (created_at)
|
||
)
|
||
|
||
# Сохраняем в БД
|
||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||
|
||
created_assets.append(new_asset)
|
||
|
||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||
result_ids = [a.id for a in created_assets]
|
||
|
||
generation.assets_list = result_ids
|
||
generation.status = GenerationStatus.DONE
|
||
generation.progress = 100
|
||
generation.updated_at = datetime.now(UTC)
|
||
generation.tech_prompt = generation_prompt
|
||
|
||
end_time = datetime.now()
|
||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||
|
||
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
||
|
||
await self.dao.generations.update_generation(generation)
|
||
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
||
|
||
# 6. Send to Telegram if telegram_id is provided
|
||
if generation.telegram_id and self.bot:
|
||
try:
|
||
for asset in created_assets:
|
||
if asset.data:
|
||
await self.bot.send_photo(
|
||
chat_id=generation.telegram_id,
|
||
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
||
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
||
)
|
||
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||
|
||
|
||
async def _simulate_progress(self, generation: Generation):
|
||
"""
|
||
Increments progress from 0 to 90 over ~20 seconds.
|
||
"""
|
||
current_progress = 0
|
||
try:
|
||
while current_progress < 90:
|
||
await asyncio.sleep(4)
|
||
# Random increment between 5 and 15
|
||
increment = random.randint(5, 15)
|
||
current_progress = min(current_progress + increment, 90)
|
||
|
||
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
||
# But for simplicity here we just use the object we have and save it.
|
||
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
||
# Assuming simple update is fine for now.
|
||
generation.progress = current_progress
|
||
await self.dao.generations.update_generation(generation)
|
||
except asyncio.CancelledError:
|
||
# Task cancelled, generation finished (or failed)
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"Error in progress simulation: {e}")
|