This commit is contained in:
xds
2026-02-04 15:10:55 +03:00
parent 11c1f4f7dc
commit 35de8efc56
20 changed files with 566 additions and 135 deletions

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
@@ -10,10 +10,10 @@ class AssetsRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["assets"]
async def save_asset(self, asset: Asset) -> Asset:
async def create_asset(self, asset: Asset) -> str:
res = await self.collection.insert_one(asset.model_dump())
asset.id = res.inserted_id
return asset
return str(res.inserted_id)
async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]:
res = await self.collection.find({}, {"data": 0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
@@ -27,6 +27,7 @@ class AssetsRepo:
return assets
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
projection = {"_id": 1, "name": 1, "type": 1, "tg_doc_file_id": 1}
if with_data:
@@ -54,3 +55,17 @@ class AssetsRepo:
doc["id"] = str(doc.pop("_id"))
assets.append(Asset(**doc))
return assets
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
res = self.collection.find({"_id": {"$in": object_ids}})
assets = []
async for doc in res:
doc["id"] = str(doc.pop("_id"))
assets.append(Asset(**doc))
return assets

View File

@@ -2,6 +2,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from repos.assets_repo import AssetsRepo
from repos.char_repo import CharacterRepo
from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo
@@ -9,3 +10,4 @@ class DAO:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.chars = CharacterRepo(client, db_name)
self.assets = AssetsRepo(client, db_name)
self.generations = GenerationRepo(client, db_name)

43
repos/generation_repo.py Normal file
View File

@@ -0,0 +1,43 @@
from typing import Optional, List
from PIL.ImageChops import offset
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from api.models.GenerationRequest import GenerationResponse
from models.Generation import Generation, GenerationStatus
class GenerationRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["generations"]
async def create_generation(self, generation: Generation) -> str:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
else:
res["id"] = str(res.pop("_id"))
return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10) -> List[Generation]:
args = {}
if character_id is not None:
args["character_id"] = character_id
if status is not None:
args["status"] = status
res = await self.collection.find(args).sort("created_at", -1).skip(
offset).limit(limit).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})