+ api
This commit is contained in:
186
api/service/generation_service.py
Normal file
186
api/service/generation_service.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
from io import BytesIO
|
||||
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from models.Asset import Asset, AssetType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
prompt: str,
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
gemini: GoogleAdapter
|
||||
) -> List[bytes]:
|
||||
"""
|
||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||
Возвращает список байтов сгенерированных изображений.
|
||||
"""
|
||||
|
||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||
generated_images_io: List[BytesIO] = await asyncio.to_thread(
|
||||
gemini.generate_image,
|
||||
prompt=prompt,
|
||||
images_list=media_group_bytes,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
)
|
||||
|
||||
images_bytes = []
|
||||
if generated_images_io:
|
||||
for img_io in generated_images_io:
|
||||
# Читаем байты из BytesIO
|
||||
img_io.seek(0)
|
||||
content = img_io.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
# Закрываем поток
|
||||
img_io.close()
|
||||
|
||||
return images_bytes
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:"""
|
||||
future_prompt += prompt
|
||||
assets_data = []
|
||||
if assets is not None:
|
||||
assets_data = await self.dao.assets.get_assets_by_ids(assets)
|
||||
generated_prompt = self.gemini.generate_text(future_prompt, assets_data)
|
||||
return generated_prompt
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
||||
Generation]:
|
||||
return await self.dao.generations.get_generations(limit=limit, offset=offset)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if gen is None:
|
||||
return None
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
|
||||
async def get_running_generations(self) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump())
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
|
||||
asyncio.create_task(runner(generation_model))
|
||||
|
||||
return GenerationResponse(**generation_model.model_dump())
|
||||
|
||||
except Exception:
|
||||
# если не успели создать запись — нечего помечать
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
|
||||
async def create_generation(self, generation: Generation):
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = "You are creating image. "
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
media_group_bytes.append(char_info.character_image_data)
|
||||
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
# Фильтруем, чтобы отправлять только картинки, и где есть data
|
||||
media_group_bytes.extend(
|
||||
asset.data
|
||||
for asset in reference_assets
|
||||
if asset.data is not None and asset.type == AssetType.IMAGE
|
||||
)
|
||||
generation_prompt+=f"PROMPT: {generation.prompt}"
|
||||
|
||||
# 3. Запускаем процесс генерации
|
||||
try:
|
||||
generated_bytes_list = await generate_image_task(
|
||||
prompt=generation_prompt, # или request.prompt
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||
quality=generation.quality,
|
||||
gemini=self.gemini
|
||||
)
|
||||
except Exception as e:
|
||||
# Тут стоит добавить логирование ошибки
|
||||
logging.error(f"Generation failed: {e}")
|
||||
# Можно обновить статус генерации на FAILED в БД
|
||||
raise e
|
||||
|
||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||
created_assets: List[Asset] = []
|
||||
|
||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||
new_asset = Asset(
|
||||
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
|
||||
type=AssetType.IMAGE,
|
||||
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
|
||||
data=img_bytes,
|
||||
# Остальные поля заполнятся дефолтными значениями (created_at)
|
||||
)
|
||||
|
||||
# Сохраняем в БД
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||||
|
||||
created_assets.append(new_asset)
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = [a.id for a in created_assets]
|
||||
|
||||
generation.assets_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = generation_prompt
|
||||
await self.dao.generations.update_generation(generation)
|
||||
Reference in New Issue
Block a user