7 Commits

36 changed files with 923 additions and 41 deletions

1
.gitignore vendored
View File

@@ -23,3 +23,4 @@ services/*.pyc
utils/__pycache__/ utils/__pycache__/
utils/*.pyc utils/*.pyc
.vscode/launch.json .vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

View File

@@ -23,10 +23,10 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview" self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview" self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> tuple: def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки). """Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use.""" Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents = [prompt] contents : list [Any]= [prompt]
opened_images = [] opened_images = []
if images_list: if images_list:
logger.info(f"Preparing content with {len(images_list)} images") logger.info(f"Preparing content with {len(images_list)} images")
@@ -41,7 +41,7 @@ class GoogleAdapter:
logger.info("Preparing content with no images") logger.info("Preparing content with no images")
return contents, opened_images return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str: def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
""" """
Генерация текста (Чат или Vision). Генерация текста (Чат или Vision).
Возвращает строку с ответом. Возвращает строку с ответом.
@@ -74,7 +74,7 @@ class GoogleAdapter:
for img in opened_images: for img in opened_images:
img.close() img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]: def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
""" """
Генерация изображений (Text-to-Image или Image-to-Image). Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке). Возвращает список байтовых потоков (готовых к отправке).
@@ -130,7 +130,9 @@ class GoogleAdapter:
try: try:
# 1. Берем сырые байты # 1. Берем сырые байты
raw_data = part.inline_data.data raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data) if raw_data is None:
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG) # 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp() timestamp = datetime.now().timestamp()

View File

@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager @asynccontextmanager
async def _get_client(self): async def _get_client(self):
async with self.session.client( async with self.session.client( # type: ignore[reportGeneralTypeIssues]
"s3", "s3",
endpoint_url=self.endpoint_url, endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id, aws_access_key_id=self.aws_access_key_id,

15
aiws.py
View File

@@ -43,6 +43,8 @@ from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,6 +65,8 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ --- # --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML)) bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API # Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -82,8 +86,12 @@ s3_adapter = S3Adapter(
) )
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY) gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot) if bot is None:
raise ValueError("bot is not set")
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
album_service = AlbumService(dao) album_service = AlbumService(dao)
# Dispatcher # Dispatcher
@@ -125,11 +133,12 @@ async def start_scheduler(service: GenerationService):
try: try:
logger.info("Running scheduler for stacked generation killing") logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations() await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"Scheduler error: {e}") logger.error(f"Scheduler error: {e}")
await asyncio.sleep(600) # Check every 10 minutes await asyncio.sleep(60) # Check every 60 seconds
# --- LIFESPAN (Запуск FastAPI + Bot) --- # --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager @asynccontextmanager
@@ -210,6 +219,8 @@ app.include_router(api_char_router)
app.include_router(api_gen_router) app.include_router(api_gen_router)
app.include_router(api_album_router) app.include_router(api_album_router)
app.include_router(project_api_router) app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
# Prometheus Metrics (Instrument after all routers are added) # Prometheus Metrics (Instrument after all routers are added)
Instrumentator( Instrumentator(

View File

@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from repos.dao import DAO from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ... # ... ваши импорты ...
@@ -45,7 +46,20 @@ def get_generation_service(
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot)
from api.service.idea_service import IdeaService
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
return IdeaService(dao)
from fastapi import Header from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]: async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)

View File

@@ -23,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
) )
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str | None = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception
except JWTError: except JWTError:

View File

@@ -1,10 +1,13 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse from api.models.GenerationRequest import GenerationResponse
from models.Album import Album from models.Album import Album
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"]) router = APIRouter(prefix="/api/albums", tags=["Albums"])

View File

@@ -0,0 +1,103 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from api.dependency import get_idea_service, get_project_id, get_generation_service
from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
request: IdeaCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
pid = project_id or request.project_id
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
@router.get("", response_model=List[IdeaResponse])
async def get_ideas(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
@router.get("/{idea_id}", response_model=Idea)
async def get_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.get_idea(idea_id)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.put("/{idea_id}", response_model=Idea)
async def update_idea(
idea_id: str,
request: IdeaUpdateRequest,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.update_idea(idea_id, request.name, request.description)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.delete("/{idea_id}")
async def delete_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.delete_idea(idea_id)
if not success:
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
return {"status": "success"}
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)
):
# Depending on how generation service implements filtering by idea_id.
# We might need to update generation_service to support getting by idea_id directly
# or ensure generic get_generations supports it.
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
# Let's check generation_service.get_generations signature again.
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
# I need to update GenerationService.get_generations too!
# For now, let's assume I will update it.
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
@router.post("/{idea_id}/generations/{generation_id}")
async def add_generation_to_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}
@router.delete("/{idea_id}/generations/{generation_id}")
async def remove_generation_from_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}

View File

@@ -0,0 +1,99 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=List[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

View File

@@ -17,6 +17,7 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10) count: int = Field(default=1, ge=1, le=10)
@@ -47,6 +48,7 @@ class GenerationResponse(BaseModel):
cost: Optional[float] = None cost: Optional[float] = None
created_by: Optional[str] = None created_by: Optional[str] = None
generation_group_id: Optional[str] = None generation_group_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)

16
api/models/IdeaRequest.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
class IdeaCreateRequest(BaseModel):
name: str
description: Optional[str] = None
project_id: Optional[str] = None # Optional in body if passed via header/dependency
class IdeaUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
class IdeaResponse(Idea):
last_generation: Optional[GenerationResponse] = None

19
api/models/PostRequest.py Normal file
View File

@@ -0,0 +1,19 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]

View File

@@ -77,7 +77,7 @@ class GenerationService:
self.bot = bot self.bot = bot
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str: async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
@@ -100,10 +100,9 @@ class GenerationService:
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[ async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
Generation]: generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id) total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations] generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count) return GenerationsResponse(generations=generations, total_count=total_count)
@@ -140,6 +139,10 @@ class GenerationService:
if generation_group_id: if generation_group_id:
generation_model.generation_group_id = generation_group_id generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id
@@ -154,8 +157,9 @@ class GenerationService:
# если генерация уже пошла и упала — пометим FAILED # если генерация уже пошла и упала — пометим FAILED
try: try:
db_gen = await self.dao.generations.get_generation(gen.id) db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED if db_gen is not None:
await self.dao.generations.update_generation(db_gen) db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception: except Exception:
logger.exception("Failed to mark generation as FAILED") logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed") logger.exception("create_generation task failed")
@@ -169,8 +173,9 @@ class GenerationService:
if gen_id is not None: if gen_id is not None:
try: try:
gen = await self.dao.generations.get_generation(gen_id) gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED if gen is not None:
await self.dao.generations.update_generation(gen) gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception: except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task") logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise raise
@@ -198,9 +203,10 @@ class GenerationService:
if char_info is None: if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found") raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image: if generation.use_profile_image:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) if char_info.avatar_asset_id is not None:
if avatar_asset: avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
media_group_bytes.append(avatar_asset.data) if avatar_asset and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
@@ -301,7 +307,9 @@ class GenerationService:
# 5. (Опционально) Обновляем запись генерации ссылками на результаты # 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids # Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets] result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids generation.result_list = result_ids
generation.status = GenerationStatus.DONE generation.status = GenerationStatus.DONE
@@ -477,3 +485,25 @@ class GenerationService:
logger.info(f"Cleaned up {count} stale generations (timeout)") logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e: except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}") logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")

View File

@@ -0,0 +1,75 @@
from typing import List, Optional
from datetime import datetime
from repos.dao import DAO
from models.Idea import Idea
class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return None
if name is not None:
idea.name = name
if description is not None:
idea.description = description
idea.updated_at = datetime.now()
await self.dao.ideas.update_idea(idea)
return idea
async def delete_idea(self, idea_id: str) -> bool:
return await self.dao.ideas.delete_idea(idea_id)
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Link
gen.idea_id = idea_id
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists (optional, but good for validation)
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Unlink only if currently linked to this idea
if gen.idea_id == idea_id:
gen.idea_id = None
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
return False

View File

@@ -0,0 +1,79 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

View File

@@ -30,6 +30,7 @@ class Asset(BaseModel):
tags: List[str] = [] tags: List[str] = []
created_by: Optional[str] = None created_by: Optional[str] = None
project_id: Optional[str] = None project_id: Optional[str] = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -38,6 +38,7 @@ class Generation(BaseModel):
generation_group_id: Optional[str] = None generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

13
models/Idea.py Normal file
View File

@@ -0,0 +1,13 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: Optional[str] = None
name: str = "New Idea"
description: Optional[str] = None
project_id: Optional[str] = None
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)

23
models/Post.py Normal file
View File

@@ -0,0 +1,23 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

View File

@@ -1,6 +1,8 @@
from typing import List, Optional from typing import Any, List, Optional
import logging import logging
from datetime import datetime, UTC
from bson import ObjectId from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset from models.Asset import Asset
@@ -19,7 +21,8 @@ class AssetsRepo:
# Main data # Main data
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}" uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
uploaded = await self.s3.upload_file(object_name, asset.data) uploaded = await self.s3.upload_file(object_name, asset.data)
if uploaded: if uploaded:
@@ -32,7 +35,8 @@ class AssetsRepo:
# Thumbnail # Thumbnail
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail) uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
if uploaded_thumb: if uploaded_thumb:
@@ -47,7 +51,7 @@ class AssetsRepo:
return str(res.inserted_id) return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]: async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {} filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
if asset_type: if asset_type:
filter["type"] = asset_type filter["type"] = asset_type
args = {} args = {}
@@ -134,7 +138,8 @@ class AssetsRepo:
if self.s3: if self.s3:
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}" uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
if await self.s3.upload_file(object_name, asset.data): if await self.s3.upload_file(object_name, asset.data):
asset.minio_object_name = object_name asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name asset.minio_bucket = self.s3.bucket_name
@@ -142,7 +147,8 @@ class AssetsRepo:
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, asset.thumbnail): if await self.s3.upload_file(thumb_name, asset.thumbnail):
asset.minio_thumbnail_object_name = thumb_name asset.minio_thumbnail_object_name = thumb_name
asset.thumbnail = None asset.thumbnail = None
@@ -197,6 +203,61 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)}) res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0 return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict: async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO.""" """Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3: if not self.s3:
@@ -216,7 +277,8 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
object_name = f"{type_}/{ts}_{asset_id}_{name}" uid = uuid4().hex[:8]
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
if await self.s3.upload_file(object_name, data): if await self.s3.upload_file(object_name, data):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},
@@ -243,7 +305,8 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, thumb): if await self.s3.upload_file(thumb_name, thumb):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},

View File

@@ -6,6 +6,8 @@ from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from typing import Optional from typing import Optional
@@ -19,3 +21,5 @@ class DAO:
self.albums = AlbumsRepo(client, db_name) self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name) self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name) self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)

View File

@@ -1,4 +1,4 @@
from typing import Optional, List from typing import Any, Optional, List
from datetime import datetime, timedelta, UTC from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset from PIL.ImageChops import offset
@@ -17,7 +17,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump()) res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]: async def get_generation(self, generation_id: str) -> Generation | None:
res = await self.collection.find_one({"_id": ObjectId(generation_id)}) res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None: if res is None:
return None return None
@@ -26,20 +26,29 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False} filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None: if character_id is not None:
filter["linked_character_id"] = character_id filter["linked_character_id"] = character_id
if status is not None: if status is not None:
filter["status"] = status filter["status"] = status
if created_by is not None: if created_by is not None:
filter["created_by"] = created_by filter["created_by"] = created_by
filter["project_id"] = None # If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
# But if project_id is passed, we filter by that.
if project_id is None:
filter["project_id"] = None
if project_id is not None: if project_id is not None:
filter["project_id"] = project_id filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
res = await self.collection.find(filter).sort("created_at", -1).skip( # If fetching for an idea, sort by created_at ascending (cronological)
# Otherwise typically descending (newest first)
sort_order = 1 if idea_id else -1
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
for generation in res: for generation in res:
@@ -48,7 +57,7 @@ class GenerationRepo:
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int: album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
@@ -58,6 +67,10 @@ class GenerationRepo:
args["created_by"] = created_by args["created_by"] = created_by
if project_id is not None: if project_id is not None:
args["project_id"] = project_id args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -87,7 +100,7 @@ class GenerationRepo:
generations.append(Generation(**generation)) generations.append(Generation(**generation))
return generations return generations
async def cancel_stale_generations(self, timeout_minutes: int = 60) -> int: async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes) cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
res = await self.collection.update_many( res = await self.collection.update_many(
{ {
@@ -103,3 +116,37 @@ class GenerationRepo:
} }
) )
return res.modified_count return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

82
repos/idea_repo.py Normal file
View File

@@ -0,0 +1,82 @@
from typing import Optional, List
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Idea import Idea
class IdeaRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["ideas"]
async def create_idea(self, idea: Idea) -> str:
res = await self.collection.insert_one(idea.model_dump())
return str(res.inserted_id)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
if not ObjectId.is_valid(idea_id):
return None
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
if res:
res["id"] = str(res.pop("_id"))
return Idea(**res)
return None
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
if project_id:
match_stage = {"project_id": project_id, "is_deleted": False}
else:
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
pipeline = [
{"$match": match_stage},
{"$sort": {"updated_at": -1}},
{"$skip": offset},
{"$limit": limit},
# Add string id field for lookup
{"$addFields": {"str_id": {"$toString": "$_id"}}},
# Lookup generations
{
"$lookup": {
"from": "generations",
"let": {"idea_id": "$str_id"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
{"$sort": {"created_at": -1}}, # Ensure we get the latest
{"$limit": 1}
],
"as": "generations"
}
},
# Unwind generations array (preserve ideas without generations)
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
# Rename for clarity
{"$addFields": {
"last_generation": "$generations",
"id": "$str_id"
}},
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
]
return await self.collection.aggregate(pipeline).to_list(None)
async def delete_idea(self, idea_id: str) -> bool:
if not ObjectId.is_valid(idea_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(idea_id)},
{"$set": {"is_deleted": True}}
)
return res.modified_count > 0
async def update_idea(self, idea: Idea) -> bool:
if not idea.id or not ObjectId.is_valid(idea.id):
return False
idea_dict = idea.model_dump()
if "id" in idea_dict:
del idea_dict["id"]
res = await self.collection.update_one(
{"_id": ObjectId(idea.id)},
{"$set": idea_dict}
)
return res.modified_count > 0

97
repos/post_repo.py Normal file
View File

@@ -0,0 +1,97 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0

97
tests/test_idea.py Normal file
View File

@@ -0,0 +1,97 @@
import asyncio
import os
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from project root (requires PYTHONPATH=.)
from api.service.idea_service import IdeaService
from repos.dao import DAO
from models.Idea import Idea
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
load_dotenv()
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_idea_flow():
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
service = IdeaService(dao)
# 1. Create an Idea
print("Creating idea...")
user_id = "test_user_123"
project_id = "test_project_abc"
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
print(f"Idea created: {idea.id} - {idea.name}")
# 2. Update Idea
print("Updating idea...")
updated_idea = await service.update_idea(idea.id, description="Updated description")
print(f"Idea updated: {updated_idea.description}")
if updated_idea.description == "Updated description":
print("✅ Idea update successful")
else:
print("❌ Idea update FAILED")
# 3. Add Generation linked to Idea
print("Creating generation linked to idea...")
gen = Generation(
prompt="idea generation 1",
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
project_id=project_id,
created_by=user_id,
aspect_ratio=AspectRatios.NINESIXTEEN,
quality=Quality.ONEK,
assets_list=[]
)
gen_id = await dao.generations.create_generation(gen)
print(f"Created generation: {gen_id}")
# Link generation to idea
print("Linking generation to idea...")
success = await service.add_generation_to_idea(idea.id, gen_id)
if success:
print("✅ Linking successful")
else:
print("❌ Linking FAILED")
# Debug: Check if generation was saved with idea_id
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
# 4. Fetch Generations for Idea (Verify filtering and ordering)
print("Fetching generations for idea...")
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
print(f"Found {len(gens)} generations in idea")
if len(gens) == 1 and gens[0].id == gen_id:
print("✅ Generation retrieval successful")
else:
print("❌ Generation retrieval FAILED")
# 5. Fetch Ideas for Project
ideas = await service.get_ideas(project_id)
print(f"Found {len(ideas)} ideas for project")
# Cleaning up
print("Cleaning up...")
await service.delete_idea(idea.id)
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
# Verify deletion
deleted_idea = await service.get_idea(idea.id)
# IdeaRepo.delete_idea logic sets is_deleted=True
if deleted_idea and deleted_idea.is_deleted:
print(f"✅ Idea deleted successfully")
# Hard delete for cleanup
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
if __name__ == "__main__":
asyncio.run(test_idea_flow())