8 Commits

37 changed files with 275 additions and 80 deletions

14
.gitignore vendored
View File

@@ -9,3 +9,17 @@ minio_backup.tar.gz
.idea .idea
.venv .venv
.vscode .vscode
.vscode/launch.json
middlewares/__pycache__/
middlewares/*.pyc
api/__pycache__/
api/*.pyc
repos/__pycache__/
repos/*.pyc
adapters/__pycache__/
adapters/*.pyc
services/__pycache__/
services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json

27
.vscode/launch.json vendored
View File

@@ -7,7 +7,7 @@
"request": "launch", "request": "launch",
"module": "uvicorn", "module": "uvicorn",
"args": [ "args": [
"main:app", "aiws:app",
"--reload", "--reload",
"--port", "--port",
"8090", "8090",
@@ -16,31 +16,6 @@
], ],
"jinja": true, "jinja": true,
"justMyCode": true "justMyCode": true
},
{
"name": "Python: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Debug Tests: Current File",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"${file}"
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
} }
] ]
} }

View File

@@ -23,28 +23,30 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview" self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview" self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list: def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки)""" """Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents = [prompt] contents = [prompt]
opened_images = []
if images_list: if images_list:
logger.info(f"Preparing content with {len(images_list)} images") logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list: for img_bytes in images_list:
try: try:
# Gemini API требует PIL Image на входе
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
contents.append(image) contents.append(image)
opened_images.append(image)
except Exception as e: except Exception as e:
logger.error(f"Error processing input image: {e}") logger.error(f"Error processing input image: {e}")
else: else:
logger.info("Preparing content with no images") logger.info("Preparing content with no images")
return contents return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str: def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
""" """
Генерация текста (Чат или Vision). Генерация текста (Чат или Vision).
Возвращает строку с ответом. Возвращает строку с ответом.
""" """
contents = self._prepare_contents(prompt, images_list) contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}") logger.info(f"Generating text: {prompt}")
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
@@ -68,6 +70,9 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Text API Error: {e}") logger.error(f"Gemini Text API Error: {e}")
raise GoogleGenerationException(f"Gemini Text API Error: {e}") raise GoogleGenerationException(f"Gemini Text API Error: {e}")
finally:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]: def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
""" """
@@ -75,7 +80,7 @@ class GoogleAdapter:
Возвращает список байтовых потоков (готовых к отправке). Возвращает список байтовых потоков (готовых к отправке).
""" """
contents = self._prepare_contents(prompt, images_list) contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}") logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
start_time = datetime.now() start_time = datetime.now()
@@ -101,8 +106,20 @@ class GoogleAdapter:
if response.usage_metadata: if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count token_usage = response.usage_metadata.total_token_count
if response.parts is None and response.candidates[0].finish_reason is not None: # Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}") if response.prompt_feedback and response.prompt_feedback.block_reason:
raise GoogleGenerationException(
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
)
# Check candidate-level block
if response.parts is None:
response_reason = (
response.candidates[0].finish_reason
if response.candidates and len(response.candidates) > 0
else "Unknown"
)
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
generated_images = [] generated_images = []
@@ -148,3 +165,7 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Image API Error: {e}") logger.error(f"Gemini Image API Error: {e}")
raise GoogleGenerationException(f"Gemini Image API Error: {e}") raise GoogleGenerationException(f"Gemini Image API Error: {e}")
finally:
for img in opened_images:
img.close()
del contents

View File

@@ -56,6 +56,21 @@ class S3Adapter:
print(f"Error downloading from S3: {e}") print(f"Error downloading from S3: {e}")
return None return None
async def stream_file(self, object_name: str, chunk_size: int = 65536):
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
try:
async with self._get_client() as client:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']
data = await body.read()
# Yield in chunks to avoid holding entire response in StreamingResponse buffer
for i in range(0, len(data), chunk_size):
yield data[i:i + chunk_size]
except ClientError as e:
print(f"Error streaming from S3: {e}")
return
async def delete_file(self, object_name: str): async def delete_file(self, object_name: str):
"""Deletes a file from S3.""" """Deletes a file from S3."""
try: try:

30
aiws.py
View File

@@ -120,6 +120,17 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
gen_router.message.middleware(AlbumMiddleware(latency=0.8)) gen_router.message.middleware(AlbumMiddleware(latency=0.8))
async def start_scheduler(service: GenerationService):
while True:
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(600) # Check every 10 minutes
# --- LIFESPAN (Запуск FastAPI + Bot) --- # --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -151,17 +162,28 @@ async def lifespan(app: FastAPI):
# ) # )
# print("🤖 Bot polling started") # print("🤖 Bot polling started")
# 3. ЗАПУСК ШЕДУЛЕРА
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
print("⏰ Scheduler started")
yield yield
# --- SHUTDOWN --- # --- SHUTDOWN ---
print("🛑 Shutting down...") print("🛑 Shutting down...")
# 3. Остановка бота # 4. Остановка шедулера
polling_task.cancel() scheduler_task.cancel()
try: try:
await polling_task await scheduler_task
except asyncio.CancelledError: except asyncio.CancelledError:
print("🤖 Bot polling stopped") print("⏰ Scheduler stopped")
# 3. Остановка бота
# polling_task.cancel()
# try:
# await polling_task
# except asyncio.CancelledError:
# print("🤖 Bot polling stopped")
# 4. Отключение БД # 4. Отключение БД
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается # Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается

View File

@@ -9,7 +9,7 @@ from pymongo import MongoClient
from starlette import status from starlette import status
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import Response, JSONResponse, StreamingResponse
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models.AssetDTO import AssetsResponse, AssetResponse
@@ -33,27 +33,46 @@ async def get_asset(
asset_id: str, asset_id: str,
request: Request, request: Request,
thumbnail: bool = False, thumbnail: bool = False,
dao: DAO = Depends(get_dao) dao: DAO = Depends(get_dao),
s3_adapter: S3Adapter = Depends(get_s3_adapter),
) -> Response: ) -> Response:
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}") logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
asset = await dao.assets.get_asset(asset_id) # Загружаем только метаданные (без data/thumbnail bytes)
# 2. Проверка на существование asset = await dao.assets.get_asset(asset_id, with_data=False)
if not asset: if not asset:
raise HTTPException(status_code=404, detail="Asset not found") raise HTTPException(status_code=404, detail="Asset not found")
headers = { headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable" "Cache-Control": "public, max-age=31536000, immutable"
} }
content = asset.data # Thumbnail: маленький, можно грузить в RAM
media_type = "image/png" # Default, or detect if thumbnail:
if asset.minio_thumbnail_object_name and s3_adapter:
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
if thumb_bytes:
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
# Fallback: thumbnail in DB
if asset.thumbnail:
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
# No thumbnail available — fall through to main content
if thumbnail and asset.thumbnail: # Main content: стримим из S3 без загрузки в RAM
content = asset.thumbnail if asset.minio_object_name and s3_adapter:
media_type = "image/jpeg" content_type = "image/png"
# if asset.content_type == AssetContentType.VIDEO:
# content_type = "video/mp4"
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name),
media_type=content_type,
headers=headers,
)
return Response(content=content, media_type=media_type, headers=headers) # Fallback: data stored in DB (legacy)
if asset.data:
return Response(content=asset.data, media_type="image/png", headers=headers)
raise HTTPException(status_code=404, detail="Asset data not found")
@router.delete("/orphans", dependencies=[Depends(get_current_user)]) @router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio( async def delete_orphan_assets_from_minio(

View File

@@ -8,7 +8,7 @@ from api import service
from api.dependency import get_generation_service, get_project_id, get_dao from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Generation import Generation from models.Generation import Generation
@@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse) @router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service), generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id), project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse: dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id: if project_id:
@@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request,
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(request: Request,
@@ -113,6 +103,27 @@ async def get_running_generations(request: Request,
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.post("/import", response_model=GenerationResponse) @router.post("/import", response_model=GenerationResponse)

View File

@@ -1,7 +1,7 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from models.Asset import Asset from models.Asset import Asset
from models.Generation import GenerationStatus from models.Generation import GenerationStatus
@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None project_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10)
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):
@@ -45,10 +46,15 @@ class GenerationResponse(BaseModel):
progress: int = 0 progress: int = 0
cost: Optional[float] = None cost: Optional[float] = None
created_by: Optional[str] = None created_by: Optional[str] = None
generation_group_id: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str

View File

@@ -5,13 +5,14 @@ import base64
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO from io import BytesIO
from uuid import uuid4
import httpx import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно # Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
@@ -21,6 +22,9 @@ from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации --- # --- Вспомогательная функция генерации ---
async def generate_image_task( async def generate_image_task(
@@ -50,16 +54,18 @@ async def generate_image_task(
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e: except GoogleGenerationException as e:
raise e raise e
finally:
# Освобождаем входные данные — они больше не нужны
del media_group_bytes
images_bytes = [] images_bytes = []
if generated_images_io: if generated_images_io:
for img_io in generated_images_io: for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0) img_io.seek(0)
content = img_io.read() images_bytes.append(img_io.read())
images_bytes.append(content)
# Закрываем поток
img_io.close() img_io.close()
# Освобождаем список BytesIO сразу
del generated_images_io
return images_bytes, metrics return images_bytes, metrics
@@ -111,21 +117,37 @@ class GenerationService:
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None gen_id = None
generation_model = None generation_model = None
try: try:
generation_model = Generation(**generation_request.model_dump()) generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id: if user_id:
generation_model.created_by = user_id generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id
async def runner(gen): async def runner(gen):
logger.info(f"Starting background generation task for ID: {gen.id}") logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
try: try:
async with generation_semaphore:
logger.info(f"Starting background generation task for ID: {gen.id}")
await self.create_generation(gen) await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}") logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception: except Exception:
@@ -444,3 +466,14 @@ class GenerationService:
except Exception as e: except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}") logger.error(f"Error deleting generation {generation_id}: {e}")
return False return False
async def cleanup_stale_generations(self):
"""
Cancels generations that have been running for more than 1 hour.
"""
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")

View File

@@ -35,6 +35,7 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None album_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,4 +1,5 @@
from typing import Optional, List from typing import Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset from PIL.ImageChops import offset
from bson import ObjectId from bson import ObjectId
@@ -77,3 +78,28 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def cancel_stale_generations(self, timeout_minutes: int = 60) -> int:
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
res = await self.collection.update_many(
{
"status": GenerationStatus.RUNNING,
"created_at": {"$lt": cutoff_time}
},
{
"$set": {
"status": GenerationStatus.FAILED,
"failed_reason": "Timeout: Execution time limit exceeded",
"updated_at": datetime.now(UTC)
}
}
)
return res.modified_count

52
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,52 @@
import asyncio
import os
from datetime import datetime, timedelta, UTC
from motor.motor_asyncio import AsyncIOMotorClient
from models.Generation import Generation, GenerationStatus
from repos.generation_repo import GenerationRepo
from dotenv import load_dotenv
load_dotenv()
# Mock configs if not present in env
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_scheduler():
client = AsyncIOMotorClient(MONGO_HOST)
repo = GenerationRepo(client, db_name=DB_NAME)
# 1. Create a "stale" generation (2 hours ago)
stale_gen = Generation(
prompt="stale test",
status=GenerationStatus.RUNNING,
created_at=datetime.now(UTC) - timedelta(minutes=120),
assets_list=[],
aspect_ratio="NINESIXTEEN",
quality="ONEK"
)
gen_id = await repo.create_generation(stale_gen)
print(f"Created stale generation: {gen_id}")
# 2. Run cleanup
print("Running cleanup...")
count = await repo.cancel_stale_generations(timeout_minutes=60)
print(f"Cleaned up {count} generations")
# 3. Verify status
updated_gen = await repo.get_generation(gen_id)
print(f"Generation status: {updated_gen.status}")
print(f"Failed reason: {updated_gen.failed_reason}")
if updated_gen.status == GenerationStatus.FAILED:
print("✅ SUCCESS: Generation marked as FAILED")
else:
print("❌ FAILURE: Generation status not updated")
# Cleanup
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
if __name__ == "__main__":
asyncio.run(test_scheduler())