Compare commits
2 Commits
c93e577bcf
...
198ac44960
| Author | SHA1 | Date | |
|---|---|---|---|
| 198ac44960 | |||
| d820d9145b |
2
aiws.py
2
aiws.py
@@ -44,6 +44,7 @@ from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
from api.endpoints.post_router import router as post_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -219,6 +220,7 @@ app.include_router(api_gen_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
app.include_router(post_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
|
||||
Binary file not shown.
@@ -57,4 +57,9 @@ async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Pro
|
||||
return x_project_id
|
||||
|
||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
||||
return AlbumService(dao)
|
||||
return AlbumService(dao)
|
||||
|
||||
from api.service.post_service import PostService
|
||||
|
||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||
return PostService(dao)
|
||||
99
api/endpoints/post_router.py
Normal file
99
api/endpoints/post_router.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.dependency import get_post_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.post_service import PostService
|
||||
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
from models.Post import Post
|
||||
|
||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
||||
|
||||
|
||||
@router.post("", response_model=Post)
|
||||
async def create_post(
|
||||
request: PostCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
return await post_service.create_post(
|
||||
date=request.date,
|
||||
topic=request.topic,
|
||||
generation_ids=request.generation_ids,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[Post])
|
||||
async def get_posts(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=Post)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=Post)
|
||||
async def update_post(
|
||||
post_id: str,
|
||||
request: PostUpdateRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.delete_post(post_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/generations")
|
||||
async def add_generations(
|
||||
post_id: str,
|
||||
request: AddGenerationsRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/generations/{generation_id}")
|
||||
async def remove_generation(
|
||||
post_id: str,
|
||||
generation_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.remove_generation(post_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
||||
return {"status": "success"}
|
||||
19
api/models/PostRequest.py
Normal file
19
api/models/PostRequest.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PostCreateRequest(BaseModel):
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = []
|
||||
project_id: Optional[str] = None
|
||||
|
||||
|
||||
class PostUpdateRequest(BaseModel):
|
||||
date: Optional[datetime] = None
|
||||
topic: Optional[str] = None
|
||||
|
||||
|
||||
class AddGenerationsRequest(BaseModel):
|
||||
generation_ids: List[str]
|
||||
79
api/service/post_service.py
Normal file
79
api/service/post_service.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from repos.dao import DAO
|
||||
from models.Post import Post
|
||||
|
||||
|
||||
class PostService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_post(
|
||||
self,
|
||||
date: datetime,
|
||||
topic: str,
|
||||
generation_ids: List[str],
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Post:
|
||||
post = Post(
|
||||
date=date,
|
||||
topic=topic,
|
||||
generation_ids=generation_ids,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
)
|
||||
post_id = await self.dao.posts.create_post(post)
|
||||
post.id = post_id
|
||||
return post
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
||||
|
||||
async def update_post(
|
||||
self,
|
||||
post_id: str,
|
||||
date: Optional[datetime] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> Optional[Post]:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return None
|
||||
|
||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
||||
if date is not None:
|
||||
updates["date"] = date
|
||||
if topic is not None:
|
||||
updates["topic"] = topic
|
||||
|
||||
await self.dao.posts.update_post(post_id, updates)
|
||||
|
||||
# Return refreshed post
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
return await self.dao.posts.delete_post(post_id)
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
||||
23
models/Post.py
Normal file
23
models/Post.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timezone, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
id: Optional[str] = None
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = Field(default_factory=list)
|
||||
project_id: Optional[str] = None
|
||||
created_by: str
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_tz_aware(self):
|
||||
for field in ("date", "created_at", "updated_at"):
|
||||
val = getattr(self, field)
|
||||
if val is not None and val.tzinfo is None:
|
||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
||||
return self
|
||||
Binary file not shown.
@@ -7,6 +7,7 @@ from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -21,3 +22,4 @@ class DAO:
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
self.posts = PostRepo(client, db_name)
|
||||
|
||||
97
repos/post_repo.py
Normal file
97
repos/post_repo.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Post import Post
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["posts"]
|
||||
|
||||
async def create_post(self, post: Post) -> str:
|
||||
res = await self.collection.insert_one(post.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Post(**res)
|
||||
return None
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
if project_id:
|
||||
match = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
if date_from or date_to:
|
||||
date_filter = {}
|
||||
if date_from:
|
||||
date_filter["$gte"] = date_from
|
||||
if date_to:
|
||||
date_filter["$lte"] = date_to
|
||||
match["date"] = date_filter
|
||||
|
||||
cursor = (
|
||||
self.collection.find(match)
|
||||
.sort("date", -1)
|
||||
.skip(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
posts = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
posts.append(Post(**doc))
|
||||
return posts
|
||||
|
||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": data},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$pull": {"generation_ids": generation_id}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
Reference in New Issue
Block a user