Files
ai-char-bot/api/service/generation_service.py
2026-02-04 16:45:25 +03:00

190 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import logging
import random
from datetime import datetime, UTC
from typing import List, Optional
from io import BytesIO
from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from repos.dao import DAO
logger = logging.getLogger(__name__)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
gemini: GoogleAdapter
) -> List[bytes]:
"""
Обертка для вызова синхронного метода Gemini в отдельном потоке.
Возвращает список байтов сгенерированных изображений.
"""
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
generated_images_io: List[BytesIO] = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=aspect_ratio,
quality=quality,
)
images_bytes = []
if generated_images_io:
for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0)
content = img_io.read()
images_bytes.append(content)
# Закрываем поток
img_io.close()
return images_bytes
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter):
self.dao = dao
self.gemini = gemini
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:"""
future_prompt += prompt
assets_data = []
if assets is not None:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db)
generated_prompt = self.gemini.generate_text(future_prompt, assets_data)
logger.info(future_prompt)
logger.info(generated_prompt)
return generated_prompt
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
Generation]:
return await self.dao.generations.get_generations(limit=limit, offset=offset)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
if gen is None:
return None
else:
return GenerationResponse(**gen.model_dump())
async def get_running_generations(self) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump())
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
try:
await self.create_generation(gen)
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model))
return GenerationResponse(**generation_model.model_dump())
except Exception:
# если не успели создать запись — нечего помечать
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation):
# 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = []
generation_prompt = "You are creating image. "
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
media_group_bytes.append(char_info.character_image_data)
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
# Фильтруем, чтобы отправлять только картинки, и где есть data
media_group_bytes.extend(
asset.data
for asset in reference_assets
if asset.data is not None and asset.type == AssetType.IMAGE
)
generation_prompt+=f"PROMPT: {generation.prompt}"
# 3. Запускаем процесс генерации
try:
generated_bytes_list = await generate_image_task(
prompt=generation_prompt, # или request.prompt
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
quality=generation.quality,
gemini=self.gemini
)
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
# Можно обновить статус генерации на FAILED в БД
raise e
# 4. Сохраняем полученные изображения как новые Ассеты
created_assets: List[Asset] = []
for idx, img_bytes in enumerate(generated_bytes_list):
new_asset = Asset(
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
type=AssetType.IMAGE,
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
data=img_bytes,
# Остальные поля заполнятся дефолтными значениями (created_at)
)
# Сохраняем в БД
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
created_assets.append(new_asset)
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
generation.assets_list = result_ids
generation.status = GenerationStatus.DONE
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = generation_prompt
await self.dao.generations.update_generation(generation)