46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
from typing import Optional, List
|
|
|
|
from PIL.ImageChops import offset
|
|
from bson import ObjectId
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
|
|
from api.models.GenerationRequest import GenerationResponse
|
|
from models.Generation import Generation, GenerationStatus
|
|
|
|
|
|
class GenerationRepo:
|
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
|
self.collection = client[db_name]["generations"]
|
|
|
|
async def create_generation(self, generation: Generation) -> str:
|
|
res = await self.collection.insert_one(generation.model_dump())
|
|
return str(res.inserted_id)
|
|
|
|
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
|
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
|
if res is None:
|
|
return None
|
|
else:
|
|
res["id"] = str(res.pop("_id"))
|
|
return Generation(**res)
|
|
|
|
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
|
limit: int = 10, offset: int = 10) -> List[Generation]:
|
|
args = {}
|
|
if character_id is not None:
|
|
args["linked_character_id"] = character_id
|
|
else:
|
|
args["linked_character_id"] = None
|
|
if status is not None:
|
|
args["status"] = status
|
|
res = await self.collection.find(args).sort("created_at", -1).skip(
|
|
offset).limit(limit).to_list(None)
|
|
generations: List[Generation] = []
|
|
for generation in res:
|
|
generation["id"] = str(generation.pop("_id"))
|
|
generations.append(Generation(**generation))
|
|
return generations
|
|
|
|
async def update_generation(self, generation: Generation, ):
|
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|