Files
ai-char-bot/repos/generation_repo.py

69 lines
2.8 KiB
Python

from typing import Optional, List
from PIL.ImageChops import offset
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from api.models.GenerationRequest import GenerationResponse
from models.Generation import Generation, GenerationStatus
class GenerationRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["generations"]
async def create_generation(self, generation: Generation) -> str:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
else:
res["id"] = str(res.pop("_id"))
return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10) -> List[Generation]:
filter = {"is_deleted": False}
if character_id is not None:
filter["linked_character_id"] = character_id
if status is not None:
filter["status"] = status
res = await self.collection.find(filter).sort("created_at", -1).skip(
offset).limit(limit).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int:
args = {}
if character_id is not None:
args["linked_character_id"] = character_id
if status is not None:
args["status"] = status
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
generations: List[Generation] = []
# Maintain order of generation_ids
gen_map = {str(doc["_id"]): doc for doc in res}
for gen_id in generation_ids:
doc = gen_map.get(gen_id)
if doc:
doc["id"] = str(doc.pop("_id"))
generations.append(Generation(**doc))
return generations
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})