fixes
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -30,12 +29,22 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
async def check_project_access(project_id: Optional[str], current_user: dict, dao: DAO):
|
||||
"""Helper to check if user has access to project."""
|
||||
if not project_id:
|
||||
return
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
async def ask_prompt_assistant(
|
||||
prompt_request: PromptRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
@@ -47,32 +56,33 @@ async def prompt_from_image(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
for image in images:
|
||||
content = await image.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
images_bytes = [await img.read() for img in images]
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.get("", response_model=GenerationsResponse)
|
||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
async def get_generations(
|
||||
character_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # Show all project generations
|
||||
# If project_id is set, we don't filter by user to show all project-wide generations
|
||||
created_by_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||
return await generation_service.get_generations(
|
||||
character_id=character_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
created_by=created_by_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
@@ -83,32 +93,15 @@ async def get_usage_report(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
"""
|
||||
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
||||
If project_id is provided, returns stats for that project.
|
||||
Otherwise, returns stats for the current user.
|
||||
"""
|
||||
user_id_filter = str(current_user["_id"])
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"]) if not project_id else None
|
||||
breakdown_by = None
|
||||
|
||||
if project_id:
|
||||
# Permission check
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
else:
|
||||
# Default: Stats for current user
|
||||
if breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
elif breakdown == "user":
|
||||
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
||||
# No, if project_id is None, it's personal.
|
||||
breakdown_by = "created_by"
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
@@ -116,58 +109,61 @@ async def get_usage_report(
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(
|
||||
generation: GenerationRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> GenerationGroupResponse:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
generation.project_id = project_id
|
||||
|
||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||
|
||||
return await generation_service.create_generation_task(
|
||||
generation,
|
||||
user_id=str(current_user.get("_id"))
|
||||
)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
async def get_running_generations(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
user_id_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||
return await generation_service.get_running_generations(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||
async def get_generation_group(group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"get_generation_group called for group_id: {group_id}")
|
||||
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
||||
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
||||
async def get_generation_group(
|
||||
group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations_by_group(group_id)
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> GenerationResponse:
|
||||
gen = await generation_service.get_generation(generation_id)
|
||||
if gen and gen.created_by != str(current_user["_id"]):
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
@@ -180,43 +176,24 @@ async def get_generation(generation_id: str,
|
||||
return gen
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Verify signature
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
logger.warning("Invalid signature for external generation import")
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||
|
||||
# Import generation
|
||||
try:
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
@@ -225,11 +202,11 @@ async def import_external_generation(
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
if not await generation_service.delete_generation(generation_id):
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return None
|
||||
|
||||
@@ -14,7 +14,7 @@ class ExternalGenerationRequest(BaseModel):
|
||||
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||
|
||||
# Generation metadata
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
quality: Quality = Quality.ONEK
|
||||
|
||||
# Optional linking
|
||||
|
||||
@@ -10,7 +10,7 @@ from models.enums import AspectRatios, Quality, GenType
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
quality: Quality = Quality.ONEK
|
||||
prompt: str
|
||||
telegram_id: Optional[int] = None
|
||||
|
||||
@@ -13,13 +13,15 @@ from aiogram.types import BufferedInputFile
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import FinancialReport, UsageStats, UsageByEntity
|
||||
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from api.models import (
|
||||
FinancialReport, UsageStats, UsageByEntity,
|
||||
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
)
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
from utils.image_utils import create_thumbnail
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,22 +29,18 @@ logger = logging.getLogger(__name__)
|
||||
generation_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
prompt: str,
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
gemini: GoogleAdapter,
|
||||
|
||||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||
"""
|
||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||
Возвращает список байтов сгенерированных изображений.
|
||||
Wrapper for calling Gemini's synchronous method in a separate thread.
|
||||
"""
|
||||
try :
|
||||
try:
|
||||
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||
result = await asyncio.to_thread(
|
||||
gemini.generate_image,
|
||||
prompt=prompt,
|
||||
@@ -51,12 +49,10 @@ async def generate_image_task(
|
||||
quality=quality,
|
||||
)
|
||||
generated_images_io, metrics = result
|
||||
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
except GoogleGenerationException as e:
|
||||
raise e
|
||||
except GoogleGenerationException:
|
||||
raise
|
||||
finally:
|
||||
# Освобождаем входные данные — они больше не нужны
|
||||
del media_group_bytes
|
||||
|
||||
images_bytes = []
|
||||
@@ -65,371 +61,136 @@ async def generate_image_task(
|
||||
img_io.seek(0)
|
||||
images_bytes.append(img_io.read())
|
||||
img_io.close()
|
||||
# Освобождаем список BytesIO сразу
|
||||
del generated_images_io
|
||||
|
||||
return images_bytes, metrics
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
self.s3_adapter = s3_adapter
|
||||
self.bot = bot
|
||||
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||
future_prompt += prompt
|
||||
future_prompt = (
|
||||
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
||||
"User may upload reference image too. I will provide sources prompt entered by user. "
|
||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||
f"USER_ENTERED_PROMPT: {prompt}"
|
||||
)
|
||||
assets_data = []
|
||||
if assets is not None:
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||||
logger.info(future_prompt)
|
||||
logger.info(generated_prompt)
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
return generated_prompt
|
||||
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||
if user_prompt:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
generations = await self.dao.generations.get_generations(**kwargs)
|
||||
total_count = await self.dao.generations.count_generations(
|
||||
character_id=kwargs.get('character_id'),
|
||||
created_by=kwargs.get('created_by'),
|
||||
project_id=kwargs.get('project_id'),
|
||||
idea_id=kwargs.get('idea_id')
|
||||
)
|
||||
return GenerationsResponse(
|
||||
generations=[GenerationResponse(**gen.model_dump()) for gen in generations],
|
||||
total_count=total_count
|
||||
)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if gen is None:
|
||||
return None
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
return GenerationResponse(**gen.model_dump()) if gen else None
|
||||
|
||||
async def get_generations_by_group(self, group_id: str) -> GenerationGroupResponse:
|
||||
generations = await self.dao.generations.get_generations_by_group(group_id)
|
||||
return GenerationGroupResponse(
|
||||
generation_group_id=group_id,
|
||||
generations=[GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
)
|
||||
|
||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
count = generation_request.count
|
||||
|
||||
if generation_group_id is None:
|
||||
generation_group_id = str(uuid4())
|
||||
|
||||
|
||||
results = []
|
||||
for _ in range(count):
|
||||
for _ in range(generation_request.count):
|
||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||
results.append(gen_response)
|
||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||
|
||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
if user_id:
|
||||
generation_model.created_by = user_id
|
||||
if generation_group_id:
|
||||
generation_model.generation_group_id = generation_group_id
|
||||
|
||||
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
|
||||
if generation_request.idea_id:
|
||||
generation_model.idea_id = generation_request.idea_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
|
||||
try:
|
||||
async with generation_semaphore:
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
if db_gen is not None:
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
|
||||
asyncio.create_task(runner(generation_model))
|
||||
|
||||
return GenerationResponse(**generation_model.model_dump())
|
||||
|
||||
except Exception:
|
||||
# если не успели создать запись — нечего помечать
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
if gen is not None:
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
|
||||
async def create_generation(self, generation: Generation):
|
||||
start_time = datetime.now()
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = generation.prompt
|
||||
# generation_prompt = f"""
|
||||
|
||||
# Create detailed image of character in scene.
|
||||
|
||||
# SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
# Rules:
|
||||
# - Integrate the character's appearance naturally into the scene description.
|
||||
# - Focus on lighting, texture, and composition.
|
||||
# """
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
if char_info.avatar_asset_id is not None:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset and avatar_asset.data:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
# 1. Prepare input
|
||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
|
||||
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
||||
|
||||
# 3. Запускаем процесс генерации и симуляцию прогресса
|
||||
# 2. Run generation with progress simulation
|
||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||
|
||||
try:
|
||||
|
||||
# Default to Image Generation (Gemini)
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt, # или request.prompt
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
gemini=self.gemini
|
||||
)
|
||||
|
||||
|
||||
# Update metrics from API (Common for both)
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
|
||||
except GoogleGenerationException as e:
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
except Exception as e:
|
||||
# Тут стоит добавить логирование ошибки
|
||||
logging.error(f"Generation failed: {e}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
await self._handle_generation_failure(generation, e)
|
||||
raise
|
||||
finally:
|
||||
if not progress_task.done():
|
||||
if not progress_task.done():
|
||||
progress_task.cancel()
|
||||
try:
|
||||
await progress_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||
created_assets: List[Asset] = []
|
||||
# 3. Process results
|
||||
created_assets = await self._process_generated_images(generation, generated_bytes_list)
|
||||
|
||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||
# Generate thumbnail
|
||||
thumbnail_bytes = None
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
||||
|
||||
# Save to S3
|
||||
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
data=None, # Not storing bytes in DB anymore
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id
|
||||
)
|
||||
# 4. Finalize generation record
|
||||
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
|
||||
|
||||
# Сохраняем в БД
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||||
|
||||
created_assets.append(new_asset)
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = []
|
||||
for a in created_assets:
|
||||
result_ids.append(a.id)
|
||||
|
||||
generation.result_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = generation_prompt
|
||||
|
||||
end_time = datetime.now()
|
||||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
||||
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
# 6. Send to Telegram if telegram_id is provided
|
||||
# 5. Notify
|
||||
if generation.telegram_id and self.bot:
|
||||
try:
|
||||
for asset in created_assets:
|
||||
if asset.data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
||||
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
||||
)
|
||||
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
"""
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
# Random increment between 5 and 15
|
||||
increment = random.randint(5, 15)
|
||||
current_progress = min(current_progress + increment, 90)
|
||||
|
||||
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
||||
# But for simplicity here we just use the object we have and save it.
|
||||
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
||||
# Assuming simple update is fine for now.
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
# Task cancelled, generation finished (or failed)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in progress simulation: {e}")
|
||||
|
||||
|
||||
|
||||
await self._notify_telegram(generation, created_assets)
|
||||
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
|
||||
Args:
|
||||
external_gen: ExternalGenerationRequest with generation data and image
|
||||
|
||||
Returns:
|
||||
Created Generation object
|
||||
"""
|
||||
|
||||
# Validate image source
|
||||
external_gen.validate_image_source()
|
||||
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
|
||||
image_bytes = await self._fetch_external_image(external_gen)
|
||||
|
||||
# 1. Process image (download or decode)
|
||||
image_bytes = None
|
||||
|
||||
if external_gen.image_url:
|
||||
# Download image from URL
|
||||
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif external_gen.image_data:
|
||||
# Decode base64 image
|
||||
logger.info("Decoding base64 image data")
|
||||
image_bytes = base64.b64decode(external_gen.image_data)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Failed to process image data")
|
||||
|
||||
# 2. Generate thumbnail
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
|
||||
# 3. Save to S3
|
||||
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
# 4. Create Asset
|
||||
new_asset = Asset(
|
||||
# Reuse internal processing logic
|
||||
new_asset = await self._save_asset(
|
||||
image_bytes=image_bytes,
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
data=None, # Not storing bytes in DB
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
project_id=external_gen.project_id,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
folder="external"
|
||||
)
|
||||
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
|
||||
logger.info(f"Created asset {asset_id} for external generation")
|
||||
|
||||
# 5. Create Generation record
|
||||
|
||||
generation = Generation(
|
||||
status=GenerationStatus.DONE,
|
||||
linked_character_id=external_gen.linked_character_id,
|
||||
@@ -446,27 +207,18 @@ class GenerationService:
|
||||
input_token_usage=external_gen.input_token_usage,
|
||||
output_token_usage=external_gen.output_token_usage,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC)
|
||||
project_id=external_gen.project_id
|
||||
)
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
logger.info(f"Created generation {gen_id} from external source")
|
||||
|
||||
return generation
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
Soft delete generation by marking it as deleted.
|
||||
"""
|
||||
try:
|
||||
generation = await self.dao.generations.get_generation(generation_id)
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
generation.is_deleted = True
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
@@ -476,59 +228,200 @@ class GenerationService:
|
||||
return False
|
||||
|
||||
async def cleanup_stale_generations(self):
|
||||
"""
|
||||
Cancels generations that have been running for more than 1 hour.
|
||||
"""
|
||||
try:
|
||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
|
||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} stale generations (timeout)")
|
||||
logger.info(f"Cleaned up {count} stale generations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up stale generations: {e}")
|
||||
|
||||
async def cleanup_old_data(self, days: int = 30):
|
||||
"""
|
||||
Очистка старых данных:
|
||||
1. Мягко удаляет генерации старше N дней
|
||||
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
|
||||
"""
|
||||
try:
|
||||
# 1. Мягко удаляем генерации и собираем asset IDs
|
||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
||||
|
||||
if gen_count > 0:
|
||||
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
|
||||
f"Found {len(asset_ids)} associated asset IDs.")
|
||||
|
||||
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
|
||||
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
|
||||
if asset_ids:
|
||||
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
|
||||
|
||||
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during old data cleanup: {e}")
|
||||
|
||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||
"""
|
||||
Generates a financial usage report for a specific user or project.
|
||||
'breakdown_by' can be 'created_by' or 'project_id'.
|
||||
"""
|
||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||
summary = UsageStats(**summary_data)
|
||||
|
||||
by_user = None
|
||||
by_project = None
|
||||
|
||||
by_user, by_project = None, None
|
||||
if breakdown_by == "created_by":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||
by_user = [UsageByEntity(**item) for item in res]
|
||||
|
||||
if breakdown_by == "project_id":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||
by_project = [UsageByEntity(**item) for item in res]
|
||||
|
||||
return FinancialReport(
|
||||
summary=summary,
|
||||
by_user=by_user,
|
||||
by_project=by_project
|
||||
)
|
||||
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
|
||||
|
||||
# --- Private Helpers ---
|
||||
|
||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
|
||||
try:
|
||||
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
gen_model.created_by = user_id
|
||||
gen_model.generation_group_id = generation_group_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(gen_model)
|
||||
gen_model.id = gen_id
|
||||
|
||||
asyncio.create_task(self._queued_generation_runner(gen_model))
|
||||
return GenerationResponse(**gen_model.model_dump())
|
||||
except Exception:
|
||||
logger.exception("Failed to initiate single generation")
|
||||
raise
|
||||
|
||||
async def _queued_generation_runner(self, gen: Generation):
|
||||
logger.info(f"Generation {gen.id} waiting for slot...")
|
||||
try:
|
||||
async with generation_semaphore:
|
||||
await self.create_generation(gen)
|
||||
except Exception:
|
||||
await self._handle_generation_failure(gen, None)
|
||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||
media_group_bytes: List[bytes] = []
|
||||
prompt = generation.prompt
|
||||
|
||||
# 1. Character Avatar
|
||||
if generation.linked_character_id:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if not char_info:
|
||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
data = await self._get_asset_data_bytes(avatar_asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 2. Reference Assets
|
||||
if generation.assets_list:
|
||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 3. Environment Assets
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
if media_group_bytes:
|
||||
prompt += (
|
||||
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
|
||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||
)
|
||||
|
||||
return media_group_bytes, prompt
|
||||
|
||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
|
||||
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
|
||||
logger.error(f"Generation {generation.id} failed: {error}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(error) if error else "Unknown error"
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
|
||||
created_assets = []
|
||||
for img_bytes in bytes_list:
|
||||
asset = await self._save_asset(
|
||||
image_bytes=img_bytes,
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
folder="generated"
|
||||
)
|
||||
created_assets.append(asset)
|
||||
return created_assets
|
||||
|
||||
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=name,
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=None,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
return new_asset
|
||||
|
||||
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
|
||||
generation.result_list = [a.id for a in assets]
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = tech_prompt
|
||||
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
|
||||
try:
|
||||
for asset in assets:
|
||||
# Need to get data for telegram if it's not in Asset object
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
|
||||
if img_data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
|
||||
caption=f"Generated from: {generation.prompt[:100]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to Telegram: {e}")
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
current_progress = min(current_progress + random.randint(5, 15), 90)
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _fetch_external_image(self, external_gen) -> bytes:
|
||||
if external_gen.image_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif external_gen.image_data:
|
||||
return base64.b64decode(external_gen.image_data)
|
||||
raise ValueError("No image source provided")
|
||||
|
||||
@@ -2,19 +2,30 @@ from enum import Enum
|
||||
|
||||
|
||||
class AspectRatios(str, Enum):
|
||||
NINESIXTEEN = "NINESIXTEEN"
|
||||
SIXTEENNINE = "SIXTEENNINE"
|
||||
THREEFOUR = "THREEFOUR"
|
||||
FOURTHREE = "FOURTHREE"
|
||||
ONEONE = "1:1"
|
||||
TWOTHREE = "2:3"
|
||||
THREETWO = "3:2"
|
||||
THREEFOUR = "3:4"
|
||||
FOURTHREE = "4:3"
|
||||
FOURFIVE = "4:5"
|
||||
FIVEFOUR = "5:4"
|
||||
NINESIXTEEN = "9:16"
|
||||
SIXTEENNINE = "16:9"
|
||||
TWENTYONENINE = "21:9"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
mapping = {
|
||||
"NINESIXTEEN": cls.NINESIXTEEN,
|
||||
"SIXTEENNINE": cls.SIXTEENNINE,
|
||||
"THREEFOUR": cls.THREEFOUR,
|
||||
"FOURTHREE": cls.FOURTHREE,
|
||||
}
|
||||
return mapping.get(value)
|
||||
|
||||
@property
|
||||
def value_ratio(self) -> str:
|
||||
return {
|
||||
AspectRatios.NINESIXTEEN: "9:16",
|
||||
AspectRatios.SIXTEENNINE: "16:9",
|
||||
AspectRatios.THREEFOUR: "3:4",
|
||||
AspectRatios.FOURTHREE: "4:3",
|
||||
}[self]
|
||||
return self.value
|
||||
|
||||
|
||||
class Quality(str, Enum):
|
||||
|
||||
@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||||
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
await call.answer()
|
||||
keyboards = []
|
||||
for ratio in AspectRatios:
|
||||
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
||||
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
|
||||
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
|
||||
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
|
||||
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||||
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
|
||||
|
||||
|
||||
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||||
|
||||
Reference in New Issue
Block a user