from typing import Any, Optional, List from datetime import datetime, timedelta, UTC from PIL.ImageChops import offset from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from api.models.GenerationRequest import GenerationResponse from models.Generation import Generation, GenerationStatus class GenerationRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["generations"] async def create_generation(self, generation: Generation) -> str: res = await self.collection.insert_one(generation.model_dump()) return str(res.inserted_id) async def get_generation(self, generation_id: str) -> Generation | None: res = await self.collection.find_one({"_id": ObjectId(generation_id)}) if res is None: return None else: res["id"] = str(res.pop("_id")) return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]: filter: dict[str, Any] = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status if created_by is not None: filter["created_by"] = created_by # If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None. # But if project_id is passed, we filter by that. if project_id is None: filter["project_id"] = None if project_id is not None: filter["project_id"] = project_id if idea_id is not None: filter["idea_id"] = idea_id # If fetching for an idea, sort by created_at ascending (cronological) # Otherwise typically descending (newest first) sort_order = 1 if idea_id else -1 res = await self.collection.find(filter).sort("created_at", sort_order).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status if created_by is not None: args["created_by"] = created_by if project_id is not None: args["project_id"] = project_id if idea_id is not None: args["idea_id"] = idea_id if album_id is not None: args["album_id"] = album_id return await self.collection.count_documents(args) async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)] res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None) generations: List[Generation] = [] # Maintain order of generation_ids gen_map = {str(doc["_id"]): doc for doc in res} for gen_id in generation_ids: doc = gen_map.get(gen_id) if doc: doc["id"] = str(doc.pop("_id")) generations.append(Generation(**doc)) return generations async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) async def get_generations_by_group(self, group_id: str) -> List[Generation]: res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int: cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes) res = await self.collection.update_many( { "status": GenerationStatus.RUNNING, "created_at": {"$lt": cutoff_time} }, { "$set": { "status": GenerationStatus.FAILED, "failed_reason": "Timeout: Execution time limit exceeded", "updated_at": datetime.now(UTC) } } ) return res.modified_count async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]: """ Мягко удаляет генерации старше N дней. Возвращает (количество удалённых, список asset IDs для очистки). """ cutoff_time = datetime.now(UTC) - timedelta(days=days) filter_query = { "is_deleted": False, "status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]}, "created_at": {"$lt": cutoff_time} } # Сначала собираем asset IDs из удаляемых генераций asset_ids: List[str] = [] cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1}) async for doc in cursor: asset_ids.extend(doc.get("result_list", [])) asset_ids.extend(doc.get("assets_list", [])) # Мягкое удаление res = await self.collection.update_many( filter_query, { "$set": { "is_deleted": True, "updated_at": datetime.now(UTC) } } ) # Убираем дубликаты unique_asset_ids = list(set(asset_ids)) return res.modified_count, unique_asset_ids