This commit is contained in:
xds
2026-02-15 10:26:01 +03:00
parent 2d3da59de9
commit 97483b7030
16 changed files with 245 additions and 4 deletions

View File

@@ -43,6 +43,7 @@ from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -210,6 +211,7 @@ app.include_router(api_char_router)
app.include_router(api_gen_router) app.include_router(api_gen_router)
app.include_router(api_album_router) app.include_router(api_album_router)
app.include_router(project_api_router) app.include_router(project_api_router)
app.include_router(idea_api_router)
# Prometheus Metrics (Instrument after all routers are added) # Prometheus Metrics (Instrument after all routers are added)
Instrumentator( Instrumentator(

View File

@@ -45,6 +45,11 @@ def get_generation_service(
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot)
from api.service.idea_service import IdeaService
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
return IdeaService(dao)
from fastapi import Header from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]: async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:

View File

@@ -0,0 +1,61 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from api.dependency import get_idea_service, get_current_user, get_project_id, get_generation_service
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
router = APIRouter(prefix="/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
name: str,
project_id: str = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
if not project_id:
raise HTTPException(status_code=400, detail="Project ID header is required")
return await idea_service.create_idea(name, project_id, str(current_user["_id"]))
@router.get("", response_model=List[Idea])
async def get_ideas(
project_id: str = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
idea_service: IdeaService = Depends(get_idea_service)
):
if not project_id:
raise HTTPException(status_code=400, detail="Project ID header is required")
return await idea_service.get_ideas(project_id, limit, offset)
@router.get("/{idea_id}", response_model=Idea)
async def get_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.get_idea(idea_id)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.delete("/{idea_id}")
async def delete_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.delete_idea(idea_id)
if not success:
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
return {"status": "success"}
@router.get("/{idea_id}/generations", response_model=List[GenerationResponse])
async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)
):
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)

View File

@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10) count: int = Field(default=1, ge=1, le=10)
@@ -47,6 +48,7 @@ class GenerationResponse(BaseModel):
cost: Optional[float] = None cost: Optional[float] = None
created_by: Optional[str] = None created_by: Optional[str] = None
generation_group_id: Optional[str] = None generation_group_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)

View File

@@ -137,6 +137,10 @@ class GenerationService:
if generation_group_id: if generation_group_id:
generation_model.generation_group_id = generation_group_id generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id

View File

@@ -0,0 +1,22 @@
from typing import List, Optional
from repos.dao import DAO
from models.Idea import Idea
class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, project_id: str, user_id: str) -> Idea:
idea = Idea(name=name, project_id=project_id, created_by=user_id)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
async def get_ideas(self, project_id: str, limit: int = 20, offset: int = 0) -> List[Idea]:
return await self.dao.ideas.get_ideas(project_id, limit, offset)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def delete_idea(self, idea_id: str) -> bool:
return await self.dao.ideas.delete_idea(idea_id)

View File

@@ -38,6 +38,7 @@ class Generation(BaseModel):
generation_group_id: Optional[str] = None generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

12
models/Idea.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: Optional[str] = None
name: str = "New Idea"
project_id: str
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)

View File

@@ -6,6 +6,7 @@ from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from typing import Optional from typing import Optional
@@ -19,3 +20,4 @@ class DAO:
self.albums = AlbumsRepo(client, db_name) self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name) self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name) self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)

View File

@@ -26,7 +26,7 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False} filter = {"is_deleted": False}
if character_id is not None: if character_id is not None:
@@ -35,11 +35,20 @@ class GenerationRepo:
filter["status"] = status filter["status"] = status
if created_by is not None: if created_by is not None:
filter["created_by"] = created_by filter["created_by"] = created_by
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
# But if project_id is passed, we filter by that.
if project_id is None:
filter["project_id"] = None filter["project_id"] = None
if project_id is not None: if project_id is not None:
filter["project_id"] = project_id filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
res = await self.collection.find(filter).sort("created_at", -1).skip( # If fetching for an idea, sort by created_at ascending (cronological)
# Otherwise typically descending (newest first)
sort_order = 1 if idea_id else -1
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
for generation in res: for generation in res:
@@ -48,7 +57,7 @@ class GenerationRepo:
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int: album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
@@ -58,6 +67,8 @@ class GenerationRepo:
args["created_by"] = created_by args["created_by"] = created_by
if project_id is not None: if project_id is not None:
args["project_id"] = project_id args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:

39
repos/idea_repo.py Normal file
View File

@@ -0,0 +1,39 @@
from typing import Optional, List
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Idea import Idea
class IdeaRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["ideas"]
async def create_idea(self, idea: Idea) -> str:
res = await self.collection.insert_one(idea.model_dump())
return str(res.inserted_id)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
if not ObjectId.is_valid(idea_id):
return None
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
if res:
res["id"] = str(res.pop("_id"))
return Idea(**res)
return None
async def get_ideas(self, project_id: str, limit: int = 20, offset: int = 0) -> List[Idea]:
filter = {"project_id": project_id, "is_deleted": False}
res = await self.collection.find(filter).sort("updated_at", -1).skip(offset).limit(limit).to_list(None)
ideas = []
for doc in res:
doc["id"] = str(doc.pop("_id"))
ideas.append(Idea(**doc))
return ideas
async def delete_idea(self, idea_id: str) -> bool:
if not ObjectId.is_valid(idea_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(idea_id)},
{"$set": {"is_deleted": True}}
)
return res.modified_count > 0

80
tests/test_idea.py Normal file
View File

@@ -0,0 +1,80 @@
import asyncio
import os
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from project root (requires PYTHONPATH=.)
from api.service.idea_service import IdeaService
from repos.dao import DAO
from models.Idea import Idea
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
load_dotenv()
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_idea_flow():
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
service = IdeaService(dao)
# 1. Create an Idea
print("Creating idea...")
user_id = "test_user_123"
project_id = "test_project_abc"
idea = await service.create_idea("My Test Idea", project_id, user_id)
print(f"Idea created: {idea.id} - {idea.name}")
# 2. Add Generation linked to Idea
print("Creating generation linked to idea...")
gen = Generation(
prompt="idea generation 1",
idea_id=idea.id,
project_id=project_id,
created_by=user_id,
aspect_ratio=AspectRatios.NINESIXTEEN,
quality=Quality.ONEK,
assets_list=[]
)
gen_id = await dao.generations.create_generation(gen)
print(f"Created linked generation: {gen_id}")
# Debug: Check if generation was saved with idea_id
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
# 3. Fetch Generations for Idea (Verify filtering and ordering)
print("Fetching generations for idea...")
gens = await dao.generations.get_generations(idea_id=idea.id)
print(f"Found {len(gens)} generations in idea")
if len(gens) == 1 and gens[0].id == gen_id:
print("✅ Generation retrieval successful")
else:
print("❌ Generation retrieval FAILED")
# 4. Fetch Ideas for Project
ideas = await service.get_ideas(project_id)
print(f"Found {len(ideas)} ideas for project")
# Cleaning up
print("Cleaning up...")
await service.delete_idea(idea.id)
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
# Verify deletion
deleted_idea = await service.get_idea(idea.id)
# IdeaRepo.delete_idea logic sets is_deleted=True
if deleted_idea and deleted_idea.is_deleted:
print(f"✅ Idea deleted successfully")
# Hard delete for cleanup
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
if __name__ == "__main__":
asyncio.run(test_idea_flow())