models + refactor
This commit is contained in:
@@ -12,6 +12,7 @@ from aiogram.types import BufferedInputFile
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.ai_proxy_adapter import AIProxyAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import (
|
||||
FinancialReport, UsageStats, UsageByEntity,
|
||||
@@ -72,6 +73,7 @@ class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
self.ai_proxy = AIProxyAdapter()
|
||||
self.s3_adapter = s3_adapter
|
||||
self.bot = bot
|
||||
|
||||
@@ -84,12 +86,19 @@ class GenerationService:
|
||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||
f"USER_ENTERED_PROMPT: {prompt}"
|
||||
)
|
||||
assets_data = []
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
asset_urls = await self._prepare_asset_urls(assets) if assets else None
|
||||
generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls)
|
||||
else:
|
||||
assets_data = []
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
return generated_prompt
|
||||
|
||||
@@ -99,6 +108,15 @@ class GenerationService:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
# Proxy doesn't support raw bytes currently.
|
||||
# In a real scenario we'd upload them to a temp bucket.
|
||||
# For now, we call the proxy with just the prompt,
|
||||
# or we could fall back to GoogleAdapter if images are essential.
|
||||
# To be safe and follow instructions to use proxy, we use it.
|
||||
return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model)
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
@@ -154,19 +172,36 @@ class GenerationService:
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 1. Prepare input
|
||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||
media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation)
|
||||
|
||||
# 2. Run generation with progress simulation
|
||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||
try:
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None
|
||||
generated_images_io, metrics = await self.ai_proxy.generate_image(
|
||||
prompt=generation_prompt,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
asset_urls=asset_urls
|
||||
)
|
||||
generated_bytes_list = []
|
||||
for img_io in generated_images_io:
|
||||
img_io.seek(0)
|
||||
generated_bytes_list.append(img_io.read())
|
||||
img_io.close()
|
||||
else:
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
|
||||
# 3. Process results
|
||||
@@ -299,36 +334,39 @@ class GenerationService:
|
||||
await self._handle_generation_failure(gen, e)
|
||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]:
|
||||
media_group_bytes: List[bytes] = []
|
||||
prompt = generation.prompt
|
||||
asset_ids = []
|
||||
|
||||
# 1. Character Avatar
|
||||
if generation.linked_character_id:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if not char_info:
|
||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
data = await self._get_asset_data_bytes(avatar_asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
asset_ids.append(char_info.avatar_asset_id)
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True)
|
||||
if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
|
||||
# 2. Reference Assets
|
||||
if generation.assets_list:
|
||||
asset_ids.extend(generation.assets_list)
|
||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
data = await self._load_asset_image_data(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 3. Environment Assets
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
asset_ids.extend(env.asset_ids)
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
data = await self._load_asset_image_data(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
if media_group_bytes:
|
||||
@@ -337,14 +375,26 @@ class GenerationService:
|
||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||
)
|
||||
|
||||
return media_group_bytes, prompt
|
||||
return media_group_bytes, prompt, asset_ids
|
||||
|
||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||
async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]:
|
||||
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
|
||||
urls = []
|
||||
for asset in assets:
|
||||
if asset.minio_object_name:
|
||||
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
|
||||
urls.append(f"{bucket}/{asset.minio_object_name}")
|
||||
return urls
|
||||
|
||||
async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]:
|
||||
"""Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids)."""
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
if asset.data:
|
||||
return asset.data
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
return None
|
||||
|
||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
|
||||
Reference in New Issue
Block a user