This commit is contained in:
xds
2026-02-20 13:10:37 +03:00
parent 9e0c522b5f
commit 1868864f76
6 changed files with 358 additions and 478 deletions

View File

@@ -1,5 +1,4 @@
import logging
import os
import json
from typing import List, Optional
@@ -30,12 +29,22 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/generations', tags=["Generation"])
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
"""Helper to check if user has access to project."""
if not project_id:
return
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service),
current_user: dict = Depends(get_current_user)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
async def ask_prompt_assistant(
prompt_request: PromptRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@@ -47,32 +56,33 @@ async def prompt_from_image(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
images_bytes = [await img.read() for img in images]
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
async def get_generations(
character_id: Optional[str] = None,
limit: int = 10,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
# If project_id is set, we don't filter by user to show all project-wide generations
created_by_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
return await generation_service.get_generations(
character_id=character_id,
limit=limit,
offset=offset,
created_by=created_by_filter,
project_id=project_id
)
@router.get("/usage", response_model=FinancialReport)
@@ -83,32 +93,15 @@ async def get_usage_report(
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"]) if not project_id else None
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
return await generation_service.get_financial_report(
user_id=user_id_filter,
@@ -116,58 +109,61 @@ async def get_usage_report(
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(
generation: GenerationRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> GenerationGroupResponse:
await check_project_access(project_id, current_user, dao)
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
return await generation_service.create_generation_task(
generation,
user_id=str(current_user.get("_id"))
)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
async def get_running_generations(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
return await generation_service.get_running_generations(
user_id=user_id_filter,
project_id=project_id
)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
async def get_generation_group(
group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
return await generation_service.get_generations_by_group(group_id)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
async def get_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> GenerationResponse:
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
@@ -180,43 +176,24 @@ async def get_generation(generation_id: str,
return gen
@router.post("/import", response_model=GenerationResponse)
async def import_external_generation(
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
"""
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = settings.EXTERNAL_API_SECRET
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
logger.warning("Invalid signature for external generation import")
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
# Import generation
try:
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
@@ -225,11 +202,11 @@ async def import_external_generation(
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
async def delete_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
if not await generation_service.delete_generation(generation_id):
raise HTTPException(status_code=404, detail="Generation not found")
return None

View File

@@ -14,7 +14,7 @@ class ExternalGenerationRequest(BaseModel):
image_url: Optional[str] = Field(None, description="URL to download image from")
# Generation metadata
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK
# Optional linking

View File

@@ -10,7 +10,7 @@ from models.enums import AspectRatios, Quality, GenType
class GenerationRequest(BaseModel):
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK
prompt: str
telegram_id: Optional[int] = None

View File

@@ -13,13 +13,15 @@ from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter
from api.models import FinancialReport, UsageStats, UsageByEntity
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from api.models import (
FinancialReport, UsageStats, UsageByEntity,
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
)
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from repos.dao import DAO
from utils.image_utils import create_thumbnail
logger = logging.getLogger(__name__)
@@ -27,22 +29,18 @@ logger = logging.getLogger(__name__)
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
gemini: GoogleAdapter,
) -> Tuple[List[bytes], Dict[str, Any]]:
"""
Обертка для вызова синхронного метода Gemini в отдельном потоке.
Возвращает список байтов сгенерированных изображений.
Wrapper for calling Gemini's synchronous method in a separate thread.
"""
try :
try:
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
result = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
@@ -51,12 +49,10 @@ async def generate_image_task(
quality=quality,
)
generated_images_io, metrics = result
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e:
raise e
except GoogleGenerationException:
raise
finally:
# Освобождаем входные данные — они больше не нужны
del media_group_bytes
images_bytes = []
@@ -65,371 +61,136 @@ async def generate_image_task(
img_io.seek(0)
images_bytes.append(img_io.read())
img_io.close()
# Освобождаем список BytesIO сразу
del generated_images_io
return images_bytes, metrics
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
self.dao = dao
self.gemini = gemini
self.s3_adapter = s3_adapter
self.bot = bot
# --- Public API ---
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
future_prompt += prompt
future_prompt = (
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
"User may upload reference image too. I will provide sources prompt entered by user. "
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
f"USER_ENTERED_PROMPT: {prompt}"
)
assets_data = []
if assets is not None:
if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db)
assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
logger.info(future_prompt)
logger.info(generated_prompt)
logger.info(f"Prompt Assistant: {generated_prompt}")
return generated_prompt
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
if user_prompt:
technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt."
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count)
async def get_generations(self, **kwargs) -> GenerationsResponse:
generations = await self.dao.generations.get_generations(**kwargs)
total_count = await self.dao.generations.count_generations(
character_id=kwargs.get('character_id'),
created_by=kwargs.get('created_by'),
project_id=kwargs.get('project_id'),
idea_id=kwargs.get('idea_id')
)
return GenerationsResponse(
generations=[GenerationResponse(**gen.model_dump()) for gen in generations],
total_count=total_count
)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
if gen is None:
return None
else:
return GenerationResponse(**gen.model_dump())
return GenerationResponse(**gen.model_dump()) if gen else None
async def get_generations_by_group(self, group_id: str) -> GenerationGroupResponse:
generations = await self.dao.generations.get_generations_by_group(group_id)
return GenerationGroupResponse(
generation_group_id=group_id,
generations=[GenerationResponse(**gen.model_dump()) for gen in generations]
)
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
for _ in range(generation_request.count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id:
generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
try:
async with generation_semaphore:
logger.info(f"Starting background generation task for ID: {gen.id}")
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen is not None:
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model))
return GenerationResponse(**generation_model.model_dump())
except Exception:
# если не успели создать запись — нечего помечать
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
if gen is not None:
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation):
start_time = datetime.now()
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = []
generation_prompt = generation.prompt
# generation_prompt = f"""
# Create detailed image of character in scene.
# SCENE DESCRIPTION: {generation.prompt}
# Rules:
# - Integrate the character's appearance naturally into the scene description.
# - Focus on lighting, texture, and composition.
# """
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
if char_info.avatar_asset_id is not None:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
# 1. Prepare input
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
for asset in reference_assets:
if asset.content_type != AssetContentType.IMAGE:
continue
img_data = None
if asset.minio_object_name:
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
elif asset.data:
img_data = asset.data
if img_data:
media_group_bytes.append(img_data)
if media_group_bytes:
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
# 3. Запускаем процесс генерации и симуляцию прогресса
# 2. Run generation with progress simulation
progress_task = asyncio.create_task(self._simulate_progress(generation))
try:
# Default to Image Generation (Gemini)
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt, # или request.prompt
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
gemini=self.gemini
)
# Update metrics from API (Common for both)
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
except GoogleGenerationException as e:
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
self._update_generation_metrics(generation, metrics)
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
await self._handle_generation_failure(generation, e)
raise
finally:
if not progress_task.done():
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
# 4. Сохраняем полученные изображения как новые Ассеты
created_assets: List[Asset] = []
# 3. Process results
created_assets = await self._process_generated_images(generation, generated_bytes_list)
for idx, img_bytes in enumerate(generated_bytes_list):
# Generate thumbnail
thumbnail_bytes = None
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
# Save to S3
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
new_asset = Asset(
name=f"Generated_{generation.linked_character_id}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=generation.linked_character_id,
data=None, # Not storing bytes in DB anymore
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=generation.created_by,
project_id=generation.project_id
)
# 4. Finalize generation record
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
# Сохраняем в БД
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
created_assets.append(new_asset)
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = generation_prompt
end_time = datetime.now()
generation.execution_time_seconds = (end_time - start_time).total_seconds()
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
# 6. Send to Telegram if telegram_id is provided
# 5. Notify
if generation.telegram_id and self.bot:
try:
for asset in created_assets:
if asset.data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
caption=f"Generated from prompt: {generation.prompt[:100]}..."
)
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
except Exception as e:
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
async def _simulate_progress(self, generation: Generation):
"""
Increments progress from 0 to 90 over ~20 seconds.
"""
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
# Random increment between 5 and 15
increment = random.randint(5, 15)
current_progress = min(current_progress + increment, 90)
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
# But for simplicity here we just use the object we have and save it.
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
# Assuming simple update is fine for now.
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
# Task cancelled, generation finished (or failed)
pass
except Exception as e:
logger.error(f"Error in progress simulation: {e}")
await self._notify_telegram(generation, created_assets)
async def import_external_generation(self, external_gen) -> Generation:
"""
Import a generation from an external source.
Args:
external_gen: ExternalGenerationRequest with generation data and image
Returns:
Created Generation object
"""
# Validate image source
external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
image_bytes = await self._fetch_external_image(external_gen)
# 1. Process image (download or decode)
image_bytes = None
if external_gen.image_url:
# Download image from URL
logger.info(f"Downloading image from URL: {external_gen.image_url}")
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
image_bytes = response.content
elif external_gen.image_data:
# Decode base64 image
logger.info("Decoding base64 image data")
image_bytes = base64.b64decode(external_gen.image_data)
if not image_bytes:
raise ValueError("Failed to process image data")
# 2. Generate thumbnail
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
# 3. Save to S3
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
# 4. Create Asset
new_asset = Asset(
# Reuse internal processing logic
new_asset = await self._save_asset(
image_bytes=image_bytes,
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=external_gen.linked_character_id,
data=None, # Not storing bytes in DB
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=external_gen.created_by,
project_id=external_gen.project_id
project_id=external_gen.project_id,
linked_char_id=external_gen.linked_character_id,
folder="external"
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
logger.info(f"Created asset {asset_id} for external generation")
# 5. Create Generation record
generation = Generation(
status=GenerationStatus.DONE,
linked_character_id=external_gen.linked_character_id,
@@ -446,27 +207,18 @@ class GenerationService:
input_token_usage=external_gen.input_token_usage,
output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
project_id=external_gen.project_id,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
project_id=external_gen.project_id
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
logger.info(f"Created generation {gen_id} from external source")
return generation
async def delete_generation(self, generation_id: str) -> bool:
"""
Soft delete generation by marking it as deleted.
"""
try:
generation = await self.dao.generations.get_generation(generation_id)
if not generation:
return False
generation.is_deleted = True
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
@@ -476,59 +228,200 @@ class GenerationService:
return False
async def cleanup_stale_generations(self):
"""
Cancels generations that have been running for more than 1 hour.
"""
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
logger.info(f"Cleaned up {count} stale generations")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 30):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
"""
Generates a financial usage report for a specific user or project.
'breakdown_by' can be 'created_by' or 'project_id'.
"""
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user = None
by_project = None
by_user, by_project = None, None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(
summary=summary,
by_user=by_user,
by_project=by_project
)
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
# --- Private Helpers ---
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
try:
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
gen_model.created_by = user_id
gen_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(gen_model)
gen_model.id = gen_id
asyncio.create_task(self._queued_generation_runner(gen_model))
return GenerationResponse(**gen_model.model_dump())
except Exception:
logger.exception("Failed to initiate single generation")
raise
async def _queued_generation_runner(self, gen: Generation):
logger.info(f"Generation {gen.id} waiting for slot...")
try:
async with generation_semaphore:
await self.create_generation(gen)
except Exception:
await self._handle_generation_failure(gen, None)
logger.exception(f"Background generation task failed for ID: {gen.id}")
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
media_group_bytes: List[bytes] = []
prompt = generation.prompt
# 1. Character Avatar
if generation.linked_character_id:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if not char_info:
raise ValueError(f"Character {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
data = await self._get_asset_data_bytes(avatar_asset)
if data: media_group_bytes.append(data)
# 2. Reference Assets
if generation.assets_list:
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
# 3. Environment Assets
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
if media_group_bytes:
prompt += (
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
)
return media_group_bytes, prompt
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
logger.error(f"Generation {generation.id} failed: {error}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(error) if error else "Unknown error"
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
created_assets = []
for img_bytes in bytes_list:
asset = await self._save_asset(
image_bytes=img_bytes,
name=f"Generated_{generation.linked_character_id}",
created_by=generation.created_by,
project_id=generation.project_id,
linked_char_id=generation.linked_character_id,
folder="generated"
)
created_assets.append(asset)
return created_assets
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
new_asset = Asset(
name=name,
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
return new_asset
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
generation.result_list = [a.id for a in assets]
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = tech_prompt
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
try:
for asset in assets:
# Need to get data for telegram if it's not in Asset object
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
if img_data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
caption=f"Generated from: {generation.prompt[:100]}..."
)
except Exception as e:
logger.error(f"Failed to send to Telegram: {e}")
async def _simulate_progress(self, generation: Generation):
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
current_progress = min(current_progress + random.randint(5, 15), 90)
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
pass
async def _fetch_external_image(self, external_gen) -> bytes:
if external_gen.image_url:
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
return response.content
elif external_gen.image_data:
return base64.b64decode(external_gen.image_data)
raise ValueError("No image source provided")

View File

@@ -2,19 +2,30 @@ from enum import Enum
class AspectRatios(str, Enum):
NINESIXTEEN = "NINESIXTEEN"
SIXTEENNINE = "SIXTEENNINE"
THREEFOUR = "THREEFOUR"
FOURTHREE = "FOURTHREE"
ONEONE = "1:1"
TWOTHREE = "2:3"
THREETWO = "3:2"
THREEFOUR = "3:4"
FOURTHREE = "4:3"
FOURFIVE = "4:5"
FIVEFOUR = "5:4"
NINESIXTEEN = "9:16"
SIXTEENNINE = "16:9"
TWENTYONENINE = "21:9"
@classmethod
def _missing_(cls, value):
mapping = {
"NINESIXTEEN": cls.NINESIXTEEN,
"SIXTEENNINE": cls.SIXTEENNINE,
"THREEFOUR": cls.THREEFOUR,
"FOURTHREE": cls.FOURTHREE,
}
return mapping.get(value)
@property
def value_ratio(self) -> str:
return {
AspectRatios.NINESIXTEEN: "9:16",
AspectRatios.SIXTEENNINE: "16:9",
AspectRatios.THREEFOUR: "3:4",
AspectRatios.FOURTHREE: "4:3",
}[self]
return self.value
class Quality(str, Enum):

View File

@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
keyboards = []
for ratio in AspectRatios:
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
await call.message.edit_caption(caption="Выбери соотношение сторон",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))