Compare commits
1 Commits
ideas
...
e2c050515d
| Author | SHA1 | Date | |
|---|---|---|---|
| e2c050515d |
4
aiws.py
4
aiws.py
@@ -43,7 +43,6 @@ from api.endpoints.auth import router as api_auth_router
|
||||
from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -130,7 +129,7 @@ async def start_scheduler(service: GenerationService):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler error: {e}")
|
||||
await asyncio.sleep(60) # Check every 10 minutes
|
||||
await asyncio.sleep(600) # Check every 10 minutes
|
||||
|
||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||
@asynccontextmanager
|
||||
@@ -211,7 +210,6 @@ app.include_router(api_char_router)
|
||||
app.include_router(api_gen_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
|
||||
Binary file not shown.
@@ -45,11 +45,6 @@ def get_generation_service(
|
||||
) -> GenerationService:
|
||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
||||
|
||||
from api.service.idea_service import IdeaService
|
||||
|
||||
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
|
||||
return IdeaService(dao)
|
||||
|
||||
from fastapi import Header
|
||||
|
||||
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.idea_service import IdeaService
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
|
||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
|
||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
||||
|
||||
@router.post("", response_model=Idea)
|
||||
async def create_idea(
|
||||
request: IdeaCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
|
||||
|
||||
@router.get("", response_model=List[IdeaResponse])
|
||||
async def get_ideas(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
|
||||
|
||||
@router.get("/{idea_id}", response_model=Idea)
|
||||
async def get_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.get_idea(idea_id)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.put("/{idea_id}", response_model=Idea)
|
||||
async def update_idea(
|
||||
idea_id: str,
|
||||
request: IdeaUpdateRequest,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.update_idea(idea_id, request.name, request.description)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.delete("/{idea_id}")
|
||||
async def delete_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.delete_idea(idea_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
|
||||
async def get_idea_generations(
|
||||
idea_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service)
|
||||
):
|
||||
# Depending on how generation service implements filtering by idea_id.
|
||||
# We might need to update generation_service to support getting by idea_id directly
|
||||
# or ensure generic get_generations supports it.
|
||||
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
|
||||
# Let's check generation_service.get_generations signature again.
|
||||
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
|
||||
# I need to update GenerationService.get_generations too!
|
||||
|
||||
# For now, let's assume I will update it.
|
||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
|
||||
|
||||
@router.post("/{idea_id}/generations/{generation_id}")
|
||||
async def add_generation_to_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.delete("/{idea_id}/generations/{generation_id}")
|
||||
async def remove_generation_from_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
@@ -17,7 +17,6 @@ class GenerationRequest(BaseModel):
|
||||
use_profile_image: bool = True
|
||||
assets_list: List[str]
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
|
||||
@@ -48,7 +47,6 @@ class GenerationResponse(BaseModel):
|
||||
cost: Optional[float] = None
|
||||
created_by: Optional[str] = None
|
||||
generation_group_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
|
||||
class IdeaCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[str] = None # Optional in body if passed via header/dependency
|
||||
|
||||
class IdeaUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class IdeaResponse(Idea):
|
||||
last_generation: Optional[GenerationResponse] = None
|
||||
Binary file not shown.
Binary file not shown.
@@ -22,6 +22,9 @@ from adapters.s3_adapter import S3Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limit concurrent generations to 4
|
||||
generation_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
@@ -97,9 +100,10 @@ class GenerationService:
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
|
||||
Generation]:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
|
||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||
|
||||
@@ -136,18 +140,16 @@ class GenerationService:
|
||||
if generation_group_id:
|
||||
generation_model.generation_group_id = generation_group_id
|
||||
|
||||
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
|
||||
if generation_request.idea_id:
|
||||
generation_model.idea_id = generation_request.idea_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
async with generation_semaphore:
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
|
||||
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
|
||||
idea_id = await self.dao.ideas.create_idea(idea)
|
||||
idea.id = idea_id
|
||||
return idea
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
return await self.dao.ideas.get_idea(idea_id)
|
||||
|
||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
idea.name = name
|
||||
if description is not None:
|
||||
idea.description = description
|
||||
|
||||
idea.updated_at = datetime.now()
|
||||
await self.dao.ideas.update_idea(idea)
|
||||
return idea
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
return await self.dao.ideas.delete_idea(idea_id)
|
||||
|
||||
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Link
|
||||
gen.idea_id = idea_id
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists (optional, but good for validation)
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Unlink only if currently linked to this idea
|
||||
if gen.idea_id == idea_id:
|
||||
gen.idea_id = None
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -38,7 +38,6 @@ class Generation(BaseModel):
|
||||
generation_group_id: Optional[str] = None
|
||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Idea(BaseModel):
|
||||
id: Optional[str] = None
|
||||
name: str = "New Idea"
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
created_by: str # User ID
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,7 +1,6 @@
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from uuid import uuid4
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Asset import Asset
|
||||
@@ -20,8 +19,7 @@ class AssetsRepo:
|
||||
# Main data
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
|
||||
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||
if uploaded:
|
||||
@@ -34,8 +32,7 @@ class AssetsRepo:
|
||||
# Thumbnail
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
|
||||
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||
if uploaded_thumb:
|
||||
@@ -137,8 +134,7 @@ class AssetsRepo:
|
||||
if self.s3:
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
if await self.s3.upload_file(object_name, asset.data):
|
||||
asset.minio_object_name = object_name
|
||||
asset.minio_bucket = self.s3.bucket_name
|
||||
@@ -146,8 +142,7 @@ class AssetsRepo:
|
||||
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||
asset.minio_thumbnail_object_name = thumb_name
|
||||
asset.thumbnail = None
|
||||
@@ -221,8 +216,7 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
|
||||
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||
if await self.s3.upload_file(object_name, data):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
@@ -249,8 +243,7 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, thumb):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
|
||||
@@ -6,7 +6,6 @@ from repos.generation_repo import GenerationRepo
|
||||
from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -20,4 +19,3 @@ class DAO:
|
||||
self.albums = AlbumsRepo(client, db_name)
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
|
||||
@@ -26,7 +26,7 @@ class GenerationRepo:
|
||||
return Generation(**res)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
|
||||
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
|
||||
filter = {"is_deleted": False}
|
||||
if character_id is not None:
|
||||
@@ -35,20 +35,11 @@ class GenerationRepo:
|
||||
filter["status"] = status
|
||||
if created_by is not None:
|
||||
filter["created_by"] = created_by
|
||||
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
|
||||
# But if project_id is passed, we filter by that.
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
filter["project_id"] = None
|
||||
if project_id is not None:
|
||||
filter["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
filter["idea_id"] = idea_id
|
||||
|
||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
||||
# Otherwise typically descending (newest first)
|
||||
sort_order = 1 if idea_id else -1
|
||||
|
||||
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
|
||||
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||
offset).limit(limit).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
@@ -57,7 +48,7 @@ class GenerationRepo:
|
||||
return generations
|
||||
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||
args = {}
|
||||
if character_id is not None:
|
||||
args["linked_character_id"] = character_id
|
||||
@@ -67,8 +58,6 @@ class GenerationRepo:
|
||||
args["created_by"] = created_by
|
||||
if project_id is not None:
|
||||
args["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
args["idea_id"] = idea_id
|
||||
return await self.collection.count_documents(args)
|
||||
|
||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||
@@ -98,7 +87,7 @@ class GenerationRepo:
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
|
||||
async def cancel_stale_generations(self, timeout_minutes: int = 60) -> int:
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||
res = await self.collection.update_many(
|
||||
{
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
from typing import Optional, List
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["ideas"]
|
||||
|
||||
async def create_idea(self, idea: Idea) -> str:
|
||||
res = await self.collection.insert_one(idea.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Idea(**res)
|
||||
return None
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
if project_id:
|
||||
match_stage = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_stage},
|
||||
{"$sort": {"updated_at": -1}},
|
||||
{"$skip": offset},
|
||||
{"$limit": limit},
|
||||
# Add string id field for lookup
|
||||
{"$addFields": {"str_id": {"$toString": "$_id"}}},
|
||||
# Lookup generations
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "generations",
|
||||
"let": {"idea_id": "$str_id"},
|
||||
"pipeline": [
|
||||
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest
|
||||
{"$limit": 1}
|
||||
],
|
||||
"as": "generations"
|
||||
}
|
||||
},
|
||||
# Unwind generations array (preserve ideas without generations)
|
||||
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
|
||||
# Rename for clarity
|
||||
{"$addFields": {
|
||||
"last_generation": "$generations",
|
||||
"id": "$str_id"
|
||||
}},
|
||||
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
|
||||
]
|
||||
|
||||
return await self.collection.aggregate(pipeline).to_list(None)
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea_id)},
|
||||
{"$set": {"is_deleted": True}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_idea(self, idea: Idea) -> bool:
|
||||
if not idea.id or not ObjectId.is_valid(idea.id):
|
||||
return False
|
||||
|
||||
idea_dict = idea.model_dump()
|
||||
if "id" in idea_dict:
|
||||
del idea_dict["id"]
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea.id)},
|
||||
{"$set": idea_dict}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,97 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
|
||||
# Import from project root (requires PYTHONPATH=.)
|
||||
from api.service.idea_service import IdeaService
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
load_dotenv()
|
||||
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
async def test_idea_flow():
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client, db_name=DB_NAME)
|
||||
service = IdeaService(dao)
|
||||
|
||||
# 1. Create an Idea
|
||||
print("Creating idea...")
|
||||
user_id = "test_user_123"
|
||||
project_id = "test_project_abc"
|
||||
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
|
||||
print(f"Idea created: {idea.id} - {idea.name}")
|
||||
|
||||
# 2. Update Idea
|
||||
print("Updating idea...")
|
||||
updated_idea = await service.update_idea(idea.id, description="Updated description")
|
||||
print(f"Idea updated: {updated_idea.description}")
|
||||
if updated_idea.description == "Updated description":
|
||||
print("✅ Idea update successful")
|
||||
else:
|
||||
print("❌ Idea update FAILED")
|
||||
|
||||
# 3. Add Generation linked to Idea
|
||||
print("Creating generation linked to idea...")
|
||||
gen = Generation(
|
||||
prompt="idea generation 1",
|
||||
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
aspect_ratio=AspectRatios.NINESIXTEEN,
|
||||
quality=Quality.ONEK,
|
||||
assets_list=[]
|
||||
)
|
||||
gen_id = await dao.generations.create_generation(gen)
|
||||
print(f"Created generation: {gen_id}")
|
||||
|
||||
# Link generation to idea
|
||||
print("Linking generation to idea...")
|
||||
success = await service.add_generation_to_idea(idea.id, gen_id)
|
||||
if success:
|
||||
print("✅ Linking successful")
|
||||
else:
|
||||
print("❌ Linking FAILED")
|
||||
|
||||
# Debug: Check if generation was saved with idea_id
|
||||
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
|
||||
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
|
||||
|
||||
# 4. Fetch Generations for Idea (Verify filtering and ordering)
|
||||
print("Fetching generations for idea...")
|
||||
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
|
||||
print(f"Found {len(gens)} generations in idea")
|
||||
|
||||
if len(gens) == 1 and gens[0].id == gen_id:
|
||||
print("✅ Generation retrieval successful")
|
||||
else:
|
||||
print("❌ Generation retrieval FAILED")
|
||||
|
||||
# 5. Fetch Ideas for Project
|
||||
ideas = await service.get_ideas(project_id)
|
||||
print(f"Found {len(ideas)} ideas for project")
|
||||
|
||||
# Cleaning up
|
||||
print("Cleaning up...")
|
||||
await service.delete_idea(idea.id)
|
||||
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
|
||||
|
||||
# Verify deletion
|
||||
deleted_idea = await service.get_idea(idea.id)
|
||||
# IdeaRepo.delete_idea logic sets is_deleted=True
|
||||
if deleted_idea and deleted_idea.is_deleted:
|
||||
print(f"✅ Idea deleted successfully")
|
||||
|
||||
# Hard delete for cleanup
|
||||
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_idea_flow())
|
||||
Reference in New Issue
Block a user