21 Commits

Author SHA1 Message Date
xds
32ff77e04b feat: Implement video generation functionality and integrate with Kling API. 2026-02-12 10:27:07 +03:00
xds
d1f67c773f 123 2026-02-12 00:25:08 +03:00
xds
c63b51ef75 123
er the commit message for your changes. Lines starting
2026-02-12 00:24:43 +03:00
xds
456562ec1d main -> aiws 2026-02-12 00:13:06 +03:00
xds
0d0fbdf7d6 main -> aiws 2026-02-11 12:56:51 +03:00
xds
f63bcedb13 main -> aiws 2026-02-11 12:46:57 +03:00
xds
be92c766ac main -> aiws 2026-02-11 12:46:35 +03:00
xds
482bc1d9b7 main -> aiws 2026-02-11 12:30:05 +03:00
xds
a2321cf070 + prometheus 2026-02-11 11:56:08 +03:00
xds
29ccd5743e main -> aiws 2026-02-11 11:37:04 +03:00
xds
d9de2f48d2 main -> aiws 2026-02-11 11:19:50 +03:00
xds
1ddeb0af46 main -> aiws 2026-02-11 11:15:21 +03:00
xds
a7c2319f13 feat: Implement external generation import API secured by HMAC-SHA256 signature verification. 2026-02-10 14:06:37 +03:00
xds
00e83b8561 fix 2026-02-09 17:01:48 +03:00
xds
a9d24c725e Update user repository implementation. 2026-02-09 16:16:55 +03:00
xds
458b6ebfc3 feat: Implement project management with new models, repositories, and API endpoints, and enhance character management with project association and DTOs. 2026-02-09 16:06:54 +03:00
xds
668aadcdc9 fix 2026-02-09 09:49:49 +03:00
xds
4461964791 feat: Add created_by and cost fields to generation models, populate created_by from the authenticated user, and implement cost calculation. 2026-02-09 01:52:23 +03:00
xds
fa3e1bb05f refactor: Remove trailing slashes from album router endpoint paths. 2026-02-09 00:47:54 +03:00
xds
8a89b27624 feat: Add album management functionality with new data model, repository, service, API, and generation integration. 2026-02-08 23:13:31 +03:00
xds
c17c47ccc1 catch exception123 2026-02-08 22:56:08 +03:00
86 changed files with 1994 additions and 106 deletions

19
.dockerignore Normal file
View File

@@ -0,0 +1,19 @@
.git
.gitignore
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.venv/
node_modules/
tmp/
logs/
*.log
dist/
build/
.cache/
.idea/
.vscode/

3
.env
View File

@@ -8,3 +8,6 @@ MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=SuperSecretPassword123! MINIO_SECRET_KEY=SuperSecretPassword123!
MINIO_BUCKET=ai-char MINIO_BUCKET=ai-char
MODE=production MODE=production
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4

12
.gitignore vendored
View File

@@ -1,15 +1,11 @@
minio_backup.tar.gz minio_backup.tar.gz
.DS_Store .DS_Store
.DS_Store
.DS_Store
**/__pycache__/ **/__pycache__/
# Игнорируем файлы скомпилированного байт-кода напрямую
*.py[cod] *.py[cod]
*$py.class *$py.class
# Игнорируем расширения CPython конкретно
*.cpython-*.pyc *.cpython-*.pyc
# Игнорируем файлы .DS_Store на всех уровнях
**/.*.DS_Store
**/.DS_Store **/.DS_Store
.idea/ai-char-bot.iml
.idea
.venv
.vscode

6
.vscode/launch.json vendored
View File

@@ -7,10 +7,12 @@
"request": "launch", "request": "launch",
"module": "uvicorn", "module": "uvicorn",
"args": [ "args": [
"main:app", "aiws:app",
"--reload", "--reload",
"--port", "--port",
"8090" "8090",
"--host",
"0.0.0.0"
], ],
"jinja": true, "jinja": true,
"justMyCode": true "justMyCode": true

View File

@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . . COPY . .
# Запуск приложения (замени app.py на свой файл) # Запуск приложения (замени app.py на свой файл)
CMD ["python", "main.py"] CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]

Binary file not shown.

Binary file not shown.

165
adapters/kling_adapter.py Normal file
View File

@@ -0,0 +1,165 @@
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import httpx
import jwt
logger = logging.getLogger(__name__)
KLING_API_BASE = "https://api.klingai.com"
class KlingApiException(Exception):
pass
class KlingAdapter:
def __init__(self, access_key: str, secret_key: str):
if not access_key or not secret_key:
raise ValueError("Kling API credentials are missing")
self.access_key = access_key
self.secret_key = secret_key
def _generate_token(self) -> str:
"""Generate a JWT token for Kling API authentication."""
now = int(time.time())
payload = {
"iss": self.access_key,
"exp": now + 1800, # 30 minutes
"iat": now - 5, # небольшой запас назад
"nbf": now - 5,
}
return jwt.encode(payload, self.secret_key, algorithm="HS256",
headers={"typ": "JWT", "alg": "HS256"})
def _headers(self) -> dict:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._generate_token()}"
}
async def create_video_task(
self,
image_url: str,
prompt: str = "",
negative_prompt: str = "",
model_name: str = "kling-v2-6",
duration: int = 5,
mode: str = "std",
cfg_scale: float = 0.5,
aspect_ratio: str = "16:9",
callback_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create an image-to-video generation task.
Returns the full task data dict including task_id.
"""
body: Dict[str, Any] = {
"model_name": model_name,
"image": image_url,
"prompt": prompt,
"negative_prompt": negative_prompt,
"duration": str(duration),
"mode": mode,
"cfg_scale": cfg_scale,
"aspect_ratio": aspect_ratio,
}
if callback_url:
body["callback_url"] = callback_url
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{KLING_API_BASE}/v1/videos/image2video",
headers=self._headers(),
json=body,
)
data = response.json()
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown Kling API error")
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
task_data = data.get("data", {})
task_id = task_data.get("task_id")
if not task_id:
raise KlingApiException("No task_id returned from Kling API")
logger.info(f"Kling video task created: task_id={task_id}")
return task_data
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
Query the status of a video generation task.
Returns the full task data dict.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
headers=self._headers(),
)
data = response.json()
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown error")
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
return data.get("data", {})
async def wait_for_completion(
self,
task_id: str,
poll_interval: int = 10,
timeout: int = 600,
progress_callback=None,
) -> Dict[str, Any]:
"""
Poll the task status until completion.
Args:
task_id: Kling task ID
poll_interval: seconds between polls
timeout: max seconds to wait
progress_callback: async callable(progress_pct: int) to report progress
Returns:
Final task data dict with video URL on success.
Raises:
KlingApiException on failure or timeout.
"""
start = time.time()
attempt = 0
while True:
elapsed = time.time() - start
if elapsed > timeout:
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
task_data = await self.get_task_status(task_id)
status = task_data.get("task_status")
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
if status == "succeed":
logger.info(f"Kling task {task_id} completed successfully")
return task_data
if status == "failed":
fail_reason = task_data.get("task_status_msg", "Unknown failure")
raise KlingApiException(f"Video generation failed: {fail_reason}")
# Report progress estimate (linear approximation based on typical time)
if progress_callback:
# Estimate: typical gen is ~120s, cap at 90%
estimated_progress = min(int((elapsed / 120) * 90), 90)
attempt += 1
await progress_callback(estimated_progress)
await asyncio.sleep(poll_interval)

View File

@@ -12,12 +12,16 @@ from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА --- # --- ИМПОРТЫ ПРОЕКТА ---
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from api.service.album_service import AlbumService
from middlewares.album import AlbumMiddleware from middlewares.album import AlbumMiddleware
from middlewares.auth import AuthMiddleware from middlewares.auth import AuthMiddleware
from middlewares.dao import DaoMiddleware from middlewares.dao import DaoMiddleware
@@ -38,6 +42,8 @@ from api.endpoints.character_router import router as api_char_router # Роут
from api.endpoints.generation_router import router as api_gen_router from api.endpoints.generation_router import router as api_gen_router
from api.endpoints.auth import router as api_auth_router from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -78,7 +84,19 @@ s3_adapter = S3Adapter(
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
gemini = GoogleAdapter(api_key=GEMINI_API_KEY) gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot)
# Kling Adapter (optional, for video generation)
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
kling_adapter = None
if kling_access_key and kling_secret_key:
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
logger.info("Kling adapter initialized")
else:
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
album_service = AlbumService(dao)
# Dispatcher # Dispatcher
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME)) dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
@@ -132,6 +150,8 @@ async def lifespan(app: FastAPI):
app.state.gemini_client = gemini app.state.gemini_client = gemini
app.state.bot = bot app.state.bot = bot
app.state.s3_adapter = s3_adapter app.state.s3_adapter = s3_adapter
app.state.kling_adapter = kling_adapter
app.state.album_service = album_service
app.state.users_repo = users_repo # Добавляем репозиторий в state app.state.users_repo = users_repo # Добавляем репозиторий в state
print("✅ DB & DAO initialized") print("✅ DB & DAO initialized")
@@ -139,10 +159,10 @@ async def lifespan(app: FastAPI):
# 2. ЗАПУСК БОТА (в фоне) # 2. ЗАПУСК БОТА (в фоне)
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn # Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше # Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
polling_task = asyncio.create_task( # polling_task = asyncio.create_task(
dp.start_polling(bot, handle_signals=False) # dp.start_polling(bot, handle_signals=False)
) # )
print("🤖 Bot polling started") # print("🤖 Bot polling started")
yield yield
@@ -173,16 +193,26 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Подключаем роутер API # Подключаем роутеры API
from api.endpoints.auth import router as auth_api_router app.include_router(api_auth_router)
from api.endpoints.admin import router as admin_api_router app.include_router(api_admin_router)
app.include_router(auth_api_router)
app.include_router(admin_api_router)
app.include_router(api_assets_router) app.include_router(api_assets_router)
app.include_router(api_char_router) app.include_router(api_char_router)
app.include_router(api_gen_router) app.include_router(api_gen_router)
app.include_router(api_admin_router) app.include_router(api_album_router)
app.include_router(api_auth_router) app.include_router(project_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
).instrument(
app,
metric_namespace="ai_bot",
).expose(app, endpoint="/metrics", include_in_schema=False)
app_info = Info("fastapi_app_info", "FastAPI application info")
app_info.info({"app_name": "ai-bot"})
# --- ХЕНДЛЕРЫ БОТА (Main Router) --- # --- ХЕНДЛЕРЫ БОТА (Main Router) ---

View File

@@ -3,6 +3,7 @@ from fastapi import Request, Depends
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from repos.dao import DAO from repos.dao import DAO
@@ -36,11 +37,20 @@ def get_dao(
# так что DAO создастся один раз за запрос. # так что DAO создастся один раз за запрос.
return DAO(mongo_client, s3_adapter) return DAO(mongo_client, s3_adapter)
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
return request.app.state.kling_adapter
# Провайдер сервиса (собирается из DAO и Gemini) # Провайдер сервиса (собирается из DAO и Gemini)
def get_generation_service( def get_generation_service(
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao),
gemini: GoogleAdapter = Depends(get_gemini_client), gemini: GoogleAdapter = Depends(get_gemini_client),
s3_adapter: S3Adapter = Depends(get_s3_adapter), s3_adapter: S3Adapter = Depends(get_s3_adapter),
bot: Bot = Depends(get_bot_client), bot: Bot = Depends(get_bot_client),
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id

View File

@@ -0,0 +1,81 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
router = APIRouter(prefix="/api/albums", tags=["Albums"])
class AlbumCreateRequest(BaseModel):
name: str
description: Optional[str] = None
class AlbumUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
class AlbumResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
generation_ids: List[str] = []
cover_asset_id: Optional[str] = None # Not implemented yet
@router.post("", response_model=AlbumResponse)
async def create_album(request: Request, album_in: AlbumCreateRequest):
service: AlbumService = request.app.state.album_service
album = await service.create_album(name=album_in.name, description=album_in.description)
return AlbumResponse(**album.model_dump())
@router.get("", response_model=List[AlbumResponse])
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
albums = await service.get_albums(limit=limit, offset=offset)
return [AlbumResponse(**album.model_dump()) for album in albums]
@router.get("/{album_id}", response_model=AlbumResponse)
async def get_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
album = await service.get_album(album_id)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.put("/{album_id}", response_model=AlbumResponse)
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
service: AlbumService = request.app.state.album_service
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
deleted = await service.delete_album(album_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
@router.post("/{album_id}/generations/{generation_id}")
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.add_generation_to_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.delete("/{album_id}/generations/{generation_id}")
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.remove_generation_from_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
return [GenerationResponse(**gen.model_dump()) for gen in generations]

View File

@@ -1,17 +1,21 @@
from typing import List, Optional from typing import List, Optional, Dict, Any
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from bson import ObjectId
from fastapi import APIRouter, UploadFile, File, Form, Depends from fastapi import APIRouter, UploadFile, File, Form, Depends
from fastapi.openapi.models import MediaType from fastapi.openapi.models import MediaType
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from starlette import status from starlette import status
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import Response, JSONResponse
from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models.AssetDTO import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao, get_mongo_client, get_s3_adapter
import asyncio import asyncio
import logging import logging
@@ -19,6 +23,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/assets", tags=["Assets"]) router = APIRouter(prefix="/api/assets", tags=["Assets"])
@@ -50,6 +55,119 @@ async def get_asset(
return Response(content=content, media_type=media_type, headers=headers) return Response(content=content, media_type=media_type, headers=headers)
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio(
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
minio_client: S3Adapter = Depends(get_s3_adapter),
*,
assets_collection: str = "assets",
generations_collection: str = "generations",
asset_type: Optional[str] = "generated",
project_id: Optional[str] = None,
dry_run: bool = True,
mark_assets_deleted: bool = False,
batch_size: int = 500,
) -> Dict[str, Any]:
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
assets = db[assets_collection]
match_assets: Dict[str, Any] = {}
if asset_type is not None:
match_assets["type"] = asset_type
if project_id is not None:
match_assets["project_id"] = project_id
pipeline: List[Dict[str, Any]] = [
{"$match": match_assets} if match_assets else {"$match": {}},
{
"$lookup": {
"from": generations_collection,
"let": {"assetIdStr": {"$toString": "$_id"}},
"pipeline": [
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
{"$match": {"is_deleted": {"$ne": True}}},
{
"$match": {
"$expr": {
"$in": [
"$$assetIdStr",
{"$ifNull": ["$result_list", []]},
]
}
}
},
{"$limit": 1},
],
"as": "alive_generations",
}
},
{
"$match": {
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
}
},
{
"$project": {
"_id": 1,
"minio_object_name": 1,
"minio_thumbnail_object_name": 1,
}
},
]
print(pipeline)
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
deleted_objects = 0
deleted_assets = 0
errors: List[Dict[str, Any]] = []
orphan_asset_ids: List[ObjectId] = []
async for asset in cursor:
aid = asset["_id"]
obj = asset.get("minio_object_name")
thumb = asset.get("minio_thumbnail_object_name")
orphan_asset_ids.append(aid)
if dry_run:
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
continue
try:
if obj:
await minio_client.delete_file(obj)
deleted_objects += 1
if thumb:
await minio_client.delete_file(thumb)
deleted_objects += 1
deleted_assets += 1
except Exception as e:
errors.append({"asset_id": str(aid), "error": str(e)})
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
res = await assets.update_many(
{"_id": {"$in": orphan_asset_ids}},
{"$set": {"is_deleted": True}},
)
marked = res.modified_count
else:
marked = 0
return {
"dry_run": dry_run,
"filter": {
"asset_type": asset_type,
"project_id": project_id,
},
"orphans_found": len(orphan_asset_ids),
"deleted_assets": deleted_assets,
"deleted_objects": deleted_objects,
"marked_assets_deleted": marked,
"errors": errors,
}
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) @router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
async def delete_asset( async def delete_asset(
@@ -68,11 +186,19 @@ async def delete_asset(
@router.get("", dependencies=[Depends(get_current_user)]) @router.get("", dependencies=[Depends(get_current_user)])
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse: async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}") logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
assets = await dao.assets.get_assets(type, limit, offset)
user_id_filter = current_user["id"]
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code # assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
total_count = await dao.assets.get_asset_count() total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
# Manually map to DTO to trigger computed fields validation if necessary, # Manually map to DTO to trigger computed fields validation if necessary,
# but primarily to ensure valid Pydantic models for the response list. # but primarily to ensure valid Pydantic models for the response list.
@@ -84,11 +210,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)]) @router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset( async def upload_asset(
file: UploadFile = File(...), file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None), linked_char_id: Optional[str] = Form(None),
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id)
): ):
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}") logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
if not file.content_type: if not file.content_type:
@@ -97,6 +225,11 @@ async def upload_asset(
if not file.content_type.startswith("image/"): if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}") raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
data = await file.read() data = await file.read()
if not data: if not data:
raise HTTPException(status_code=400, detail="Empty file") raise HTTPException(status_code=400, detail="Empty file")
@@ -111,7 +244,9 @@ async def upload_asset(
content_type=AssetContentType.IMAGE, content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id, linked_char_id=linked_char_id,
data=data, data=data,
thumbnail=thumbnail_bytes thumbnail=thumbnail_bytes,
created_by=str(current_user["_id"]),
project_id=project_id,
) )
asset_id = await dao.assets.create_asset(asset) asset_id = await dao.assets.create_asset(asset)
@@ -172,3 +307,4 @@ async def migrate_to_minio(dao: DAO = Depends(get_dao)):
result = await dao.assets.migrate_to_minio() result = await dao.assets.migrate_to_minio()
logger.info(f"Migration result: {result}") logger.info(f"Migration result: {result}")
return result return result

View File

@@ -59,6 +59,7 @@ class Token(BaseModel):
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str
username: str username: str
full_name: str | None = None full_name: str | None = None
status: str status: str

View File

@@ -1,4 +1,4 @@
from typing import List, Any, Coroutine from typing import List, Any, Coroutine, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
@@ -9,6 +9,7 @@ from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from models.Asset import Asset from models.Asset import Asset
from models.Character import Character from models.Character import Character
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao
@@ -17,25 +18,49 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)]) router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
@router.get("/", response_model=List[Character]) @router.get("/", response_model=List[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]: async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
logger.info("get_characters called") logger.info("get_characters called")
characters = await dao.chars.get_all_characters()
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
return characters return characters
@router.get("/{character_id}/assets", response_model=AssetsResponse) @router.get("/{character_id}/assets", response_model=AssetsResponse)
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10, async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
offset: int = 0, ) -> AssetsResponse: offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}") logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if character is None: if character is None:
raise HTTPException(status_code=404, detail="Character not found") raise HTTPException(status_code=404, detail="Character not found")
# Access Check
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset) assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
# Filter assets by user ownership as well?
# Usually if you own character, you see its assets.
# But assets also have specific created_by.
# Let's assume if you own character you can see its assets.
total_count = await dao.assets.get_asset_count(character_id) total_count = await dao.assets.get_asset_count(character_id)
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets] asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
@@ -43,12 +68,118 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
@router.get("/{character_id}", response_model=Character) @router.get("/{character_id}", response_model=Character)
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character: async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
logger.debug(f"get_character_by_id called. ID: {character_id}") logger.debug(f"get_character_by_id called. ID: {character_id}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
if character:
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
return character return character
@router.post("/", response_model=Character)
async def create_character(
char_req: CharacterCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info("create_character called")
char_req.project_id = project_id
char_data = char_req.model_dump()
char_data["created_by"] = str(current_user["_id"])
if "id" not in char_data:
char_data["id"] = None
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
new_char = Character(**char_data)
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
created_char = await dao.chars.add_character(new_char)
return created_char
@router.put("/{character_id}", response_model=Character)
async def update_character(
character_id: str,
char_update: CharacterUpdateRequest,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info(f"update_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
update_data = char_update.model_dump(exclude_unset=True)
if "project_id" in update_data and update_data["project_id"]:
new_project_id = update_data["project_id"]
project = await dao.projects.get_project(new_project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Target project access denied")
updated_char_data = existing_char.model_dump()
updated_char_data.update(update_data)
updated_char = Character(**updated_char_data)
success = await dao.chars.update_char(character_id, updated_char)
if not success:
raise HTTPException(status_code=500, detail="Failed to update character")
return updated_char
@router.delete("/{character_id}", status_code=204)
async def delete_character(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"delete_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
success = await dao.chars.delete_character(character_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete character")
return
@router.post("/{character_id}/_run", response_model=GenerationResponse) @router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest, async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse: request: Request) -> GenerationResponse:

View File

@@ -1,13 +1,15 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends from fastapi.params import Depends
from starlette.requests import Request from starlette.requests import Request
from api import service from api import service
from api.dependency import get_generation_service from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.models.VideoGenerationRequest import VideoGenerationRequest
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Generation import Generation from models.Generation import Generation
@@ -19,13 +21,14 @@ logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)]) router = APIRouter(prefix='/api/generations', tags=["Generation"])
@router.post("/prompt-assistant", response_model=PromptResponse) @router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends( generation_service: GenerationService = Depends(
get_generation_service)) -> PromptResponse: get_generation_service),
current_user: dict = Depends(get_current_user)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}") logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt) return PromptResponse(prompt=generated_prompt)
@@ -35,7 +38,8 @@ async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
async def prompt_from_image( async def prompt_from_image(
prompt: Optional[str] = Form(None), prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...), images: List[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service) generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse: ) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}") logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = [] images_bytes = []
@@ -49,34 +53,139 @@ async def prompt_from_image(
@router.get("", response_model=GenerationsResponse) @router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)): generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse) @router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends( generation_service: GenerationService = Depends(get_generation_service),
get_generation_service)) -> GenerationResponse: current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
return await generation_service.create_generation_task(generation)
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse) @router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str, async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse: generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}") logger.debug(f"get_generation called for ID: {generation_id}")
return await generation_service.get_generation(generation_id) gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service)): generation_service: GenerationService = Depends(get_generation_service),
return await generation_service.get_running_generations() current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) @router.post("/video/_run", response_model=GenerationResponse)
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)): async def post_video_generation(
video_request: VideoGenerationRequest,
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao),
) -> GenerationResponse:
"""Start image-to-video generation using Kling AI."""
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
video_request.project_id = project_id
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
@router.post("/import", response_model=GenerationResponse)
async def import_external_generation(
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
"""
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
import os
from utils.external_auth import verify_signature
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = os.getenv("EXTERNAL_API_SECRET")
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
logger.warning("Invalid signature for external generation import")
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
import json
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
# Import generation
try:
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
logger.error(f"Failed to import external generation: {e}")
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"delete_generation called for ID: {generation_id}") logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id) deleted = await generation_service.delete_generation(generation_id)
if not deleted: if not deleted:

View File

@@ -0,0 +1,167 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from models.Project import Project
from repos.dao import DAO
router = APIRouter(prefix="/api/projects", tags=["Projects"])
class ProjectCreate(BaseModel):
name: str
description: Optional[str] = None
class ProjectResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
owner_id: str
members: List[str]
is_owner: bool = False
@router.post("", response_model=ProjectResponse)
async def create_project(
project_data: ProjectCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
new_project = Project(
name=project_data.name,
description=project_data.description,
owner_id=user_id,
members=[user_id]
)
project_id = await dao.projects.create_project(new_project)
# Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return ProjectResponse(
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
@router.get("", response_model=List[ProjectResponse])
async def get_my_projects(
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
projects = await dao.projects.get_projects_by_user(user_id)
responses = []
for p in projects:
responses.append(ProjectResponse(
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
return responses
class MemberAdd(BaseModel):
username: str
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
async def add_member(
project_id: str,
member_data: MemberAdd,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can add members")
target_user = await dao.users.get_user_by_username(member_data.username)
if not target_user:
raise HTTPException(status_code=404, detail="User not found")
target_user_id = str(target_user["_id"])
if target_user_id in project.members:
return {"message": "User already in project"}
await dao.projects.add_member(project_id, target_user_id)
# Update target user's project list
await dao.users.collection.update_one(
{"_id": target_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Member added"}
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
async def join_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
# Retrieve project to verify it exists
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
user_id = str(current_user["_id"])
# Check if user is ALREADY in project
if user_id in project.members:
return {"message": "Already a member"}
# Add member
await dao.projects.add_member(project_id, user_id)
# Update user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Joined project"}
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
async def delete_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can delete project")
await dao.projects.delete_project(project_id)
# Remove project from user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$pull": {"project_ids": project_id}}
)
return {"message": "Project deleted"}

View File

@@ -0,0 +1,18 @@
from typing import Optional
from pydantic import BaseModel
class CharacterCreateRequest(BaseModel):
name: str
character_bio: str
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None
class CharacterUpdateRequest(BaseModel):
name: Optional[str] = None
character_bio: Optional[str] = None
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -0,0 +1,37 @@
from typing import Optional
from pydantic import BaseModel, Field
from models.enums import AspectRatios, Quality
class ExternalGenerationRequest(BaseModel):
"""Request model for importing external generations."""
prompt: str
tech_prompt: Optional[str] = None
# Image can be provided as base64 string OR URL (one must be provided)
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
image_url: Optional[str] = Field(None, description="URL to download image from")
# Generation metadata
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
quality: Quality = Quality.ONEK
# Optional linking
linked_character_id: Optional[str] = None
created_by: str = Field(..., description="User ID from external system")
project_id: Optional[str] = None
# Performance metrics
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
def validate_image_source(self):
"""Ensure at least one image source is provided."""
if not self.image_data and not self.image_url:
raise ValueError("Either image_data or image_url must be provided")
if self.image_data and self.image_url:
raise ValueError("Only one of image_data or image_url should be provided")

View File

@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
telegram_id: Optional[int] = None telegram_id: Optional[int] = None
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):
@@ -26,6 +27,7 @@ class GenerationsResponse(BaseModel):
class GenerationResponse(BaseModel): class GenerationResponse(BaseModel):
id: str id: str
status: GenerationStatus status: GenerationStatus
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None linked_character_id: Optional[str] = None
@@ -42,6 +44,12 @@ class GenerationResponse(BaseModel):
input_token_usage: Optional[int] = None input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
progress: int = 0 progress: int = 0
cost: Optional[float] = None
created_by: Optional[str] = None
# Video-specific
kling_task_id: Optional[str] = None
video_duration: Optional[int] = None
video_mode: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)

View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
class VideoGenerationRequest(BaseModel):
prompt: str = ""
negative_prompt: Optional[str] = ""
image_asset_id: str # ID ассета-картинки для source image
duration: int = 5 # 5 or 10 seconds
mode: str = "std" # "std" or "pro"
model_name: str = "kling-v2-1"
cfg_scale: float = 0.5
aspect_ratio: str = "16:9"
linked_character_id: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -0,0 +1,85 @@
from typing import List, Optional
from models.Album import Album
from models.Generation import Generation
from repos.dao import DAO
class AlbumService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
album = Album(name=name, description=description)
album_id = await self.dao.albums.create_album(album)
album.id = album_id
return album
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
return await self.dao.albums.get_albums(limit=limit, offset=offset)
async def get_album(self, album_id: str) -> Optional[Album]:
return await self.dao.albums.get_album(album_id)
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
album = await self.dao.albums.get_album(album_id)
if not album:
return None
if name:
album.name = name
if description is not None:
album.description = description
await self.dao.albums.update_album(album_id, album)
return album
async def delete_album(self, album_id: str) -> bool:
return await self.dao.albums.delete_album(album_id)
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
# Verify album exists
album = await self.dao.albums.get_album(album_id)
if not album:
return False
# Verify generation exists (optional but good practice)
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
if album.cover_asset_id is None and gen.status == 'done':
album.cover_asset_id = gen.result_list[0]
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
return await self.dao.albums.remove_generation(album_id, generation_id)
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
album = await self.dao.albums.get_album(album_id)
if not album or not album.generation_ids:
return []
# Slice the generation IDs (simple pagination on ID list)
# Note: This pagination is on IDs, then we fetch objects.
# Ideally, fetch only slice.
# Reverse to show newest first? Or just follow list order?
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
# Let's assume user wants same order as in list.
sliced_ids = album.generation_ids[offset : offset + limit]
if not sliced_ids:
return []
# Fetch generations by IDs
# We need a method in GenerationRepo to fetch by IDs.
# Currently we only have get_generations with filters.
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
# Let's add get_generations_by_ids to GenerationRepo.
# For now, I will use a loop if I can't modify Repo immediately,
# but I SHOULD modify GenerationRepo.
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
# But get_generations doesn't support generic filter passing.
# I'll update GenerationRepo to support fetching by IDs.
return await self.dao.generations.get_generations_by_ids(sliced_ids)

View File

@@ -1,15 +1,19 @@
import asyncio import asyncio
import logging import logging
import random import random
import base64
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO from io import BytesIO
import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter, KlingApiException
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
from api.models.VideoGenerationRequest import VideoGenerationRequest
# Импортируйте ваши модели DAO, Asset, Generation корректно # Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
@@ -62,11 +66,12 @@ async def generate_image_task(
return images_bytes, metrics return images_bytes, metrics
class GenerationService: class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
self.dao = dao self.dao = dao
self.gemini = gemini self.gemini = gemini
self.s3_adapter = s3_adapter self.s3_adapter = s3_adapter
self.bot = bot self.bot = bot
self.kling_adapter = kling_adapter
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str: async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
@@ -92,10 +97,10 @@ class GenerationService:
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[ async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
Generation]: Generation]:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset) generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
total_count = await self.dao.generations.count_generations(character_id = character_id) total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations] generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count) return GenerationsResponse(generations=generations, total_count=total_count)
@@ -106,15 +111,18 @@ class GenerationService:
else: else:
return GenerationResponse(**gen.model_dump()) return GenerationResponse(**gen.model_dump())
async def get_running_generations(self) -> List[Generation]: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING) return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse: async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
gen_id = None gen_id = None
generation_model = None generation_model = None
try: try:
generation_model = Generation(**generation_request.model_dump()) generation_model = Generation(**generation_request.model_dump())
if user_id:
generation_model.created_by = user_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id
@@ -155,22 +163,25 @@ class GenerationService:
# 2. Получаем ассеты-референсы (если они есть) # 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = [] reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = [] media_group_bytes: List[bytes] = []
generation_prompt = f""" generation_prompt = generation.prompt
# generation_prompt = f"""
Create detailed image of character in scene. # Create detailed image of character in scene.
SCENE DESCRIPTION: {generation.prompt} # SCENE DESCRIPTION: {generation.prompt}
Rules: # Rules:
- Integrate the character's appearance naturally into the scene description. # - Integrate the character's appearance naturally into the scene description.
- Focus on lighting, texture, and composition. # - Focus on lighting, texture, and composition.
""" # """
if generation.linked_character_id is not None: if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True) char_info = await self.dao.chars.get_character(generation.linked_character_id)
if char_info is None: if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found") raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image: if generation.use_profile_image:
media_group_bytes.append(char_info.character_image_data) avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
@@ -258,7 +269,9 @@ class GenerationService:
data=None, # Not storing bytes in DB anymore data=None, # Not storing bytes in DB anymore
minio_object_name=filename, minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name, minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes thumbnail=thumbnail_bytes,
created_by=generation.created_by,
project_id=generation.project_id
) )
# Сохраняем в БД # Сохраняем в БД
@@ -325,6 +338,261 @@ class GenerationService:
logger.error(f"Error in progress simulation: {e}") logger.error(f"Error in progress simulation: {e}")
async def import_external_generation(self, external_gen) -> Generation:
"""
Import a generation from an external source.
Args:
external_gen: ExternalGenerationRequest with generation data and image
Returns:
Created Generation object
"""
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
# Validate image source
external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
# 1. Process image (download or decode)
image_bytes = None
if external_gen.image_url:
# Download image from URL
logger.info(f"Downloading image from URL: {external_gen.image_url}")
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
image_bytes = response.content
elif external_gen.image_data:
# Decode base64 image
logger.info("Decoding base64 image data")
image_bytes = base64.b64decode(external_gen.image_data)
if not image_bytes:
raise ValueError("Failed to process image data")
# 2. Generate thumbnail
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
# 3. Save to S3
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
# 4. Create Asset
new_asset = Asset(
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=external_gen.linked_character_id,
data=None, # Not storing bytes in DB
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=external_gen.created_by,
project_id=external_gen.project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
logger.info(f"Created asset {asset_id} for external generation")
# 5. Create Generation record
generation = Generation(
status=GenerationStatus.DONE,
linked_character_id=external_gen.linked_character_id,
aspect_ratio=external_gen.aspect_ratio,
quality=external_gen.quality,
prompt=external_gen.prompt,
tech_prompt=external_gen.tech_prompt,
result_list=[new_asset.id],
result=new_asset.id,
progress=100,
execution_time_seconds=external_gen.execution_time_seconds,
api_execution_time_seconds=external_gen.api_execution_time_seconds,
token_usage=external_gen.token_usage,
input_token_usage=external_gen.input_token_usage,
output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
project_id=external_gen.project_id,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
logger.info(f"Created generation {gen_id} from external source")
return generation
# === VIDEO GENERATION (Kling) ===
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
"""Create a video generation task (async, returns immediately)."""
if not self.kling_adapter:
raise Exception("Kling adapter is not configured")
generation = Generation(
status=GenerationStatus.RUNNING,
gen_type=GenType.VIDEO,
linked_character_id=request.linked_character_id,
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
quality=Quality.ONEK,
prompt=request.prompt,
assets_list=[request.image_asset_id],
video_duration=request.duration,
video_mode=request.mode,
project_id=request.project_id,
)
if user_id:
generation.created_by = user_id
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
async def runner(gen, req):
logger.info(f"Starting background video generation task for ID: {gen.id}")
try:
await self.create_video_generation(gen, req)
logger.info(f"Background video generation task finished for ID: {gen.id}")
except Exception:
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen and db_gen.status != GenerationStatus.FAILED:
db_gen.status = GenerationStatus.FAILED
db_gen.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark video generation as FAILED")
logger.exception("create_video_generation task failed")
asyncio.create_task(runner(generation, request))
return GenerationResponse(**generation.model_dump())
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
"""Background video generation: call Kling API, poll, download result, save asset."""
start_time = datetime.now()
try:
# 1. Get source image presigned URL
asset = await self.dao.assets.get_asset(request.image_asset_id)
if not asset:
raise Exception(f"Asset {request.image_asset_id} not found")
if not asset.minio_object_name:
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
if not presigned_url:
raise Exception("Failed to generate presigned URL for source image")
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
# 2. Create Kling task
task_data = await self.kling_adapter.create_video_task(
image_url=presigned_url,
prompt=request.prompt,
negative_prompt=request.negative_prompt or "",
model_name=request.model_name,
duration=request.duration,
mode=request.mode,
cfg_scale=request.cfg_scale,
aspect_ratio=request.aspect_ratio,
)
task_id = task_data.get("task_id")
generation.kling_task_id = task_id
await self.dao.generations.update_generation(generation)
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
# 3. Poll for completion with progress updates
async def progress_callback(progress_pct: int):
generation.progress = progress_pct
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
result = await self.kling_adapter.wait_for_completion(
task_id=task_id,
poll_interval=10,
timeout=600,
progress_callback=progress_callback,
)
# 4. Extract video URL and download
works = result.get("task_result", {}).get("videos", [])
if not works:
raise Exception("No video in Kling result")
video_url = works[0].get("url")
video_duration = works[0].get("duration", request.duration)
if not video_url:
raise Exception("No video URL in Kling result")
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
async with httpx.AsyncClient(timeout=120.0) as client:
video_response = await client.get(video_url)
video_response.raise_for_status()
video_bytes = video_response.content
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
# 5. Upload to S3
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
# 6. Create Asset
new_asset = Asset(
name=f"Video_{generation.linked_character_id or 'gen'}",
type=AssetType.GENERATED,
content_type=AssetContentType.VIDEO,
linked_char_id=generation.linked_character_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=None, # видео thumbnails можно добавить позже
created_by=generation.created_by,
project_id=generation.project_id,
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
# 7. Finalize generation
end_time = datetime.now()
generation.result_list = [new_asset.id]
generation.result = new_asset.id
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.video_duration = video_duration
generation.execution_time_seconds = (end_time - start_time).total_seconds()
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
except KlingApiException as e:
logger.error(f"Kling API error for generation {generation.id}: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
except Exception as e:
logger.error(f"Video generation {generation.id} failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
async def delete_generation(self, generation_id: str) -> bool: async def delete_generation(self, generation_id: str) -> bool:
""" """
Soft delete generation by marking it as deleted. Soft delete generation by marking it as deleted.

12
models/Album.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime, UTC
from typing import Optional, List
from pydantic import BaseModel, Field
class Album(BaseModel):
id: Optional[str] = None
name: str
description: Optional[str] = None
cover_asset_id: Optional[str] = None
generation_ids: List[str] = []
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, computed_field, Field, model_validator
class AssetContentType(str, Enum): class AssetContentType(str, Enum):
IMAGE = 'image' IMAGE = 'image'
VIDEO = 'video'
PROMPT = 'prompt' PROMPT = 'prompt'
class AssetType(str, Enum): class AssetType(str, Enum):
@@ -28,6 +29,8 @@ class Asset(BaseModel):
minio_thumbnail_object_name: Optional[str] = None minio_thumbnail_object_name: Optional[str] = None
thumbnail: Optional[bytes] = None thumbnail: Optional[bytes] = None
tags: List[str] = [] tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -5,11 +5,13 @@ from pydantic_core.core_schema import computed_field
class Character(BaseModel): class Character(BaseModel):
id: str | None id: Optional[str] = None
name: str name: str
avatar_asset_id: Optional[str] = None
avatar_image: Optional[str] = None avatar_image: Optional[str] = None
character_image_data: Optional[bytes] = None character_image_data: Optional[bytes] = None
character_image_doc_tg_id: str character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: str | None character_image_tg_id: Optional[str] = None
character_bio: str character_bio: Optional[str] = None
created_by: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -2,7 +2,7 @@ from datetime import datetime, UTC
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, computed_field
from models.Asset import Asset from models.Asset import Asset
from models.enums import AspectRatios, Quality, GenType from models.enums import AspectRatios, Quality, GenType
@@ -16,6 +16,7 @@ class GenerationStatus(str, Enum):
class Generation(BaseModel): class Generation(BaseModel):
id: Optional[str] = None id: Optional[str] = None
status: GenerationStatus = GenerationStatus.RUNNING status: GenerationStatus = GenerationStatus.RUNNING
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None linked_character_id: Optional[str] = None
telegram_id: Optional[int] = None telegram_id: Optional[int] = None
@@ -34,5 +35,20 @@ class Generation(BaseModel):
input_token_usage: Optional[int] = None input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
# Video-specific fields
kling_task_id: Optional[str] = None
video_duration: Optional[int] = None # 5 or 10 seconds
video_mode: Optional[str] = None # "std" or "pro"
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@computed_field
def cost(self) -> float:
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
cost_input = self.input_token_usage * 0.000002
cost_output = self.output_token_usage * 0.00012
return round(cost_input + cost_output, 3)
return 0.0

12
models/Project.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class Project(BaseModel):
id: Optional[str] = None
name: str
description: Optional[str] = None
owner_id: str
members: List[str] = [] # List of User IDs
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

View File

@@ -34,10 +34,12 @@ class Quality(str, Enum):
class GenType(str, Enum): class GenType(str, Enum):
TEXT = 'Text' TEXT = 'Text'
IMAGE = 'Image' IMAGE = 'Image'
VIDEO = 'Video'
@property @property
def value_type(self) -> str: def value_type(self) -> str:
return { return {
GenType.TEXT: 'Text', GenType.TEXT: 'Text',
GenType.IMAGE: 'Image', GenType.IMAGE: 'Image',
GenType.VIDEO: 'Video',
}[self] }[self]

61
repos/albums_repo.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import List, Optional
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Album import Album
logger = logging.getLogger(__name__)
class AlbumsRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["albums"]
async def create_album(self, album: Album) -> str:
res = await self.collection.insert_one(album.model_dump())
return str(res.inserted_id)
async def get_album(self, album_id: str) -> Optional[Album]:
try:
res = await self.collection.find_one({"_id": ObjectId(album_id)})
if not res:
return None
res["id"] = str(res.pop("_id"))
return Album(**res)
except Exception:
return None
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
albums = []
for doc in res:
doc["id"] = str(doc.pop("_id"))
albums.append(Album(**doc))
return albums
async def update_album(self, album_id: str, album: Album) -> bool:
if not album.id:
album.id = album_id
model_dump = album.model_dump()
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
return res.modified_count > 0
async def delete_album(self, album_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
return res.deleted_count > 0
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
)
return res.modified_count > 0
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$pull": {"generation_ids": generation_id}}
)
return res.modified_count > 0

View File

@@ -46,7 +46,7 @@ class AssetsRepo:
res = await self.collection.insert_one(asset.model_dump()) res = await self.collection.insert_one(asset.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]: async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {} filter = {}
if asset_type: if asset_type:
filter["type"] = asset_type filter["type"] = asset_type
@@ -70,6 +70,12 @@ class AssetsRepo:
# if not with_data: args["data"] = 0; args["thumbnail"] = 0 # if not with_data: args["data"] = 0; args["thumbnail"] = 0
# So list DOES NOT return thumbnails by default. # So list DOES NOT return thumbnails by default.
args["thumbnail"] = 0 args["thumbnail"] = 0
if created_by:
filter["created_by"] = created_by
filter['project_id'] = None
if project_id:
filter["project_id"] = project_id
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None) res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
assets = [] assets = []
@@ -157,8 +163,15 @@ class AssetsRepo:
assets.append(Asset(**doc)) assets.append(Asset(**doc))
return assets return assets
async def get_asset_count(self, character_id: Optional[str] = None) -> int: async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {}) filter = {}
if character_id:
filter["linked_char_id"] = character_id
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
return await self.collection.count_documents(filter)
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids] object_ids = [ObjectId(asset_id) for asset_id in asset_ids]

View File

@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
@@ -12,7 +12,7 @@ class CharacterRepo:
async def add_character(self, character: Character) -> Character: async def add_character(self, character: Character) -> Character:
op = await self.collection.insert_one(character.model_dump()) op = await self.collection.insert_one(character.model_dump())
character.id = op.inserted_id character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None: async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
@@ -26,18 +26,25 @@ class CharacterRepo:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None) filter = {}
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
characters = [] args = {"character_image_data": 0} # don't return image data for list
for doc in docs: res = await self.collection.find(filter, args).to_list(None)
# Конвертируем ObjectId в строку и кладем в поле id chars = []
for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))
chars.append(Character(**doc))
return chars
# Создаем объект async def update_char(self, char_id: str, character: Character) -> bool:
characters.append(Character(**doc)) result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
return result.modified_count > 0
return characters async def delete_character(self, char_id: str) -> bool:
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
async def update_char(self, char_id: str, character: Character) -> None: return result.deleted_count > 0
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})

View File

@@ -4,6 +4,8 @@ from repos.assets_repo import AssetsRepo
from repos.char_repo import CharacterRepo from repos.char_repo import CharacterRepo
from repos.generation_repo import GenerationRepo from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from typing import Optional from typing import Optional
@@ -14,3 +16,6 @@ class DAO:
self.chars = CharacterRepo(client, db_name) self.chars = CharacterRepo(client, db_name)
self.assets = AssetsRepo(client, s3_adapter, db_name) self.assets = AssetsRepo(client, s3_adapter, db_name)
self.generations = GenerationRepo(client, db_name) self.generations = GenerationRepo(client, db_name)
self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)

View File

@@ -25,13 +25,19 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10) -> List[Generation]: limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False} filter = {"is_deleted": False}
if character_id is not None: if character_id is not None:
filter["linked_character_id"] = character_id filter["linked_character_id"] = character_id
if status is not None: if status is not None:
filter["status"] = status filter["status"] = status
if created_by is not None:
filter["created_by"] = created_by
filter["project_id"] = None
if project_id is not None:
filter["project_id"] = project_id
res = await self.collection.find(filter).sort("created_at", -1).skip( res = await self.collection.find(filter).sort("created_at", -1).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
@@ -40,13 +46,34 @@ class GenerationRepo:
generations.append(Generation(**generation)) generations.append(Generation(**generation))
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int: async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
if status is not None: if status is not None:
args["status"] = status args["status"] = status
if created_by is not None:
args["created_by"] = created_by
if project_id is not None:
args["project_id"] = project_id
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
generations: List[Generation] = []
# Maintain order of generation_ids
gen_map = {str(doc["_id"]): doc for doc in res}
for gen_id in generation_ids:
doc = gen_map.get(gen_id)
if doc:
doc["id"] = str(doc.pop("_id"))
generations.append(Generation(**doc))
return generations
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})

62
repos/project_repo.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Project import Project
class ProjectRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["projects"]
async def create_project(self, project: Project) -> str:
res = await self.collection.insert_one(project.model_dump())
return str(res.inserted_id)
async def get_project(self, project_id: str) -> Optional[Project]:
if not ObjectId.is_valid(project_id):
return None
res = await self.collection.find_one({"_id": ObjectId(project_id)})
if res:
res["id"] = str(res.pop("_id"))
return Project(**res)
return None
async def get_projects_by_user(self, user_id: str) -> List[Project]:
# Find projects where user is owner OR in members
filter = {
"$or": [
{"owner_id": user_id},
{"members": user_id}
],
"is_deleted": False
}
cursor = self.collection.find(filter).sort("created_at", -1)
projects = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
projects.append(Project(**doc))
return projects
async def add_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$addToSet": {"members": user_id}}
)
return res.modified_count > 0
async def remove_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$pull": {"members": user_id}}
)
return res.modified_count > 0
async def update_project(self, project_id: str, updates: dict) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$set": updates}
)
return res.modified_count > 0
async def delete_project(self, project_id: str) -> bool:
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
return res.modified_count > 0

View File

@@ -19,10 +19,16 @@ class UsersRepo:
self.collection = client[db_name]["users"] self.collection = client[db_name]["users"]
async def get_user(self, user_id: int): async def get_user(self, user_id: int):
return await self.collection.find_one({"user_id": user_id}) user = await self.collection.find_one({"user_id": user_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_user_by_username(self, username: str): async def get_user_by_username(self, username: str):
return await self.collection.find_one({"username": username}) user = await self.collection.find_one({"username": username})
if user:
user["id"] = str(user["_id"])
return user
async def create_user(self, username: str, password: str, full_name: Optional[str] = None): async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
"""Создает нового пользователя с username/паролем""" """Создает нового пользователя с username/паролем"""
@@ -38,15 +44,23 @@ class UsersRepo:
"created_at": datetime.now(), "created_at": datetime.now(),
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать) "is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
"is_web_user": True, "is_web_user": True,
"is_admin": False "is_admin": False,
"project_ids": [],
"current_project_id": None
} }
result = await self.collection.insert_one(user_doc) result = await self.collection.insert_one(user_doc)
return await self.collection.find_one({"_id": result.inserted_id}) user = await self.collection.find_one({"_id": result.inserted_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_pending_users(self): async def get_pending_users(self):
"""Возвращает список пользователей со статусом PENDING""" """Возвращает список пользователей со статусом PENDING"""
cursor = self.collection.find({"status": UserStatus.PENDING}) cursor = self.collection.find({"status": UserStatus.PENDING})
return await cursor.to_list(length=100) users = await cursor.to_list(length=100)
for user in users:
user["id"] = str(user["_id"])
return users
async def approve_user(self, username: str): async def approve_user(self, username: str):
await self.collection.update_one( await self.collection.update_one(

View File

@@ -50,3 +50,5 @@ passlib[argon2]==1.7.4
python-jose[cryptography]==3.3.0 python-jose[cryptography]==3.3.0
python-multipart==0.0.22 python-multipart==0.0.22
email-validator email-validator
prometheus-fastapi-instrumentator
PyJWT

View File

@@ -63,7 +63,8 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
character_image_data=file_io.read(), character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio character_bio=bio,
created_by=str(message.from_user.id)
) )
file_io.close() file_io.close()

View File

@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
await wait_msg.delete() await wait_msg.delete()
doc = await message.answer_document(res[0], caption="Generated result 💫") doc = await message.answer_document(res[0], caption="Generated result 💫")
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data, await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
@router.message(Command("gen_mode")) @router.message(Command("gen_mode"))
@@ -259,7 +259,8 @@ async def handle_album(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
linked_char_id = data["char_id"])) linked_char_id = data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Генерация не вернула изображений.") await message.answer("❌ Генерация не вернула изображений.")
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
@@ -314,7 +315,8 @@ async def gen_mode_start(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
linked_char_id=data["char_id"])) linked_char_id=data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Ничего не сгенерировалось.") await message.answer("❌ Ничего не сгенерировалось.")

View File

@@ -0,0 +1,101 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio
from main import app
from api.endpoints.auth import get_current_user
from api.dependency import get_dao
from repos.dao import DAO
from models.Character import Character
# Config for test DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_chars"
# Mock User
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed"
}
# Override get_current_user to bypass auth
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
# Setup Real DAO with Test DB
client_mongo = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client_mongo, db_name=DB_NAME)
def mock_get_dao():
return dao
app.dependency_overrides[get_dao] = mock_get_dao
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
# Setup: Ensure clean state
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
yield
# Teardown
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
loop.close()
def test_character_crud_flow():
# 1. Create Character
create_payload = {
"name": "Test Character",
"character_bio": "A bio for test character",
"character_image_doc_tg_id": "file_123",
"avatar_image": "http://example.com/avatar.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
assert char_data["created_by"] == MOCK_USER_ID
char_id = char_data["id"]
assert char_id is not None
# 2. Get Character
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update Character
update_payload = {
"name": "Updated Name",
"character_bio": "Updated bio"
}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
updated_data = response.json()
assert updated_data["name"] == "Updated Name"
assert updated_data["character_bio"] == "Updated bio"
# Verify update persistent
response = client.get(f"/api/characters/{char_id}")
assert response.json()["name"] == "Updated Name"
# 4. Delete Character
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# Verify deletion
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404, "Deleted character should return 404"

View File

@@ -0,0 +1,64 @@
import os
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
# 1. Set Auth Bypass and Test Config
os.environ["DB_NAME"] = "bot_db_test_integration"
# We keep MONGO_HOST as is (it works in verified script)
# 2. Import app AFTER setting env
from main import app
from api.endpoints.auth import get_current_user
# 3. Override Auth
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed",
"project_ids": []
}
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
client = TestClient(app)
def test_character_crud_lifecycle():
# 1. Create
create_payload = {
"name": "Integration Test Char",
"character_bio": "Testing with real app structure",
"character_image_doc_tg_id": "doc_123",
"avatar_image": "http://example.com/img.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
char_id = char_data["id"]
# 2. Get
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update
update_payload = {"name": "Updated Int Name"}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
assert response.json()["name"] == "Updated Int Name"
# 4. Delete
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# 5. Verify Delete
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404

63
tests/test_external_import.py Executable file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
Test script for external generation import API.
This script demonstrates how to call the import endpoint with proper HMAC signature.
"""
import hmac
import hashlib
import json
import requests
import base64
import os
from dotenv import load_dotenv
load_dotenv()
# Configuration
API_URL = "http://localhost:8090/api/generations/import"
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
# Sample generation data
generation_data = {
"prompt": "A beautiful sunset over mountains",
"tech_prompt": "High quality landscape photography",
"image_url": "https://picsum.photos/512/512", # Sample image URL
# OR use base64:
# "image_data": "base64_encoded_image_string_here",
"aspect_ratio": "9:16",
"quality": "1k",
"created_by": "external_user_123",
"execution_time_seconds": 5.2,
"token_usage": 1000,
"input_token_usage": 200,
"output_token_usage": 800
}
# Convert to JSON
body = json.dumps(generation_data).encode('utf-8')
# Compute HMAC signature
signature = hmac.new(
SECRET.encode('utf-8'),
body,
hashlib.sha256
).hexdigest()
# Make request
headers = {
"Content-Type": "application/json",
"X-Signature": signature
}
print(f"Sending request to {API_URL}")
print(f"Signature: {signature}")
try:
response = requests.post(API_URL, data=body, headers=headers)
print(f"\nStatus Code: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
except Exception as e:
print(f"Error: {e}")
if hasattr(e, 'response'):
print(f"Response text: {e.response.text}")

View File

@@ -0,0 +1,91 @@
import asyncio
import os
import sys
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from motor.motor_asyncio import AsyncIOMotorClient
from repos.dao import DAO
from models.Album import Album
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
# Mock config
# Use the same host as aiws.py but different DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_albums"
async def test_albums():
print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...")
# Needs to run inside a loop from main
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
try:
# 1. Clean up
await client[DB_NAME]["albums"].drop()
await client[DB_NAME]["generations"].drop()
print("✅ Cleaned up test database")
# 2. Create Album
album = Album(name="Test Album", description="A test album")
print("Creating album...")
album_id = await dao.albums.create_album(album)
print(f"✅ Created Album: {album_id}")
# 3. Create Generations
gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
print("Creating generations...")
gen1_id = await dao.generations.create_generation(gen1)
gen2_id = await dao.generations.create_generation(gen2)
print(f"✅ Created Generations: {gen1_id}, {gen2_id}")
# 4. Add generations to album
print("Adding generations to album...")
await dao.albums.add_generation(album_id, gen1_id)
await dao.albums.add_generation(album_id, gen2_id)
print("✅ Added generations to album")
# 5. Fetch album and check generation_ids
album_fetched = await dao.albums.get_album(album_id)
assert album_fetched is not None
assert len(album_fetched.generation_ids) == 2
assert gen1_id in album_fetched.generation_ids
assert gen2_id in album_fetched.generation_ids
print("✅ Verified generations in album")
# 6. Fetch generations by IDs via GenerationRepo
generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id])
assert len(generations) == 2
# Ensure ID type match (str vs ObjectId handling in repo)
gen_ids_fetched = [g.id for g in generations]
assert gen1_id in gen_ids_fetched
assert gen2_id in gen_ids_fetched
print("✅ Verified fetching generations by IDs")
# 7. Remove generation
print("Removing generation...")
await dao.albums.remove_generation(album_id, gen1_id)
album_fetched = await dao.albums.get_album(album_id)
assert len(album_fetched.generation_ids) == 1
assert album_fetched.generation_ids[0] == gen2_id
print("✅ Verified removing generation from album")
print("🎉 Album Verification SUCCESS")
finally:
# Cleanup client
client.close()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
try:
asyncio.run(test_albums())
except Exception as e:
print(f"Error: {e}")

46
utils/external_auth.py Normal file
View File

@@ -0,0 +1,46 @@
import hmac
import hashlib
import os
from fastapi import Header, HTTPException
from typing import Optional
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
"""
Verify HMAC-SHA256 signature.
Args:
body: Raw request body bytes
signature: Signature from X-Signature header
secret: Shared secret key
Returns:
True if signature is valid, False otherwise
"""
expected_signature = hmac.new(
secret.encode('utf-8'),
body,
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
async def verify_external_signature(
x_signature: Optional[str] = Header(None, alias="X-Signature")
):
"""
FastAPI dependency to verify external API signature.
Raises:
HTTPException: If signature is missing or invalid
"""
if not x_signature:
raise HTTPException(
status_code=401,
detail="Missing X-Signature header"
)
# Note: We'll need to access the raw request body in the endpoint
# This dependency just validates the header exists
# Actual signature verification happens in the endpoint
return x_signature