from typing import Optional, List from PIL.ImageChops import offset from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from api.models.GenerationRequest import GenerationResponse from models.Generation import Generation, GenerationStatus class GenerationRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["generations"] async def create_generation(self, generation: Generation) -> str: res = await self.collection.insert_one(generation.model_dump()) return str(res.inserted_id) async def get_generation(self, generation_id: str) -> Optional[Generation]: res = await self.collection.find_one({"_id": ObjectId(generation_id)}) if res is None: return None else: res["id"] = str(res.pop("_id")) return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, limit: int = 10, offset: int = 10) -> List[Generation]: filter = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status res = await self.collection.find(filter).sort("created_at", -1).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status return await self.collection.count_documents(args) async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)] res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None) generations: List[Generation] = [] # Maintain order of generation_ids gen_map = {str(doc["_id"]): doc for doc in res} for gen_id in generation_ids: doc = gen_map.get(gen_id) if doc: doc["id"] = str(doc.pop("_id")) generations.append(Generation(**doc)) return generations async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})