feat: Implement cancellation of stale generations in the service and repository, along with a new test.
This commit is contained in:
32
aiws.py
32
aiws.py
@@ -120,6 +120,17 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
|
|||||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||||
|
|
||||||
|
|
||||||
|
async def start_scheduler(service: GenerationService):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.info("Running scheduler for stacked generation killing")
|
||||||
|
await service.cleanup_stale_generations()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scheduler error: {e}")
|
||||||
|
await asyncio.sleep(300) # Check every 5 minutes
|
||||||
|
|
||||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -151,17 +162,28 @@ async def lifespan(app: FastAPI):
|
|||||||
# )
|
# )
|
||||||
# print("🤖 Bot polling started")
|
# print("🤖 Bot polling started")
|
||||||
|
|
||||||
|
# 3. ЗАПУСК ШЕДУЛЕРА
|
||||||
|
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||||
|
print("⏰ Scheduler started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# --- SHUTDOWN ---
|
# --- SHUTDOWN ---
|
||||||
print("🛑 Shutting down...")
|
print("🛑 Shutting down...")
|
||||||
|
|
||||||
# 3. Остановка бота
|
# 4. Остановка шедулера
|
||||||
polling_task.cancel()
|
scheduler_task.cancel()
|
||||||
try:
|
try:
|
||||||
await polling_task
|
await scheduler_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("🤖 Bot polling stopped")
|
print("⏰ Scheduler stopped")
|
||||||
|
|
||||||
|
# 3. Остановка бота
|
||||||
|
# polling_task.cancel()
|
||||||
|
# try:
|
||||||
|
# await polling_task
|
||||||
|
# except asyncio.CancelledError:
|
||||||
|
# print("🤖 Bot polling stopped")
|
||||||
|
|
||||||
# 4. Отключение БД
|
# 4. Отключение БД
|
||||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||||
|
|||||||
Binary file not shown.
@@ -460,4 +460,15 @@ class GenerationService:
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting generation {generation_id}: {e}")
|
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def cleanup_stale_generations(self):
|
||||||
|
"""
|
||||||
|
Cancels generations that have been running for more than 1 hour.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"Cleaned up {count} stale generations (timeout)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up stale generations: {e}")
|
||||||
Binary file not shown.
@@ -1,4 +1,5 @@
|
|||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
from datetime import datetime, timedelta, UTC
|
||||||
|
|
||||||
from PIL.ImageChops import offset
|
from PIL.ImageChops import offset
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
@@ -85,3 +86,20 @@ class GenerationRepo:
|
|||||||
generation["id"] = str(generation.pop("_id"))
|
generation["id"] = str(generation.pop("_id"))
|
||||||
generations.append(Generation(**generation))
|
generations.append(Generation(**generation))
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
|
async def cancel_stale_generations(self, timeout_minutes: int = 60) -> int:
|
||||||
|
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||||
|
res = await self.collection.update_many(
|
||||||
|
{
|
||||||
|
"status": GenerationStatus.RUNNING,
|
||||||
|
"created_at": {"$lt": cutoff_time}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"status": GenerationStatus.FAILED,
|
||||||
|
"failed_reason": "Timeout: Execution time limit exceeded",
|
||||||
|
"updated_at": datetime.now(UTC)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return res.modified_count
|
||||||
|
|||||||
52
tests/test_scheduler.py
Normal file
52
tests/test_scheduler.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from datetime import datetime, timedelta, UTC
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from models.Generation import Generation, GenerationStatus
|
||||||
|
from repos.generation_repo import GenerationRepo
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Mock configs if not present in env
|
||||||
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||||
|
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
||||||
|
|
||||||
|
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||||
|
|
||||||
|
async def test_scheduler():
|
||||||
|
client = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
repo = GenerationRepo(client, db_name=DB_NAME)
|
||||||
|
|
||||||
|
# 1. Create a "stale" generation (2 hours ago)
|
||||||
|
stale_gen = Generation(
|
||||||
|
prompt="stale test",
|
||||||
|
status=GenerationStatus.RUNNING,
|
||||||
|
created_at=datetime.now(UTC) - timedelta(minutes=120),
|
||||||
|
assets_list=[],
|
||||||
|
aspect_ratio="NINESIXTEEN",
|
||||||
|
quality="ONEK"
|
||||||
|
)
|
||||||
|
gen_id = await repo.create_generation(stale_gen)
|
||||||
|
print(f"Created stale generation: {gen_id}")
|
||||||
|
|
||||||
|
# 2. Run cleanup
|
||||||
|
print("Running cleanup...")
|
||||||
|
count = await repo.cancel_stale_generations(timeout_minutes=60)
|
||||||
|
print(f"Cleaned up {count} generations")
|
||||||
|
|
||||||
|
# 3. Verify status
|
||||||
|
updated_gen = await repo.get_generation(gen_id)
|
||||||
|
print(f"Generation status: {updated_gen.status}")
|
||||||
|
print(f"Failed reason: {updated_gen.failed_reason}")
|
||||||
|
|
||||||
|
if updated_gen.status == GenerationStatus.FAILED:
|
||||||
|
print("✅ SUCCESS: Generation marked as FAILED")
|
||||||
|
else:
|
||||||
|
print("❌ FAILURE: Generation status not updated")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_scheduler())
|
||||||
Reference in New Issue
Block a user