from typing import Optional, List from datetime import datetime, timedelta, UTC from PIL.ImageChops import offset from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from api.models.GenerationRequest import GenerationResponse from models.Generation import Generation, GenerationStatus class GenerationRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["generations"] async def create_generation(self, generation: Generation) -> str: res = await self.collection.insert_one(generation.model_dump()) return str(res.inserted_id) async def get_generation(self, generation_id: str) -> Optional[Generation]: res = await self.collection.find_one({"_id": ObjectId(generation_id)}) if res is None: return None else: res["id"] = str(res.pop("_id")) return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]: filter = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status if created_by is not None: filter["created_by"] = created_by # If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None. # But if project_id is passed, we filter by that. if project_id is None: filter["project_id"] = None if project_id is not None: filter["project_id"] = project_id if idea_id is not None: filter["idea_id"] = idea_id # If fetching for an idea, sort by created_at ascending (cronological) # Otherwise typically descending (newest first) sort_order = 1 if idea_id else -1 res = await self.collection.find(filter).sort("created_at", sort_order).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status if created_by is not None: args["created_by"] = created_by if project_id is not None: args["project_id"] = project_id if idea_id is not None: args["idea_id"] = idea_id return await self.collection.count_documents(args) async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)] res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None) generations: List[Generation] = [] # Maintain order of generation_ids gen_map = {str(doc["_id"]): doc for doc in res} for gen_id in generation_ids: doc = gen_map.get(gen_id) if doc: doc["id"] = str(doc.pop("_id")) generations.append(Generation(**doc)) return generations async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) async def get_generations_by_group(self, group_id: str) -> List[Generation]: res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def cancel_stale_generations(self, timeout_minutes: int = 60) -> int: cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes) res = await self.collection.update_many( { "status": GenerationStatus.RUNNING, "created_at": {"$lt": cutoff_time} }, { "$set": { "status": GenerationStatus.FAILED, "failed_reason": "Timeout: Execution time limit exceeded", "updated_at": datetime.now(UTC) } } ) return res.modified_count