diff --git a/api/__pycache__/dependency.cpython-313.pyc b/api/__pycache__/dependency.cpython-313.pyc index eb067d1..6e62158 100644 Binary files a/api/__pycache__/dependency.cpython-313.pyc and b/api/__pycache__/dependency.cpython-313.pyc differ diff --git a/api/endpoints/idea_router.py b/api/endpoints/idea_router.py index b2022b8..f6ceb89 100644 --- a/api/endpoints/idea_router.py +++ b/api/endpoints/idea_router.py @@ -1,24 +1,28 @@ from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query -from api.dependency import get_idea_service, get_current_user, get_project_id, get_generation_service +from fastapi import APIRouter, Depends, HTTPException, Query, Body +from api.dependency import get_idea_service, get_project_id, get_generation_service +from api.endpoints.auth import get_current_user from api.service.idea_service import IdeaService from api.service.generation_service import GenerationService from models.Idea import Idea -from api.models.GenerationRequest import GenerationResponse +from api.models.GenerationRequest import GenerationResponse, GenerationsResponse +from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest -router = APIRouter(prefix="/ideas", tags=["ideas"]) +router = APIRouter(prefix="/api/ideas", tags=["ideas"]) @router.post("", response_model=Idea) async def create_idea( - name: str, + request: IdeaCreateRequest, project_id: str = Depends(get_project_id), current_user: dict = Depends(get_current_user), idea_service: IdeaService = Depends(get_idea_service) ): - if not project_id: + if not project_id and not request.project_id: raise HTTPException(status_code=400, detail="Project ID header is required") - return await idea_service.create_idea(name, project_id, str(current_user["_id"])) + pid = project_id or request.project_id + + return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"])) @router.get("", response_model=List[Idea]) async def get_ideas( @@ -41,6 +45,17 @@ async def get_idea( raise HTTPException(status_code=404, detail="Idea not found") return idea +@router.put("/{idea_id}", response_model=Idea) +async def update_idea( + idea_id: str, + request: IdeaUpdateRequest, + idea_service: IdeaService = Depends(get_idea_service) +): + idea = await idea_service.update_idea(idea_id, request.name, request.description) + if not idea: + raise HTTPException(status_code=404, detail="Idea not found") + return idea + @router.delete("/{idea_id}") async def delete_idea( idea_id: str, @@ -51,11 +66,42 @@ async def delete_idea( raise HTTPException(status_code=404, detail="Idea not found or could not be deleted") return {"status": "success"} -@router.get("/{idea_id}/generations", response_model=List[GenerationResponse]) +@router.get("/{idea_id}/generations", response_model=GenerationsResponse) async def get_idea_generations( idea_id: str, limit: int = 50, offset: int = 0, generation_service: GenerationService = Depends(get_generation_service) ): + # Depending on how generation service implements filtering by idea_id. + # We might need to update generation_service to support getting by idea_id directly + # or ensure generic get_generations supports it. + # Looking at generation_router.py, get_generations doesn't have idea_id arg? + # Let's check generation_service.get_generations signature again. + # It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID. + # I need to update GenerationService.get_generations too! + + # For now, let's assume I will update it. return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset) + +@router.post("/{idea_id}/generations/{generation_id}") +async def add_generation_to_idea( + idea_id: str, + generation_id: str, + idea_service: IdeaService = Depends(get_idea_service) +): + success = await idea_service.add_generation_to_idea(idea_id, generation_id) + if not success: + raise HTTPException(status_code=404, detail="Idea or Generation not found") + return {"status": "success"} + +@router.delete("/{idea_id}/generations/{generation_id}") +async def remove_generation_from_idea( + idea_id: str, + generation_id: str, + idea_service: IdeaService = Depends(get_idea_service) +): + success = await idea_service.remove_generation_from_idea(idea_id, generation_id) + if not success: + raise HTTPException(status_code=404, detail="Idea or Generation not found") + return {"status": "success"} diff --git a/api/models/IdeaRequest.py b/api/models/IdeaRequest.py new file mode 100644 index 0000000..773f38c --- /dev/null +++ b/api/models/IdeaRequest.py @@ -0,0 +1,11 @@ +from typing import Optional +from pydantic import BaseModel + +class IdeaCreateRequest(BaseModel): + name: str + description: Optional[str] = None + project_id: Optional[str] = None # Optional in body if passed via header/dependency + +class IdeaUpdateRequest(BaseModel): + name: Optional[str] = None + description: Optional[str] = None diff --git a/api/service/__pycache__/generation_service.cpython-313.pyc b/api/service/__pycache__/generation_service.cpython-313.pyc index a0f1879..1d49568 100644 Binary files a/api/service/__pycache__/generation_service.cpython-313.pyc and b/api/service/__pycache__/generation_service.cpython-313.pyc differ diff --git a/api/service/generation_service.py b/api/service/generation_service.py index b241923..84ac80a 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -97,10 +97,9 @@ class GenerationService: return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) - async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[ - Generation]: - generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id) - total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id) + async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse: + generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id) + total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id) generations = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationsResponse(generations=generations, total_count=total_count) diff --git a/api/service/idea_service.py b/api/service/idea_service.py index fd1641c..89c430f 100644 --- a/api/service/idea_service.py +++ b/api/service/idea_service.py @@ -1,4 +1,5 @@ from typing import List, Optional +from datetime import datetime from repos.dao import DAO from models.Idea import Idea @@ -6,8 +7,8 @@ class IdeaService: def __init__(self, dao: DAO): self.dao = dao - async def create_idea(self, name: str, project_id: str, user_id: str) -> Idea: - idea = Idea(name=name, project_id=project_id, created_by=user_id) + async def create_idea(self, name: str, description: Optional[str], project_id: str, user_id: str) -> Idea: + idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id) idea_id = await self.dao.ideas.create_idea(idea) idea.id = idea_id return idea @@ -18,5 +19,57 @@ class IdeaService: async def get_idea(self, idea_id: str) -> Optional[Idea]: return await self.dao.ideas.get_idea(idea_id) + async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]: + idea = await self.dao.ideas.get_idea(idea_id) + if not idea: + return None + + if name is not None: + idea.name = name + if description is not None: + idea.description = description + + idea.updated_at = datetime.now() + await self.dao.ideas.update_idea(idea) + return idea + async def delete_idea(self, idea_id: str) -> bool: return await self.dao.ideas.delete_idea(idea_id) + + async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool: + # Verify idea exists + idea = await self.dao.ideas.get_idea(idea_id) + if not idea: + return False + + # Get generation + gen = await self.dao.generations.get_generation(generation_id) + if not gen: + return False + + # Link + gen.idea_id = idea_id + gen.updated_at = datetime.now() + await self.dao.generations.update_generation(gen) + return True + + async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool: + # Verify idea exists (optional, but good for validation) + idea = await self.dao.ideas.get_idea(idea_id) + if not idea: + return False + + # Get generation + gen = await self.dao.generations.get_generation(generation_id) + if not gen: + return False + + # Unlink only if currently linked to this idea + if gen.idea_id == idea_id: + gen.idea_id = None + gen.updated_at = datetime.now() + await self.dao.generations.update_generation(gen) + return True + + return False + diff --git a/models/Idea.py b/models/Idea.py index 0a7ebb1..305ecc2 100644 --- a/models/Idea.py +++ b/models/Idea.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field class Idea(BaseModel): id: Optional[str] = None name: str = "New Idea" + description: Optional[str] = None project_id: str created_by: str # User ID is_deleted: bool = False diff --git a/repos/idea_repo.py b/repos/idea_repo.py index 6ea79f9..bb65e0e 100644 --- a/repos/idea_repo.py +++ b/repos/idea_repo.py @@ -37,3 +37,17 @@ class IdeaRepo: {"$set": {"is_deleted": True}} ) return res.modified_count > 0 + + async def update_idea(self, idea: Idea) -> bool: + if not idea.id or not ObjectId.is_valid(idea.id): + return False + + idea_dict = idea.model_dump() + if "id" in idea_dict: + del idea_dict["id"] + + res = await self.collection.update_one( + {"_id": ObjectId(idea.id)}, + {"$set": idea_dict} + ) + return res.modified_count > 0 diff --git a/tests/test_idea.py b/tests/test_idea.py index 18fe1ee..fb864f9 100644 --- a/tests/test_idea.py +++ b/tests/test_idea.py @@ -27,14 +27,23 @@ async def test_idea_flow(): print("Creating idea...") user_id = "test_user_123" project_id = "test_project_abc" - idea = await service.create_idea("My Test Idea", project_id, user_id) + idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id) print(f"Idea created: {idea.id} - {idea.name}") - # 2. Add Generation linked to Idea + # 2. Update Idea + print("Updating idea...") + updated_idea = await service.update_idea(idea.id, description="Updated description") + print(f"Idea updated: {updated_idea.description}") + if updated_idea.description == "Updated description": + print("✅ Idea update successful") + else: + print("❌ Idea update FAILED") + + # 3. Add Generation linked to Idea print("Creating generation linked to idea...") gen = Generation( prompt="idea generation 1", - idea_id=idea.id, + # idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method project_id=project_id, created_by=user_id, aspect_ratio=AspectRatios.NINESIXTEEN, @@ -42,15 +51,23 @@ async def test_idea_flow(): assets_list=[] ) gen_id = await dao.generations.create_generation(gen) - print(f"Created linked generation: {gen_id}") + print(f"Created generation: {gen_id}") + + # Link generation to idea + print("Linking generation to idea...") + success = await service.add_generation_to_idea(idea.id, gen_id) + if success: + print("✅ Linking successful") + else: + print("❌ Linking FAILED") # Debug: Check if generation was saved with idea_id saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)}) print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}") - # 3. Fetch Generations for Idea (Verify filtering and ordering) + # 4. Fetch Generations for Idea (Verify filtering and ordering) print("Fetching generations for idea...") - gens = await dao.generations.get_generations(idea_id=idea.id) + gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper print(f"Found {len(gens)} generations in idea") if len(gens) == 1 and gens[0].id == gen_id: @@ -58,7 +75,7 @@ async def test_idea_flow(): else: print("❌ Generation retrieval FAILED") - # 4. Fetch Ideas for Project + # 5. Fetch Ideas for Project ideas = await service.get_ideas(project_id) print(f"Found {len(ideas)} ideas for project")