15 Commits

Author SHA1 Message Date
xds
d820d9145b feat: introduce post resource with full CRUD operations and generation linking. 2026-02-17 15:54:01 +03:00
xds
c93e577bcf feat: Implement asset soft deletion with S3 file purging, enhance type safety, and improve error handling in generation and adapter services. 2026-02-17 12:51:40 +03:00
xds
c5d4849bff Update compiled bytecode for assets_repo.py. 2026-02-16 16:41:53 +03:00
xds
9abfbef871 Merge branch 'ideas' 2026-02-16 16:41:13 +03:00
xds
68a3f529cb feat: Enhance idea retrieval to include the latest generation and support user-specific ideas not tied to a project, while also improving asset storage uniqueness and adjusting generation cancellation timeout. 2026-02-16 16:35:26 +03:00
xds
e2c050515d feat: Limit concurrent image generations to 4 using an asyncio semaphore. 2026-02-16 00:30:34 +03:00
xds
5e7dc19bf3 ideas 2026-02-15 12:42:15 +03:00
xds
97483b7030 + ideas 2026-02-15 10:26:01 +03:00
xds
2d3da59de9 12 2026-02-13 17:31:14 +03:00
xds
279cb5c6f6 feat: Implement cancellation of stale generations in the service and repository, along with a new test. 2026-02-13 17:30:11 +03:00
xds
30138bab38 feat: Introduce generation grouping, enabling multiple generations per request via a new count parameter and retrieval by group ID. 2026-02-13 11:18:11 +03:00
xds
977cab92f8 fix 2026-02-12 18:41:01 +03:00
xds
dcab238d3e fix 2026-02-12 15:50:43 +03:00
xds
9d2e4e47de fix 2026-02-12 15:32:28 +03:00
xds
c6142715d9 feat: Add image, update VS Code launch configuration, and enhance gitignore rules for build artifacts. 2026-02-12 14:02:36 +03:00
54 changed files with 1194 additions and 117 deletions

15
.gitignore vendored
View File

@@ -9,3 +9,18 @@ minio_backup.tar.gz
.idea
.venv
.vscode
.vscode/launch.json
middlewares/__pycache__/
middlewares/*.pyc
api/__pycache__/
api/*.pyc
repos/__pycache__/
repos/*.pyc
adapters/__pycache__/
adapters/*.pyc
services/__pycache__/
services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

27
.vscode/launch.json vendored
View File

@@ -7,7 +7,7 @@
"request": "launch",
"module": "uvicorn",
"args": [
"main:app",
"aiws:app",
"--reload",
"--port",
"8090",
@@ -16,31 +16,6 @@
],
"jinja": true,
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Debug Tests: Current File",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"${file}"
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}
]
}

View File

@@ -23,28 +23,30 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
contents = [prompt]
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents : list [Any]= [prompt]
opened_images = []
if images_list:
logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list:
try:
# Gemini API требует PIL Image на входе
image = Image.open(io.BytesIO(img_bytes))
contents.append(image)
opened_images.append(image)
except Exception as e:
logger.error(f"Error processing input image: {e}")
else:
logger.info("Preparing content with no images")
return contents
return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
"""
Генерация текста (Чат или Vision).
Возвращает строку с ответом.
"""
contents = self._prepare_contents(prompt, images_list)
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}")
try:
response = self.client.models.generate_content(
@@ -68,14 +70,17 @@ class GoogleAdapter:
except Exception as e:
logger.error(f"Gemini Text API Error: {e}")
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
finally:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке).
"""
contents = self._prepare_contents(prompt, images_list)
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
start_time = datetime.now()
@@ -101,8 +106,20 @@ class GoogleAdapter:
if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count
if response.parts is None and response.candidates[0].finish_reason is not None:
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
if response.prompt_feedback and response.prompt_feedback.block_reason:
raise GoogleGenerationException(
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
)
# Check candidate-level block
if response.parts is None:
response_reason = (
response.candidates[0].finish_reason
if response.candidates and len(response.candidates) > 0
else "Unknown"
)
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
generated_images = []
@@ -113,7 +130,9 @@ class GoogleAdapter:
try:
# 1. Берем сырые байты
raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data)
if raw_data is None:
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp()
@@ -148,3 +167,7 @@ class GoogleAdapter:
except Exception as e:
logger.error(f"Gemini Image API Error: {e}")
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
finally:
for img in opened_images:
img.close()
del contents

View File

@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager
async def _get_client(self):
async with self.session.client(
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id,
@@ -56,6 +56,21 @@ class S3Adapter:
print(f"Error downloading from S3: {e}")
return None
async def stream_file(self, object_name: str, chunk_size: int = 65536):
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
try:
async with self._get_client() as client:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']
data = await body.read()
# Yield in chunks to avoid holding entire response in StreamingResponse buffer
for i in range(0, len(data), chunk_size):
yield data[i:i + chunk_size]
except ClientError as e:
print(f"Error streaming from S3: {e}")
return
async def delete_file(self, object_name: str):
"""Deletes a file from S3."""
try:

43
aiws.py
View File

@@ -43,6 +43,8 @@ from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
load_dotenv()
logger = logging.getLogger(__name__)
@@ -63,6 +65,8 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -82,8 +86,12 @@ s3_adapter = S3Adapter(
)
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot)
if bot is None:
raise ValueError("bot is not set")
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
album_service = AlbumService(dao)
# Dispatcher
@@ -120,6 +128,18 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
async def start_scheduler(service: GenerationService):
while True:
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(60) # Check every 60 seconds
# --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -151,17 +171,28 @@ async def lifespan(app: FastAPI):
# )
# print("🤖 Bot polling started")
# 3. ЗАПУСК ШЕДУЛЕРА
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
print("⏰ Scheduler started")
yield
# --- SHUTDOWN ---
print("🛑 Shutting down...")
# 3. Остановка бота
polling_task.cancel()
# 4. Остановка шедулера
scheduler_task.cancel()
try:
await polling_task
await scheduler_task
except asyncio.CancelledError:
print("🤖 Bot polling stopped")
print("⏰ Scheduler stopped")
# 3. Остановка бота
# polling_task.cancel()
# try:
# await polling_task
# except asyncio.CancelledError:
# print("🤖 Bot polling stopped")
# 4. Отключение БД
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
@@ -188,6 +219,8 @@ app.include_router(api_char_router)
app.include_router(api_gen_router)
app.include_router(api_album_router)
app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(

View File

@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService
from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ...
@@ -45,7 +46,20 @@ def get_generation_service(
) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot)
from api.service.idea_service import IdeaService
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
return IdeaService(dao)
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)

View File

@@ -23,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
username: str | None = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:

View File

@@ -1,10 +1,13 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"])

View File

@@ -9,7 +9,7 @@ from pymongo import MongoClient
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from starlette.responses import Response, JSONResponse, StreamingResponse
from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse
@@ -33,27 +33,46 @@ async def get_asset(
asset_id: str,
request: Request,
thumbnail: bool = False,
dao: DAO = Depends(get_dao)
dao: DAO = Depends(get_dao),
s3_adapter: S3Adapter = Depends(get_s3_adapter),
) -> Response:
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
asset = await dao.assets.get_asset(asset_id)
# 2. Проверка на существование
# Загружаем только метаданные (без data/thumbnail bytes)
asset = await dao.assets.get_asset(asset_id, with_data=False)
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable"
}
content = asset.data
media_type = "image/png" # Default, or detect
# Thumbnail: маленький, можно грузить в RAM
if thumbnail:
if asset.minio_thumbnail_object_name and s3_adapter:
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
if thumb_bytes:
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
# Fallback: thumbnail in DB
if asset.thumbnail:
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
# No thumbnail available — fall through to main content
if thumbnail and asset.thumbnail:
content = asset.thumbnail
media_type = "image/jpeg"
# Main content: стримим из S3 без загрузки в RAM
if asset.minio_object_name and s3_adapter:
content_type = "image/png"
# if asset.content_type == AssetContentType.VIDEO:
# content_type = "video/mp4"
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name),
media_type=content_type,
headers=headers,
)
return Response(content=content, media_type=media_type, headers=headers)
# Fallback: data stored in DB (legacy)
if asset.data:
return Response(content=asset.data, media_type="image/png", headers=headers)
raise HTTPException(status_code=404, detail="Asset data not found")
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio(

View File

@@ -8,7 +8,7 @@ from api import service
from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
from api.service.generation_service import GenerationService
from models.Generation import Generation
@@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id:
@@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request,
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running")
async def get_running_generations(request: Request,
@@ -113,6 +103,27 @@ async def get_running_generations(request: Request,
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.post("/import", response_model=GenerationResponse)

View File

@@ -0,0 +1,103 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from api.dependency import get_idea_service, get_project_id, get_generation_service
from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
request: IdeaCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
pid = project_id or request.project_id
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
@router.get("", response_model=List[IdeaResponse])
async def get_ideas(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
@router.get("/{idea_id}", response_model=Idea)
async def get_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.get_idea(idea_id)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.put("/{idea_id}", response_model=Idea)
async def update_idea(
idea_id: str,
request: IdeaUpdateRequest,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.update_idea(idea_id, request.name, request.description)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.delete("/{idea_id}")
async def delete_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.delete_idea(idea_id)
if not success:
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
return {"status": "success"}
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)
):
# Depending on how generation service implements filtering by idea_id.
# We might need to update generation_service to support getting by idea_id directly
# or ensure generic get_generations supports it.
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
# Let's check generation_service.get_generations signature again.
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
# I need to update GenerationService.get_generations too!
# For now, let's assume I will update it.
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
@router.post("/{idea_id}/generations/{generation_id}")
async def add_generation_to_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}
@router.delete("/{idea_id}/generations/{generation_id}")
async def remove_generation_from_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}

View File

@@ -0,0 +1,99 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=List[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

View File

@@ -1,7 +1,7 @@
from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from models.Asset import Asset
from models.Generation import GenerationStatus
@@ -17,6 +17,8 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True
assets_list: List[str]
project_id: Optional[str] = None
idea_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10)
class GenerationsResponse(BaseModel):
@@ -45,10 +47,16 @@ class GenerationResponse(BaseModel):
progress: int = 0
cost: Optional[float] = None
created_by: Optional[str] = None
generation_group_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
class PromptRequest(BaseModel):
prompt: str

16
api/models/IdeaRequest.py Normal file
View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
class IdeaCreateRequest(BaseModel):
name: str
description: Optional[str] = None
project_id: Optional[str] = None # Optional in body if passed via header/dependency
class IdeaUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
class IdeaResponse(Idea):
last_generation: Optional[GenerationResponse] = None

19
api/models/PostRequest.py Normal file
View File

@@ -0,0 +1,19 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]

View File

@@ -5,13 +5,14 @@ import base64
from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO
from uuid import uuid4
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
@@ -21,6 +22,9 @@ from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
@@ -50,16 +54,18 @@ async def generate_image_task(
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e:
raise e
finally:
# Освобождаем входные данные — они больше не нужны
del media_group_bytes
images_bytes = []
if generated_images_io:
for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0)
content = img_io.read()
images_bytes.append(content)
# Закрываем поток
images_bytes.append(img_io.read())
img_io.close()
# Освобождаем список BytesIO сразу
del generated_images_io
return images_bytes, metrics
@@ -71,7 +77,7 @@ class GenerationService:
self.bot = bot
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
@@ -94,10 +100,9 @@ class GenerationService:
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
Generation]:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count)
@@ -111,29 +116,50 @@ class GenerationService:
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump())
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id:
generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
logger.info(f"Starting background generation task for ID: {gen.id}")
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
try:
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
async with generation_semaphore:
logger.info(f"Starting background generation task for ID: {gen.id}")
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
if db_gen is not None:
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
@@ -147,8 +173,9 @@ class GenerationService:
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
if gen is not None:
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
@@ -176,9 +203,10 @@ class GenerationService:
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
media_group_bytes.append(avatar_asset.data)
if char_info.avatar_asset_id is not None:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
@@ -279,7 +307,9 @@ class GenerationService:
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
@@ -444,3 +474,36 @@ class GenerationService:
except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}")
return False
async def cleanup_stale_generations(self):
"""
Cancels generations that have been running for more than 1 hour.
"""
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")

View File

@@ -0,0 +1,75 @@
from typing import List, Optional
from datetime import datetime
from repos.dao import DAO
from models.Idea import Idea
class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return None
if name is not None:
idea.name = name
if description is not None:
idea.description = description
idea.updated_at = datetime.now()
await self.dao.ideas.update_idea(idea)
return idea
async def delete_idea(self, idea_id: str) -> bool:
return await self.dao.ideas.delete_idea(idea_id)
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Link
gen.idea_id = idea_id
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists (optional, but good for validation)
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Unlink only if currently linked to this idea
if gen.idea_id == idea_id:
gen.idea_id = None
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
return False

View File

@@ -0,0 +1,79 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

View File

@@ -30,6 +30,7 @@ class Asset(BaseModel):
tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -35,8 +35,10 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None
is_deleted: bool = False
album_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
idea_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

13
models/Idea.py Normal file
View File

@@ -0,0 +1,13 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: Optional[str] = None
name: str = "New Idea"
description: Optional[str] = None
project_id: Optional[str] = None
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)

23
models/Post.py Normal file
View File

@@ -0,0 +1,23 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

View File

@@ -1,6 +1,8 @@
from typing import List, Optional
from typing import Any, List, Optional
import logging
from datetime import datetime, UTC
from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset
@@ -19,7 +21,8 @@ class AssetsRepo:
# Main data
if asset.data:
ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}"
uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
uploaded = await self.s3.upload_file(object_name, asset.data)
if uploaded:
@@ -32,7 +35,8 @@ class AssetsRepo:
# Thumbnail
if asset.thumbnail:
ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
if uploaded_thumb:
@@ -47,7 +51,7 @@ class AssetsRepo:
return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {}
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
if asset_type:
filter["type"] = asset_type
args = {}
@@ -134,7 +138,8 @@ class AssetsRepo:
if self.s3:
if asset.data:
ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}"
uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
if await self.s3.upload_file(object_name, asset.data):
asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name
@@ -142,7 +147,8 @@ class AssetsRepo:
if asset.thumbnail:
ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, asset.thumbnail):
asset.minio_thumbnail_object_name = thumb_name
asset.thumbnail = None
@@ -197,6 +203,61 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3:
@@ -216,7 +277,8 @@ class AssetsRepo:
created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0
object_name = f"{type_}/{ts}_{asset_id}_{name}"
uid = uuid4().hex[:8]
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
if await self.s3.upload_file(object_name, data):
await self.collection.update_one(
{"_id": asset_id},
@@ -243,7 +305,8 @@ class AssetsRepo:
created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
uid = uuid4().hex[:8]
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, thumb):
await self.collection.update_one(
{"_id": asset_id},

View File

@@ -6,6 +6,8 @@ from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from typing import Optional
@@ -19,3 +21,5 @@ class DAO:
self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)

View File

@@ -1,4 +1,5 @@
from typing import Optional, List
from typing import Any, Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset
from bson import ObjectId
@@ -16,7 +17,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
async def get_generation(self, generation_id: str) -> Generation | None:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
@@ -25,20 +26,29 @@ class GenerationRepo:
return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False}
filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None:
filter["linked_character_id"] = character_id
if status is not None:
filter["status"] = status
if created_by is not None:
filter["created_by"] = created_by
filter["project_id"] = None
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
# But if project_id is passed, we filter by that.
if project_id is None:
filter["project_id"] = None
if project_id is not None:
filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
res = await self.collection.find(filter).sort("created_at", -1).skip(
# If fetching for an idea, sort by created_at ascending (cronological)
# Otherwise typically descending (newest first)
sort_order = 1 if idea_id else -1
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
offset).limit(limit).to_list(None)
generations: List[Generation] = []
for generation in res:
@@ -47,7 +57,7 @@ class GenerationRepo:
return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
args = {}
if character_id is not None:
args["linked_character_id"] = character_id
@@ -57,6 +67,10 @@ class GenerationRepo:
args["created_by"] = created_by
if project_id is not None:
args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -77,3 +91,62 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
res = await self.collection.update_many(
{
"status": GenerationStatus.RUNNING,
"created_at": {"$lt": cutoff_time}
},
{
"$set": {
"status": GenerationStatus.FAILED,
"failed_reason": "Timeout: Execution time limit exceeded",
"updated_at": datetime.now(UTC)
}
}
)
return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

82
repos/idea_repo.py Normal file
View File

@@ -0,0 +1,82 @@
from typing import Optional, List
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Idea import Idea
class IdeaRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["ideas"]
async def create_idea(self, idea: Idea) -> str:
res = await self.collection.insert_one(idea.model_dump())
return str(res.inserted_id)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
if not ObjectId.is_valid(idea_id):
return None
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
if res:
res["id"] = str(res.pop("_id"))
return Idea(**res)
return None
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
if project_id:
match_stage = {"project_id": project_id, "is_deleted": False}
else:
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
pipeline = [
{"$match": match_stage},
{"$sort": {"updated_at": -1}},
{"$skip": offset},
{"$limit": limit},
# Add string id field for lookup
{"$addFields": {"str_id": {"$toString": "$_id"}}},
# Lookup generations
{
"$lookup": {
"from": "generations",
"let": {"idea_id": "$str_id"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
{"$sort": {"created_at": -1}}, # Ensure we get the latest
{"$limit": 1}
],
"as": "generations"
}
},
# Unwind generations array (preserve ideas without generations)
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
# Rename for clarity
{"$addFields": {
"last_generation": "$generations",
"id": "$str_id"
}},
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
]
return await self.collection.aggregate(pipeline).to_list(None)
async def delete_idea(self, idea_id: str) -> bool:
if not ObjectId.is_valid(idea_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(idea_id)},
{"$set": {"is_deleted": True}}
)
return res.modified_count > 0
async def update_idea(self, idea: Idea) -> bool:
if not idea.id or not ObjectId.is_valid(idea.id):
return False
idea_dict = idea.model_dump()
if "id" in idea_dict:
del idea_dict["id"]
res = await self.collection.update_one(
{"_id": ObjectId(idea.id)},
{"$set": idea_dict}
)
return res.modified_count > 0

97
repos/post_repo.py Normal file
View File

@@ -0,0 +1,97 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0

97
tests/test_idea.py Normal file
View File

@@ -0,0 +1,97 @@
import asyncio
import os
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from project root (requires PYTHONPATH=.)
from api.service.idea_service import IdeaService
from repos.dao import DAO
from models.Idea import Idea
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
load_dotenv()
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_idea_flow():
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
service = IdeaService(dao)
# 1. Create an Idea
print("Creating idea...")
user_id = "test_user_123"
project_id = "test_project_abc"
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
print(f"Idea created: {idea.id} - {idea.name}")
# 2. Update Idea
print("Updating idea...")
updated_idea = await service.update_idea(idea.id, description="Updated description")
print(f"Idea updated: {updated_idea.description}")
if updated_idea.description == "Updated description":
print("✅ Idea update successful")
else:
print("❌ Idea update FAILED")
# 3. Add Generation linked to Idea
print("Creating generation linked to idea...")
gen = Generation(
prompt="idea generation 1",
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
project_id=project_id,
created_by=user_id,
aspect_ratio=AspectRatios.NINESIXTEEN,
quality=Quality.ONEK,
assets_list=[]
)
gen_id = await dao.generations.create_generation(gen)
print(f"Created generation: {gen_id}")
# Link generation to idea
print("Linking generation to idea...")
success = await service.add_generation_to_idea(idea.id, gen_id)
if success:
print("✅ Linking successful")
else:
print("❌ Linking FAILED")
# Debug: Check if generation was saved with idea_id
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
# 4. Fetch Generations for Idea (Verify filtering and ordering)
print("Fetching generations for idea...")
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
print(f"Found {len(gens)} generations in idea")
if len(gens) == 1 and gens[0].id == gen_id:
print("✅ Generation retrieval successful")
else:
print("❌ Generation retrieval FAILED")
# 5. Fetch Ideas for Project
ideas = await service.get_ideas(project_id)
print(f"Found {len(ideas)} ideas for project")
# Cleaning up
print("Cleaning up...")
await service.delete_idea(idea.id)
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
# Verify deletion
deleted_idea = await service.get_idea(idea.id)
# IdeaRepo.delete_idea logic sets is_deleted=True
if deleted_idea and deleted_idea.is_deleted:
print(f"✅ Idea deleted successfully")
# Hard delete for cleanup
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
if __name__ == "__main__":
asyncio.run(test_idea_flow())

52
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,52 @@
import asyncio
import os
from datetime import datetime, timedelta, UTC
from motor.motor_asyncio import AsyncIOMotorClient
from models.Generation import Generation, GenerationStatus
from repos.generation_repo import GenerationRepo
from dotenv import load_dotenv
load_dotenv()
# Mock configs if not present in env
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_scheduler():
client = AsyncIOMotorClient(MONGO_HOST)
repo = GenerationRepo(client, db_name=DB_NAME)
# 1. Create a "stale" generation (2 hours ago)
stale_gen = Generation(
prompt="stale test",
status=GenerationStatus.RUNNING,
created_at=datetime.now(UTC) - timedelta(minutes=120),
assets_list=[],
aspect_ratio="NINESIXTEEN",
quality="ONEK"
)
gen_id = await repo.create_generation(stale_gen)
print(f"Created stale generation: {gen_id}")
# 2. Run cleanup
print("Running cleanup...")
count = await repo.cancel_stale_generations(timeout_minutes=60)
print(f"Cleaned up {count} generations")
# 3. Verify status
updated_gen = await repo.get_generation(gen_id)
print(f"Generation status: {updated_gen.status}")
print(f"Failed reason: {updated_gen.failed_reason}")
if updated_gen.status == GenerationStatus.FAILED:
print("✅ SUCCESS: Generation marked as FAILED")
else:
print("❌ FAILURE: Generation status not updated")
# Cleanup
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
if __name__ == "__main__":
asyncio.run(test_scheduler())