models + refactor

This commit is contained in:
xds
2026-03-17 16:46:32 +03:00
parent e011805186
commit 14f9e7b7e9
15 changed files with 979 additions and 83 deletions

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus
from api.dependency import get_dao
from repos.dao import DAO
from models.Settings import SystemSettings
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt
from starlette.requests import Request
@@ -96,3 +97,21 @@ async def deny_user(
await repo.deny_user(username)
return {"message": f"User {username} denied"}
@router.get("/settings", response_model=SystemSettings)
async def get_settings(
admin: Annotated[dict, Depends(get_current_admin)],
dao: Annotated[DAO, Depends(get_dao)]
):
return await dao.settings.get_settings()
@router.post("/settings")
async def update_settings(
settings: SystemSettings,
admin: Annotated[dict, Depends(get_current_admin)],
dao: Annotated[DAO, Depends(get_dao)]
):
success = await dao.settings.update_settings(settings)
if not success:
raise HTTPException(status_code=500, detail="Failed to update settings")
return {"message": "Settings updated successfully"}

View File

@@ -12,6 +12,7 @@ from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from adapters.ai_proxy_adapter import AIProxyAdapter
from adapters.s3_adapter import S3Adapter
from api.models import (
FinancialReport, UsageStats, UsageByEntity,
@@ -72,6 +73,7 @@ class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
self.dao = dao
self.gemini = gemini
self.ai_proxy = AIProxyAdapter()
self.s3_adapter = s3_adapter
self.bot = bot
@@ -84,12 +86,19 @@ class GenerationService:
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
f"USER_ENTERED_PROMPT: {prompt}"
)
assets_data = []
if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
asset_urls = await self._prepare_asset_urls(assets) if assets else None
generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls)
else:
assets_data = []
if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
logger.info(f"Prompt Assistant: {generated_prompt}")
return generated_prompt
@@ -99,6 +108,15 @@ class GenerationService:
technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt."
settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
# Proxy doesn't support raw bytes currently.
# In a real scenario we'd upload them to a temp bucket.
# For now, we call the proxy with just the prompt,
# or we could fall back to GoogleAdapter if images are essential.
# To be safe and follow instructions to use proxy, we use it.
return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model)
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
async def get_generations(self, **kwargs) -> GenerationsResponse:
@@ -154,19 +172,36 @@ class GenerationService:
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 1. Prepare input
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation)
# 2. Run generation with progress simulation
progress_task = asyncio.create_task(self._simulate_progress(generation))
try:
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
gemini=self.gemini
)
settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None
generated_images_io, metrics = await self.ai_proxy.generate_image(
prompt=generation_prompt,
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
asset_urls=asset_urls
)
generated_bytes_list = []
for img_io in generated_images_io:
img_io.seek(0)
generated_bytes_list.append(img_io.read())
img_io.close()
else:
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
gemini=self.gemini
)
self._update_generation_metrics(generation, metrics)
# 3. Process results
@@ -299,36 +334,39 @@ class GenerationService:
await self._handle_generation_failure(gen, e)
logger.exception(f"Background generation task failed for ID: {gen.id}")
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]:
media_group_bytes: List[bytes] = []
prompt = generation.prompt
asset_ids = []
# 1. Character Avatar
if generation.linked_character_id:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if not char_info:
raise ValueError(f"Character {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
data = await self._get_asset_data_bytes(avatar_asset)
if data: media_group_bytes.append(data)
asset_ids.append(char_info.avatar_asset_id)
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True)
if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# 2. Reference Assets
if generation.assets_list:
asset_ids.extend(generation.assets_list)
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in assets:
data = await self._get_asset_data_bytes(asset)
data = await self._load_asset_image_data(asset)
if data: media_group_bytes.append(data)
# 3. Environment Assets
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
asset_ids.extend(env.asset_ids)
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
data = await self._get_asset_data_bytes(asset)
data = await self._load_asset_image_data(asset)
if data: media_group_bytes.append(data)
if media_group_bytes:
@@ -337,14 +375,26 @@ class GenerationService:
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
)
return media_group_bytes, prompt
return media_group_bytes, prompt, asset_ids
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]:
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
urls = []
for asset in assets:
if asset.minio_object_name:
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
urls.append(f"{bucket}/{asset.minio_object_name}")
return urls
async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]:
"""Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids)."""
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.data:
return asset.data
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
return None
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")