from typing import Optional, List from PIL.ImageChops import offset from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from api.models.GenerationRequest import GenerationResponse from models.Generation import Generation, GenerationStatus class GenerationRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["generations"] async def create_generation(self, generation: Generation) -> str: res = await self.collection.insert_one(generation.model_dump()) return str(res.inserted_id) async def get_generation(self, generation_id: str) -> Optional[Generation]: res = await self.collection.find_one({"_id": ObjectId(generation_id)}) if res is None: return None else: res["id"] = str(res.pop("_id")) return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, limit: int = 10, offset: int = 10) -> List[Generation]: filter = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status res = await self.collection.find(filter).sort("created_at", -1).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status return await self.collection.count_documents(args) async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})