98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
from typing import List, Optional
|
|
from datetime import datetime
|
|
import logging
|
|
from bson import ObjectId
|
|
from motor.motor_asyncio import AsyncIOMotorClient
|
|
|
|
from models.Post import Post
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PostRepo:
|
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
|
self.collection = client[db_name]["posts"]
|
|
|
|
async def create_post(self, post: Post) -> str:
|
|
res = await self.collection.insert_one(post.model_dump())
|
|
return str(res.inserted_id)
|
|
|
|
async def get_post(self, post_id: str) -> Optional[Post]:
|
|
if not ObjectId.is_valid(post_id):
|
|
return None
|
|
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
|
if res:
|
|
res["id"] = str(res.pop("_id"))
|
|
return Post(**res)
|
|
return None
|
|
|
|
async def get_posts(
|
|
self,
|
|
project_id: Optional[str],
|
|
user_id: str,
|
|
limit: int = 20,
|
|
offset: int = 0,
|
|
date_from: Optional[datetime] = None,
|
|
date_to: Optional[datetime] = None,
|
|
) -> List[Post]:
|
|
if project_id:
|
|
match = {"project_id": project_id, "is_deleted": False}
|
|
else:
|
|
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
|
|
|
if date_from or date_to:
|
|
date_filter = {}
|
|
if date_from:
|
|
date_filter["$gte"] = date_from
|
|
if date_to:
|
|
date_filter["$lte"] = date_to
|
|
match["date"] = date_filter
|
|
|
|
cursor = (
|
|
self.collection.find(match)
|
|
.sort("date", -1)
|
|
.skip(offset)
|
|
.limit(limit)
|
|
)
|
|
posts = []
|
|
async for doc in cursor:
|
|
doc["id"] = str(doc.pop("_id"))
|
|
posts.append(Post(**doc))
|
|
return posts
|
|
|
|
async def update_post(self, post_id: str, data: dict) -> bool:
|
|
if not ObjectId.is_valid(post_id):
|
|
return False
|
|
res = await self.collection.update_one(
|
|
{"_id": ObjectId(post_id)},
|
|
{"$set": data},
|
|
)
|
|
return res.modified_count > 0
|
|
|
|
async def delete_post(self, post_id: str) -> bool:
|
|
if not ObjectId.is_valid(post_id):
|
|
return False
|
|
res = await self.collection.update_one(
|
|
{"_id": ObjectId(post_id)},
|
|
{"$set": {"is_deleted": True}},
|
|
)
|
|
return res.modified_count > 0
|
|
|
|
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
|
if not ObjectId.is_valid(post_id):
|
|
return False
|
|
res = await self.collection.update_one(
|
|
{"_id": ObjectId(post_id)},
|
|
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
|
)
|
|
return res.modified_count > 0
|
|
|
|
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
|
if not ObjectId.is_valid(post_id):
|
|
return False
|
|
res = await self.collection.update_one(
|
|
{"_id": ObjectId(post_id)},
|
|
{"$pull": {"generation_ids": generation_id}},
|
|
)
|
|
return res.modified_count > 0
|