Compare commits
19 Commits
8a89b27624
...
video
| Author | SHA1 | Date | |
|---|---|---|---|
| 32ff77e04b | |||
| d1f67c773f | |||
| c63b51ef75 | |||
| 456562ec1d | |||
| 0d0fbdf7d6 | |||
| f63bcedb13 | |||
| be92c766ac | |||
| 482bc1d9b7 | |||
| a2321cf070 | |||
| 29ccd5743e | |||
| d9de2f48d2 | |||
| 1ddeb0af46 | |||
| a7c2319f13 | |||
| 00e83b8561 | |||
| a9d24c725e | |||
| 458b6ebfc3 | |||
| 668aadcdc9 | |||
| 4461964791 | |||
| fa3e1bb05f |
19
.dockerignore
Normal file
19
.dockerignore
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
.venv/
|
||||||
|
node_modules/
|
||||||
|
tmp/
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.cache/
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
3
.env
3
.env
@@ -8,3 +8,6 @@ MINIO_ACCESS_KEY=admin
|
|||||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||||
MINIO_BUCKET=ai-char
|
MINIO_BUCKET=ai-char
|
||||||
MODE=production
|
MODE=production
|
||||||
|
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||||
|
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
|
||||||
|
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4
|
||||||
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -7,10 +7,12 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"module": "uvicorn",
|
"module": "uvicorn",
|
||||||
"args": [
|
"args": [
|
||||||
"main:app",
|
"aiws:app",
|
||||||
"--reload",
|
"--reload",
|
||||||
"--port",
|
"--port",
|
||||||
"8090"
|
"8090",
|
||||||
|
"--host",
|
||||||
|
"0.0.0.0"
|
||||||
],
|
],
|
||||||
"jinja": true,
|
"jinja": true,
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
|
|||||||
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Запуск приложения (замени app.py на свой файл)
|
# Запуск приложения (замени app.py на свой файл)
|
||||||
CMD ["python", "main.py"]
|
CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
165
adapters/kling_adapter.py
Normal file
165
adapters/kling_adapter.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KLING_API_BASE = "https://api.klingai.com"
|
||||||
|
|
||||||
|
|
||||||
|
class KlingApiException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KlingAdapter:
|
||||||
|
def __init__(self, access_key: str, secret_key: str):
|
||||||
|
if not access_key or not secret_key:
|
||||||
|
raise ValueError("Kling API credentials are missing")
|
||||||
|
self.access_key = access_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
|
||||||
|
def _generate_token(self) -> str:
|
||||||
|
"""Generate a JWT token for Kling API authentication."""
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"iss": self.access_key,
|
||||||
|
"exp": now + 1800, # 30 minutes
|
||||||
|
"iat": now - 5, # небольшой запас назад
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, self.secret_key, algorithm="HS256",
|
||||||
|
headers={"typ": "JWT", "alg": "HS256"})
|
||||||
|
|
||||||
|
def _headers(self) -> dict:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self._generate_token()}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def create_video_task(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
prompt: str = "",
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_name: str = "kling-v2-6",
|
||||||
|
duration: int = 5,
|
||||||
|
mode: str = "std",
|
||||||
|
cfg_scale: float = 0.5,
|
||||||
|
aspect_ratio: str = "16:9",
|
||||||
|
callback_url: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create an image-to-video generation task.
|
||||||
|
Returns the full task data dict including task_id.
|
||||||
|
"""
|
||||||
|
body: Dict[str, Any] = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"image": image_url,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"duration": str(duration),
|
||||||
|
"mode": mode,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
}
|
||||||
|
if callback_url:
|
||||||
|
body["callback_url"] = callback_url
|
||||||
|
|
||||||
|
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video",
|
||||||
|
headers=self._headers(),
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown Kling API error")
|
||||||
|
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
|
||||||
|
|
||||||
|
task_data = data.get("data", {})
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
raise KlingApiException("No task_id returned from Kling API")
|
||||||
|
|
||||||
|
logger.info(f"Kling video task created: task_id={task_id}")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Query the status of a video generation task.
|
||||||
|
Returns the full task data dict.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown error")
|
||||||
|
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
|
||||||
|
|
||||||
|
return data.get("data", {})
|
||||||
|
|
||||||
|
async def wait_for_completion(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 600,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Poll the task status until completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Kling task ID
|
||||||
|
poll_interval: seconds between polls
|
||||||
|
timeout: max seconds to wait
|
||||||
|
progress_callback: async callable(progress_pct: int) to report progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final task data dict with video URL on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KlingApiException on failure or timeout.
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
if elapsed > timeout:
|
||||||
|
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
|
||||||
|
|
||||||
|
task_data = await self.get_task_status(task_id)
|
||||||
|
status = task_data.get("task_status")
|
||||||
|
|
||||||
|
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||||
|
|
||||||
|
if status == "succeed":
|
||||||
|
logger.info(f"Kling task {task_id} completed successfully")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
if status == "failed":
|
||||||
|
fail_reason = task_data.get("task_status_msg", "Unknown failure")
|
||||||
|
raise KlingApiException(f"Video generation failed: {fail_reason}")
|
||||||
|
|
||||||
|
# Report progress estimate (linear approximation based on typical time)
|
||||||
|
if progress_callback:
|
||||||
|
# Estimate: typical gen is ~120s, cap at 90%
|
||||||
|
estimated_progress = min(int((elapsed / 120) * 90), 90)
|
||||||
|
attempt += 1
|
||||||
|
await progress_callback(estimated_progress)
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
@@ -12,10 +12,13 @@ from aiogram.fsm.storage.mongo import MongoStorage
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from prometheus_client import Info
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from api.service.album_service import AlbumService
|
from api.service.album_service import AlbumService
|
||||||
@@ -40,6 +43,7 @@ from api.endpoints.generation_router import router as api_gen_router
|
|||||||
from api.endpoints.auth import router as api_auth_router
|
from api.endpoints.auth import router as api_auth_router
|
||||||
from api.endpoints.admin import router as api_admin_router
|
from api.endpoints.admin import router as api_admin_router
|
||||||
from api.endpoints.album_router import router as api_album_router
|
from api.endpoints.album_router import router as api_album_router
|
||||||
|
from api.endpoints.project_router import router as project_api_router
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -80,7 +84,18 @@ s3_adapter = S3Adapter(
|
|||||||
|
|
||||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||||
generation_service = GenerationService(dao, gemini, bot)
|
|
||||||
|
# Kling Adapter (optional, for video generation)
|
||||||
|
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
|
||||||
|
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
|
||||||
|
kling_adapter = None
|
||||||
|
if kling_access_key and kling_secret_key:
|
||||||
|
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
|
||||||
|
logger.info("Kling adapter initialized")
|
||||||
|
else:
|
||||||
|
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
|
||||||
|
|
||||||
|
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
|
||||||
album_service = AlbumService(dao)
|
album_service = AlbumService(dao)
|
||||||
|
|
||||||
# Dispatcher
|
# Dispatcher
|
||||||
@@ -135,6 +150,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app.state.gemini_client = gemini
|
app.state.gemini_client = gemini
|
||||||
app.state.bot = bot
|
app.state.bot = bot
|
||||||
app.state.s3_adapter = s3_adapter
|
app.state.s3_adapter = s3_adapter
|
||||||
|
app.state.kling_adapter = kling_adapter
|
||||||
app.state.album_service = album_service
|
app.state.album_service = album_service
|
||||||
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||||
|
|
||||||
@@ -143,10 +159,10 @@ async def lifespan(app: FastAPI):
|
|||||||
# 2. ЗАПУСК БОТА (в фоне)
|
# 2. ЗАПУСК БОТА (в фоне)
|
||||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
||||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
||||||
polling_task = asyncio.create_task(
|
# polling_task = asyncio.create_task(
|
||||||
dp.start_polling(bot, handle_signals=False)
|
# dp.start_polling(bot, handle_signals=False)
|
||||||
)
|
# )
|
||||||
print("🤖 Bot polling started")
|
# print("🤖 Bot polling started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -177,17 +193,26 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Подключаем роутер API
|
# Подключаем роутеры API
|
||||||
from api.endpoints.auth import router as auth_api_router
|
app.include_router(api_auth_router)
|
||||||
from api.endpoints.admin import router as admin_api_router
|
app.include_router(api_admin_router)
|
||||||
app.include_router(auth_api_router)
|
|
||||||
app.include_router(admin_api_router)
|
|
||||||
app.include_router(api_assets_router)
|
app.include_router(api_assets_router)
|
||||||
app.include_router(api_char_router)
|
app.include_router(api_char_router)
|
||||||
app.include_router(api_gen_router)
|
app.include_router(api_gen_router)
|
||||||
app.include_router(api_album_router)
|
app.include_router(api_album_router)
|
||||||
app.include_router(api_admin_router)
|
app.include_router(project_api_router)
|
||||||
app.include_router(api_auth_router)
|
|
||||||
|
# Prometheus Metrics (Instrument after all routers are added)
|
||||||
|
Instrumentator(
|
||||||
|
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
|
||||||
|
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
|
||||||
|
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
|
||||||
|
).instrument(
|
||||||
|
app,
|
||||||
|
metric_namespace="ai_bot",
|
||||||
|
).expose(app, endpoint="/metrics", include_in_schema=False)
|
||||||
|
app_info = Info("fastapi_app_info", "FastAPI application info")
|
||||||
|
app_info.info({"app_name": "ai-bot"})
|
||||||
|
|
||||||
|
|
||||||
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
||||||
Binary file not shown.
Binary file not shown.
@@ -3,6 +3,7 @@ from fastapi import Request, Depends
|
|||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
|
||||||
@@ -36,11 +37,20 @@ def get_dao(
|
|||||||
# так что DAO создастся один раз за запрос.
|
# так что DAO создастся один раз за запрос.
|
||||||
return DAO(mongo_client, s3_adapter)
|
return DAO(mongo_client, s3_adapter)
|
||||||
|
|
||||||
|
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
|
||||||
|
return request.app.state.kling_adapter
|
||||||
|
|
||||||
# Провайдер сервиса (собирается из DAO и Gemini)
|
# Провайдер сервиса (собирается из DAO и Gemini)
|
||||||
def get_generation_service(
|
def get_generation_service(
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
gemini: GoogleAdapter = Depends(get_gemini_client),
|
gemini: GoogleAdapter = Depends(get_gemini_client),
|
||||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||||
bot: Bot = Depends(get_bot_client),
|
bot: Bot = Depends(get_bot_client),
|
||||||
|
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
|
||||||
) -> GenerationService:
|
) -> GenerationService:
|
||||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
|
||||||
|
|
||||||
|
from fastapi import Header
|
||||||
|
|
||||||
|
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||||
|
return x_project_id
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -23,13 +23,13 @@ class AlbumResponse(BaseModel):
|
|||||||
generation_ids: List[str] = []
|
generation_ids: List[str] = []
|
||||||
cover_asset_id: Optional[str] = None # Not implemented yet
|
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||||
|
|
||||||
@router.post("/", response_model=AlbumResponse)
|
@router.post("", response_model=AlbumResponse)
|
||||||
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||||
service: AlbumService = request.app.state.album_service
|
service: AlbumService = request.app.state.album_service
|
||||||
album = await service.create_album(name=album_in.name, description=album_in.description)
|
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||||
return AlbumResponse(**album.model_dump())
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
@router.get("/", response_model=List[AlbumResponse])
|
@router.get("", response_model=List[AlbumResponse])
|
||||||
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||||
service: AlbumService = request.app.state.album_service
|
service: AlbumService = request.app.state.album_service
|
||||||
albums = await service.get_albums(limit=limit, offset=offset)
|
albums = await service.get_albums(limit=limit, offset=offset)
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
|
from bson import ObjectId
|
||||||
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
||||||
from fastapi.openapi.models import MediaType
|
from fastapi.openapi.models import MediaType
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from pymongo import MongoClient
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response, JSONResponse
|
from starlette.responses import Response, JSONResponse
|
||||||
|
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||||
from models.Asset import Asset, AssetType, AssetContentType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.dependency import get_dao
|
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -19,6 +23,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from api.endpoints.auth import get_current_user
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_project_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||||
|
|
||||||
@@ -50,6 +55,119 @@ async def get_asset(
|
|||||||
|
|
||||||
return Response(content=content, media_type=media_type, headers=headers)
|
return Response(content=content, media_type=media_type, headers=headers)
|
||||||
|
|
||||||
|
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||||
|
async def delete_orphan_assets_from_minio(
|
||||||
|
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
|
||||||
|
minio_client: S3Adapter = Depends(get_s3_adapter),
|
||||||
|
*,
|
||||||
|
assets_collection: str = "assets",
|
||||||
|
generations_collection: str = "generations",
|
||||||
|
asset_type: Optional[str] = "generated",
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
dry_run: bool = True,
|
||||||
|
mark_assets_deleted: bool = False,
|
||||||
|
batch_size: int = 500,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||||
|
assets = db[assets_collection]
|
||||||
|
|
||||||
|
match_assets: Dict[str, Any] = {}
|
||||||
|
if asset_type is not None:
|
||||||
|
match_assets["type"] = asset_type
|
||||||
|
if project_id is not None:
|
||||||
|
match_assets["project_id"] = project_id
|
||||||
|
|
||||||
|
pipeline: List[Dict[str, Any]] = [
|
||||||
|
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||||
|
{
|
||||||
|
"$lookup": {
|
||||||
|
"from": generations_collection,
|
||||||
|
"let": {"assetIdStr": {"$toString": "$_id"}},
|
||||||
|
"pipeline": [
|
||||||
|
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
|
||||||
|
{"$match": {"is_deleted": {"$ne": True}}},
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"$expr": {
|
||||||
|
"$in": [
|
||||||
|
"$$assetIdStr",
|
||||||
|
{"$ifNull": ["$result_list", []]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$limit": 1},
|
||||||
|
],
|
||||||
|
"as": "alive_generations",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 1,
|
||||||
|
"minio_object_name": 1,
|
||||||
|
"minio_thumbnail_object_name": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
print(pipeline)
|
||||||
|
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
|
||||||
|
|
||||||
|
deleted_objects = 0
|
||||||
|
deleted_assets = 0
|
||||||
|
errors: List[Dict[str, Any]] = []
|
||||||
|
orphan_asset_ids: List[ObjectId] = []
|
||||||
|
|
||||||
|
async for asset in cursor:
|
||||||
|
aid = asset["_id"]
|
||||||
|
obj = asset.get("minio_object_name")
|
||||||
|
thumb = asset.get("minio_thumbnail_object_name")
|
||||||
|
|
||||||
|
orphan_asset_ids.append(aid)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if obj:
|
||||||
|
await minio_client.delete_file(obj)
|
||||||
|
deleted_objects += 1
|
||||||
|
|
||||||
|
if thumb:
|
||||||
|
await minio_client.delete_file(thumb)
|
||||||
|
deleted_objects += 1
|
||||||
|
|
||||||
|
deleted_assets += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.append({"asset_id": str(aid), "error": str(e)})
|
||||||
|
|
||||||
|
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
|
||||||
|
res = await assets.update_many(
|
||||||
|
{"_id": {"$in": orphan_asset_ids}},
|
||||||
|
{"$set": {"is_deleted": True}},
|
||||||
|
)
|
||||||
|
marked = res.modified_count
|
||||||
|
else:
|
||||||
|
marked = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dry_run": dry_run,
|
||||||
|
"filter": {
|
||||||
|
"asset_type": asset_type,
|
||||||
|
"project_id": project_id,
|
||||||
|
},
|
||||||
|
"orphans_found": len(orphan_asset_ids),
|
||||||
|
"deleted_assets": deleted_assets,
|
||||||
|
"deleted_objects": deleted_objects,
|
||||||
|
"marked_assets_deleted": marked,
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
|
||||||
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||||
async def delete_asset(
|
async def delete_asset(
|
||||||
@@ -68,11 +186,19 @@ async def delete_asset(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", dependencies=[Depends(get_current_user)])
|
@router.get("", dependencies=[Depends(get_current_user)])
|
||||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse:
|
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
|
||||||
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||||
assets = await dao.assets.get_assets(type, limit, offset)
|
|
||||||
|
user_id_filter = current_user["id"]
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
|
||||||
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
|
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
|
||||||
total_count = await dao.assets.get_asset_count()
|
total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
# Manually map to DTO to trigger computed fields validation if necessary,
|
# Manually map to DTO to trigger computed fields validation if necessary,
|
||||||
# but primarily to ensure valid Pydantic models for the response list.
|
# but primarily to ensure valid Pydantic models for the response list.
|
||||||
@@ -84,11 +210,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)])
|
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def upload_asset(
|
async def upload_asset(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
linked_char_id: Optional[str] = Form(None),
|
linked_char_id: Optional[str] = Form(None),
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id)
|
||||||
):
|
):
|
||||||
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||||
if not file.content_type:
|
if not file.content_type:
|
||||||
@@ -97,6 +225,11 @@ async def upload_asset(
|
|||||||
if not file.content_type.startswith("image/"):
|
if not file.content_type.startswith("image/"):
|
||||||
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
|
||||||
data = await file.read()
|
data = await file.read()
|
||||||
if not data:
|
if not data:
|
||||||
raise HTTPException(status_code=400, detail="Empty file")
|
raise HTTPException(status_code=400, detail="Empty file")
|
||||||
@@ -111,7 +244,9 @@ async def upload_asset(
|
|||||||
content_type=AssetContentType.IMAGE,
|
content_type=AssetContentType.IMAGE,
|
||||||
linked_char_id=linked_char_id,
|
linked_char_id=linked_char_id,
|
||||||
data=data,
|
data=data,
|
||||||
thumbnail=thumbnail_bytes
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=str(current_user["_id"]),
|
||||||
|
project_id=project_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
asset_id = await dao.assets.create_asset(asset)
|
asset_id = await dao.assets.create_asset(asset)
|
||||||
@@ -172,3 +307,4 @@ async def migrate_to_minio(dao: DAO = Depends(get_dao)):
|
|||||||
result = await dao.assets.migrate_to_minio()
|
result = await dao.assets.migrate_to_minio()
|
||||||
logger.info(f"Migration result: {result}")
|
logger.info(f"Migration result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ class Token(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
username: str
|
username: str
|
||||||
full_name: str | None = None
|
full_name: str | None = None
|
||||||
status: str
|
status: str
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Any, Coroutine
|
from typing import List, Any, Coroutine, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -9,6 +9,7 @@ from api.models.AssetDTO import AssetsResponse, AssetResponse
|
|||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
|
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.dependency import get_dao
|
from api.dependency import get_dao
|
||||||
|
|
||||||
@@ -17,25 +18,49 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from api.endpoints.auth import get_current_user
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_project_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Character])
|
@router.get("/", response_model=List[Character])
|
||||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
|
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
|
||||||
logger.info("get_characters called")
|
logger.info("get_characters called")
|
||||||
characters = await dao.chars.get_all_characters()
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
|
||||||
return characters
|
return characters
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
||||||
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
||||||
offset: int = 0, ) -> AssetsResponse:
|
offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
|
||||||
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
|
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||||
character = await dao.chars.get_character(character_id)
|
character = await dao.chars.get_character(character_id)
|
||||||
if character is None:
|
if character is None:
|
||||||
raise HTTPException(status_code=404, detail="Character not found")
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
# Access Check
|
||||||
|
is_creator = character.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
||||||
|
# Filter assets by user ownership as well?
|
||||||
|
# Usually if you own character, you see its assets.
|
||||||
|
# But assets also have specific created_by.
|
||||||
|
# Let's assume if you own character you can see its assets.
|
||||||
|
|
||||||
total_count = await dao.assets.get_asset_count(character_id)
|
total_count = await dao.assets.get_asset_count(character_id)
|
||||||
|
|
||||||
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
||||||
@@ -43,12 +68,118 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/{character_id}", response_model=Character)
|
@router.get("/{character_id}", response_model=Character)
|
||||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
|
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
|
||||||
logger.debug(f"get_character_by_id called. ID: {character_id}")
|
logger.debug(f"get_character_by_id called. ID: {character_id}")
|
||||||
character = await dao.chars.get_character(character_id)
|
character = await dao.chars.get_character(character_id)
|
||||||
|
|
||||||
|
if not character:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
if character:
|
||||||
|
is_creator = character.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
return character
|
return character
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=Character)
|
||||||
|
async def create_character(
|
||||||
|
char_req: CharacterCreateRequest,
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
) -> Character:
|
||||||
|
logger.info("create_character called")
|
||||||
|
char_req.project_id = project_id
|
||||||
|
char_data = char_req.model_dump()
|
||||||
|
char_data["created_by"] = str(current_user["_id"])
|
||||||
|
if "id" not in char_data:
|
||||||
|
char_data["id"] = None
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
|
||||||
|
new_char = Character(**char_data)
|
||||||
|
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
|
||||||
|
created_char = await dao.chars.add_character(new_char)
|
||||||
|
return created_char
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{character_id}", response_model=Character)
|
||||||
|
async def update_character(
|
||||||
|
character_id: str,
|
||||||
|
char_update: CharacterUpdateRequest,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
) -> Character:
|
||||||
|
logger.info(f"update_character called. ID: {character_id}")
|
||||||
|
|
||||||
|
existing_char = await dao.chars.get_character(character_id)
|
||||||
|
if not existing_char:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
update_data = char_update.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
if "project_id" in update_data and update_data["project_id"]:
|
||||||
|
new_project_id = update_data["project_id"]
|
||||||
|
project = await dao.projects.get_project(new_project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Target project access denied")
|
||||||
|
|
||||||
|
updated_char_data = existing_char.model_dump()
|
||||||
|
updated_char_data.update(update_data)
|
||||||
|
|
||||||
|
updated_char = Character(**updated_char_data)
|
||||||
|
|
||||||
|
success = await dao.chars.update_char(character_id, updated_char)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update character")
|
||||||
|
|
||||||
|
return updated_char
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{character_id}", status_code=204)
|
||||||
|
async def delete_character(
|
||||||
|
character_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
logger.info(f"delete_character called. ID: {character_id}")
|
||||||
|
|
||||||
|
existing_char = await dao.chars.get_character(character_id)
|
||||||
|
if not existing_char:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
success = await dao.chars.delete_character(character_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||||
request: Request) -> GenerationResponse:
|
request: Request) -> GenerationResponse:
|
||||||
|
|||||||
@@ -1,13 +1,15 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, UploadFile, File, Form
|
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||||
from fastapi.params import Depends
|
from fastapi.params import Depends
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from api import service
|
from api import service
|
||||||
from api.dependency import get_generation_service
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from models.Generation import Generation
|
from models.Generation import Generation
|
||||||
|
|
||||||
@@ -19,13 +21,14 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
from api.endpoints.auth import get_current_user
|
from api.endpoints.auth import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
|
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(
|
generation_service: GenerationService = Depends(
|
||||||
get_generation_service)) -> PromptResponse:
|
get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||||
return PromptResponse(prompt=generated_prompt)
|
return PromptResponse(prompt=generated_prompt)
|
||||||
@@ -35,7 +38,8 @@ async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
|||||||
async def prompt_from_image(
|
async def prompt_from_image(
|
||||||
prompt: Optional[str] = Form(None),
|
prompt: Optional[str] = Form(None),
|
||||||
images: List[UploadFile] = File(...),
|
images: List[UploadFile] = File(...),
|
||||||
generation_service: GenerationService = Depends(get_generation_service)
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
) -> PromptResponse:
|
) -> PromptResponse:
|
||||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||||
images_bytes = []
|
images_bytes = []
|
||||||
@@ -49,34 +53,139 @@ async def prompt_from_image(
|
|||||||
|
|
||||||
@router.get("", response_model=GenerationsResponse)
|
@router.get("", response_model=GenerationsResponse)
|
||||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)):
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)):
|
||||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None # Show all project generations
|
||||||
|
|
||||||
|
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/_run", response_model=GenerationResponse)
|
@router.post("/_run", response_model=GenerationResponse)
|
||||||
async def post_generation(generation: GenerationRequest, request: Request,
|
async def post_generation(generation: GenerationRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
get_generation_service)) -> GenerationResponse:
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)) -> GenerationResponse:
|
||||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||||
return await generation_service.create_generation_task(generation)
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
generation.project_id = project_id
|
||||||
|
|
||||||
|
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||||
async def get_generation(generation_id: str,
|
async def get_generation(generation_id: str,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||||
return await generation_service.get_generation(generation_id)
|
gen = await generation_service.get_generation(generation_id)
|
||||||
|
if gen and gen.created_by != str(current_user["_id"]):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
return gen
|
||||||
|
|
||||||
|
|
||||||
@router.get("/running")
|
@router.get("/running")
|
||||||
async def get_running_generations(request: Request,
|
async def get_running_generations(request: Request,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)):
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
return await generation_service.get_running_generations()
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)):
|
||||||
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
@router.post("/video/_run", response_model=GenerationResponse)
|
||||||
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
|
async def post_video_generation(
|
||||||
|
video_request: VideoGenerationRequest,
|
||||||
|
request: Request,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
) -> GenerationResponse:
|
||||||
|
"""Start image-to-video generation using Kling AI."""
|
||||||
|
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
video_request.project_id = project_id
|
||||||
|
|
||||||
|
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
|
async def import_external_generation(
|
||||||
|
request: Request,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
x_signature: str = Header(..., alias="X-Signature")
|
||||||
|
) -> GenerationResponse:
|
||||||
|
"""
|
||||||
|
Import a generation from an external source.
|
||||||
|
Requires server-to-server authentication via HMAC signature.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from utils.external_auth import verify_signature
|
||||||
|
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||||
|
|
||||||
|
logger.info("import_external_generation called")
|
||||||
|
|
||||||
|
# Get raw request body for signature verification
|
||||||
|
body = await request.body()
|
||||||
|
|
||||||
|
# Verify signature
|
||||||
|
secret = os.getenv("EXTERNAL_API_SECRET")
|
||||||
|
if not secret:
|
||||||
|
logger.error("EXTERNAL_API_SECRET not configured")
|
||||||
|
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||||
|
|
||||||
|
if not verify_signature(body, x_signature, secret):
|
||||||
|
logger.warning("Invalid signature for external generation import")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||||
|
|
||||||
|
# Parse request body
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = json.loads(body.decode('utf-8'))
|
||||||
|
external_gen = ExternalGenerationRequest(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse request body: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||||
|
|
||||||
|
# Import generation
|
||||||
|
try:
|
||||||
|
generation = await generation_service.import_external_generation(external_gen)
|
||||||
|
return GenerationResponse(**generation.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to import external generation: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_generation(generation_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)):
|
||||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||||
deleted = await generation_service.delete_generation(generation_id)
|
deleted = await generation_service.delete_generation(generation_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
|
|||||||
167
api/endpoints/project_router.py
Normal file
167
api/endpoints/project_router.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from api.dependency import get_dao
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from models.Project import Project
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||||
|
|
||||||
|
class ProjectCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class ProjectResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
owner_id: str
|
||||||
|
members: List[str]
|
||||||
|
is_owner: bool = False
|
||||||
|
|
||||||
|
@router.post("", response_model=ProjectResponse)
|
||||||
|
async def create_project(
|
||||||
|
project_data: ProjectCreate,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
new_project = Project(
|
||||||
|
name=project_data.name,
|
||||||
|
description=project_data.description,
|
||||||
|
owner_id=user_id,
|
||||||
|
members=[user_id]
|
||||||
|
)
|
||||||
|
project_id = await dao.projects.create_project(new_project)
|
||||||
|
|
||||||
|
# Add project to user's project list
|
||||||
|
# Assuming user_repo has a method to add project or we do it directly?
|
||||||
|
# UserRepo doesn't have add_project method yet.
|
||||||
|
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
|
||||||
|
# Better to update UserRepo. For now, let's just return success.
|
||||||
|
# But user needs to see it in list.
|
||||||
|
# Update user in DB
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ProjectResponse(
|
||||||
|
id=project_id,
|
||||||
|
name=new_project.name,
|
||||||
|
description=new_project.description,
|
||||||
|
owner_id=new_project.owner_id,
|
||||||
|
members=new_project.members,
|
||||||
|
is_owner=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("", response_model=List[ProjectResponse])
|
||||||
|
async def get_my_projects(
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
projects = await dao.projects.get_projects_by_user(user_id)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
for p in projects:
|
||||||
|
responses.append(ProjectResponse(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
owner_id=p.owner_id,
|
||||||
|
members=p.members,
|
||||||
|
is_owner=(p.owner_id == user_id)
|
||||||
|
))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
class MemberAdd(BaseModel):
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
|
||||||
|
async def add_member(
|
||||||
|
project_id: str,
|
||||||
|
member_data: MemberAdd,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
if project.owner_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Only owner can add members")
|
||||||
|
|
||||||
|
target_user = await dao.users.get_user_by_username(member_data.username)
|
||||||
|
if not target_user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
target_user_id = str(target_user["_id"])
|
||||||
|
|
||||||
|
if target_user_id in project.members:
|
||||||
|
return {"message": "User already in project"}
|
||||||
|
|
||||||
|
await dao.projects.add_member(project_id, target_user_id)
|
||||||
|
|
||||||
|
# Update target user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": target_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Member added"}
|
||||||
|
|
||||||
|
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
|
||||||
|
async def join_project(
|
||||||
|
project_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
# Retrieve project to verify it exists
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
|
||||||
|
# Check if user is ALREADY in project
|
||||||
|
if user_id in project.members:
|
||||||
|
return {"message": "Already a member"}
|
||||||
|
|
||||||
|
# Add member
|
||||||
|
await dao.projects.add_member(project_id, user_id)
|
||||||
|
|
||||||
|
# Update user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Joined project"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
|
||||||
|
async def delete_project(
|
||||||
|
project_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
if project.owner_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Only owner can delete project")
|
||||||
|
|
||||||
|
await dao.projects.delete_project(project_id)
|
||||||
|
|
||||||
|
# Remove project from user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$pull": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Project deleted"}
|
||||||
18
api/models/CharacterDTO.py
Normal file
18
api/models/CharacterDTO.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CharacterCreateRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
character_bio: str
|
||||||
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
|
avatar_image: Optional[str] = None
|
||||||
|
character_image_tg_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
class CharacterUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
character_bio: Optional[str] = None
|
||||||
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
|
avatar_image: Optional[str] = None
|
||||||
|
character_image_tg_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
37
api/models/ExternalGenerationDTO.py
Normal file
37
api/models/ExternalGenerationDTO.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalGenerationRequest(BaseModel):
|
||||||
|
"""Request model for importing external generations."""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
tech_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
# Image can be provided as base64 string OR URL (one must be provided)
|
||||||
|
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
||||||
|
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||||
|
|
||||||
|
# Generation metadata
|
||||||
|
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||||
|
quality: Quality = Quality.ONEK
|
||||||
|
|
||||||
|
# Optional linking
|
||||||
|
linked_character_id: Optional[str] = None
|
||||||
|
created_by: str = Field(..., description="User ID from external system")
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
execution_time_seconds: Optional[float] = None
|
||||||
|
api_execution_time_seconds: Optional[float] = None
|
||||||
|
token_usage: Optional[int] = None
|
||||||
|
input_token_usage: Optional[int] = None
|
||||||
|
output_token_usage: Optional[int] = None
|
||||||
|
|
||||||
|
def validate_image_source(self):
|
||||||
|
"""Ensure at least one image source is provided."""
|
||||||
|
if not self.image_data and not self.image_url:
|
||||||
|
raise ValueError("Either image_data or image_url must be provided")
|
||||||
|
if self.image_data and self.image_url:
|
||||||
|
raise ValueError("Only one of image_data or image_url should be provided")
|
||||||
@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
|
|||||||
telegram_id: Optional[int] = None
|
telegram_id: Optional[int] = None
|
||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class GenerationsResponse(BaseModel):
|
class GenerationsResponse(BaseModel):
|
||||||
@@ -26,6 +27,7 @@ class GenerationsResponse(BaseModel):
|
|||||||
class GenerationResponse(BaseModel):
|
class GenerationResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: GenerationStatus
|
status: GenerationStatus
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
|
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
@@ -42,6 +44,12 @@ class GenerationResponse(BaseModel):
|
|||||||
input_token_usage: Optional[int] = None
|
input_token_usage: Optional[int] = None
|
||||||
output_token_usage: Optional[int] = None
|
output_token_usage: Optional[int] = None
|
||||||
progress: int = 0
|
progress: int = 0
|
||||||
|
cost: Optional[float] = None
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
# Video-specific
|
||||||
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None
|
||||||
|
video_mode: Optional[str] = None
|
||||||
created_at: datetime = datetime.now(UTC)
|
created_at: datetime = datetime.now(UTC)
|
||||||
updated_at: datetime = datetime.now(UTC)
|
updated_at: datetime = datetime.now(UTC)
|
||||||
|
|
||||||
|
|||||||
16
api/models/VideoGenerationRequest.py
Normal file
16
api/models/VideoGenerationRequest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VideoGenerationRequest(BaseModel):
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: Optional[str] = ""
|
||||||
|
image_asset_id: str # ID ассета-картинки для source image
|
||||||
|
duration: int = 5 # 5 or 10 seconds
|
||||||
|
mode: str = "std" # "std" or "pro"
|
||||||
|
model_name: str = "kling-v2-1"
|
||||||
|
cfg_scale: float = 0.5
|
||||||
|
aspect_ratio: str = "16:9"
|
||||||
|
linked_character_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,15 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import base64
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional, Tuple, Any, Dict
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import httpx
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter, KlingApiException
|
||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||||
from models.Asset import Asset, AssetType, AssetContentType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Generation import Generation, GenerationStatus
|
from models.Generation import Generation, GenerationStatus
|
||||||
@@ -62,11 +66,12 @@ async def generate_image_task(
|
|||||||
return images_bytes, metrics
|
return images_bytes, metrics
|
||||||
|
|
||||||
class GenerationService:
|
class GenerationService:
|
||||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
self.gemini = gemini
|
self.gemini = gemini
|
||||||
self.s3_adapter = s3_adapter
|
self.s3_adapter = s3_adapter
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.kling_adapter = kling_adapter
|
||||||
|
|
||||||
|
|
||||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||||
@@ -92,10 +97,10 @@ class GenerationService:
|
|||||||
|
|
||||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
|
||||||
Generation]:
|
Generation]:
|
||||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset)
|
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
|
||||||
total_count = await self.dao.generations.count_generations(character_id = character_id)
|
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
|
||||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||||
|
|
||||||
@@ -106,15 +111,18 @@ class GenerationService:
|
|||||||
else:
|
else:
|
||||||
return GenerationResponse(**gen.model_dump())
|
return GenerationResponse(**gen.model_dump())
|
||||||
|
|
||||||
async def get_running_generations(self) -> List[Generation]:
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||||
|
|
||||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
gen_id = None
|
gen_id = None
|
||||||
generation_model = None
|
generation_model = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_model = Generation(**generation_request.model_dump())
|
generation_model = Generation(**generation_request.model_dump())
|
||||||
|
if user_id:
|
||||||
|
generation_model.created_by = user_id
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||||
generation_model.id = gen_id
|
generation_model.id = gen_id
|
||||||
|
|
||||||
@@ -155,22 +163,25 @@ class GenerationService:
|
|||||||
# 2. Получаем ассеты-референсы (если они есть)
|
# 2. Получаем ассеты-референсы (если они есть)
|
||||||
reference_assets: List[Asset] = []
|
reference_assets: List[Asset] = []
|
||||||
media_group_bytes: List[bytes] = []
|
media_group_bytes: List[bytes] = []
|
||||||
generation_prompt = f"""
|
generation_prompt = generation.prompt
|
||||||
|
# generation_prompt = f"""
|
||||||
|
|
||||||
Create detailed image of character in scene.
|
# Create detailed image of character in scene.
|
||||||
|
|
||||||
SCENE DESCRIPTION: {generation.prompt}
|
# SCENE DESCRIPTION: {generation.prompt}
|
||||||
|
|
||||||
Rules:
|
# Rules:
|
||||||
- Integrate the character's appearance naturally into the scene description.
|
# - Integrate the character's appearance naturally into the scene description.
|
||||||
- Focus on lighting, texture, and composition.
|
# - Focus on lighting, texture, and composition.
|
||||||
"""
|
# """
|
||||||
if generation.linked_character_id is not None:
|
if generation.linked_character_id is not None:
|
||||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||||
if char_info is None:
|
if char_info is None:
|
||||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||||
if generation.use_profile_image:
|
if generation.use_profile_image:
|
||||||
media_group_bytes.append(char_info.character_image_data)
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||||
|
if avatar_asset:
|
||||||
|
media_group_bytes.append(avatar_asset.data)
|
||||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||||
|
|
||||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
@@ -258,7 +269,9 @@ class GenerationService:
|
|||||||
data=None, # Not storing bytes in DB anymore
|
data=None, # Not storing bytes in DB anymore
|
||||||
minio_object_name=filename,
|
minio_object_name=filename,
|
||||||
minio_bucket=self.s3_adapter.bucket_name,
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
thumbnail=thumbnail_bytes
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Сохраняем в БД
|
# Сохраняем в БД
|
||||||
@@ -325,6 +338,261 @@ class GenerationService:
|
|||||||
logger.error(f"Error in progress simulation: {e}")
|
logger.error(f"Error in progress simulation: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def import_external_generation(self, external_gen) -> Generation:
|
||||||
|
"""
|
||||||
|
Import a generation from an external source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_gen: ExternalGenerationRequest with generation data and image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created Generation object
|
||||||
|
"""
|
||||||
|
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||||
|
|
||||||
|
# Validate image source
|
||||||
|
external_gen.validate_image_source()
|
||||||
|
|
||||||
|
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||||
|
|
||||||
|
# 1. Process image (download or decode)
|
||||||
|
image_bytes = None
|
||||||
|
|
||||||
|
if external_gen.image_url:
|
||||||
|
# Download image from URL
|
||||||
|
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_bytes = response.content
|
||||||
|
elif external_gen.image_data:
|
||||||
|
# Decode base64 image
|
||||||
|
logger.info("Decoding base64 image data")
|
||||||
|
image_bytes = base64.b64decode(external_gen.image_data)
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
raise ValueError("Failed to process image data")
|
||||||
|
|
||||||
|
# 2. Generate thumbnail
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||||
|
|
||||||
|
# 3. Save to S3
|
||||||
|
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||||
|
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||||
|
|
||||||
|
# 4. Create Asset
|
||||||
|
new_asset = Asset(
|
||||||
|
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.IMAGE,
|
||||||
|
linked_char_id=external_gen.linked_character_id,
|
||||||
|
data=None, # Not storing bytes in DB
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=external_gen.created_by,
|
||||||
|
project_id=external_gen.project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
|
||||||
|
logger.info(f"Created asset {asset_id} for external generation")
|
||||||
|
|
||||||
|
# 5. Create Generation record
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.DONE,
|
||||||
|
linked_character_id=external_gen.linked_character_id,
|
||||||
|
aspect_ratio=external_gen.aspect_ratio,
|
||||||
|
quality=external_gen.quality,
|
||||||
|
prompt=external_gen.prompt,
|
||||||
|
tech_prompt=external_gen.tech_prompt,
|
||||||
|
result_list=[new_asset.id],
|
||||||
|
result=new_asset.id,
|
||||||
|
progress=100,
|
||||||
|
execution_time_seconds=external_gen.execution_time_seconds,
|
||||||
|
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||||
|
token_usage=external_gen.token_usage,
|
||||||
|
input_token_usage=external_gen.input_token_usage,
|
||||||
|
output_token_usage=external_gen.output_token_usage,
|
||||||
|
created_by=external_gen.created_by,
|
||||||
|
project_id=external_gen.project_id,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
logger.info(f"Created generation {gen_id} from external source")
|
||||||
|
|
||||||
|
return generation
|
||||||
|
|
||||||
|
# === VIDEO GENERATION (Kling) ===
|
||||||
|
|
||||||
|
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
|
"""Create a video generation task (async, returns immediately)."""
|
||||||
|
if not self.kling_adapter:
|
||||||
|
raise Exception("Kling adapter is not configured")
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.RUNNING,
|
||||||
|
gen_type=GenType.VIDEO,
|
||||||
|
linked_character_id=request.linked_character_id,
|
||||||
|
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
|
||||||
|
quality=Quality.ONEK,
|
||||||
|
prompt=request.prompt,
|
||||||
|
assets_list=[request.image_asset_id],
|
||||||
|
video_duration=request.duration,
|
||||||
|
video_mode=request.mode,
|
||||||
|
project_id=request.project_id,
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
generation.created_by = user_id
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
async def runner(gen, req):
|
||||||
|
logger.info(f"Starting background video generation task for ID: {gen.id}")
|
||||||
|
try:
|
||||||
|
await self.create_video_generation(gen, req)
|
||||||
|
logger.info(f"Background video generation task finished for ID: {gen.id}")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||||
|
if db_gen and db_gen.status != GenerationStatus.FAILED:
|
||||||
|
db_gen.status = GenerationStatus.FAILED
|
||||||
|
db_gen.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(db_gen)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to mark video generation as FAILED")
|
||||||
|
logger.exception("create_video_generation task failed")
|
||||||
|
|
||||||
|
asyncio.create_task(runner(generation, request))
|
||||||
|
return GenerationResponse(**generation.model_dump())
|
||||||
|
|
||||||
|
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
|
||||||
|
"""Background video generation: call Kling API, poll, download result, save asset."""
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Get source image presigned URL
|
||||||
|
asset = await self.dao.assets.get_asset(request.image_asset_id)
|
||||||
|
if not asset:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} not found")
|
||||||
|
|
||||||
|
if not asset.minio_object_name:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
|
||||||
|
|
||||||
|
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
|
||||||
|
if not presigned_url:
|
||||||
|
raise Exception("Failed to generate presigned URL for source image")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
|
||||||
|
|
||||||
|
# 2. Create Kling task
|
||||||
|
task_data = await self.kling_adapter.create_video_task(
|
||||||
|
image_url=presigned_url,
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt or "",
|
||||||
|
model_name=request.model_name,
|
||||||
|
duration=request.duration,
|
||||||
|
mode=request.mode,
|
||||||
|
cfg_scale=request.cfg_scale,
|
||||||
|
aspect_ratio=request.aspect_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
generation.kling_task_id = task_id
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
|
||||||
|
|
||||||
|
# 3. Poll for completion with progress updates
|
||||||
|
async def progress_callback(progress_pct: int):
|
||||||
|
generation.progress = progress_pct
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
result = await self.kling_adapter.wait_for_completion(
|
||||||
|
task_id=task_id,
|
||||||
|
poll_interval=10,
|
||||||
|
timeout=600,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Extract video URL and download
|
||||||
|
works = result.get("task_result", {}).get("videos", [])
|
||||||
|
if not works:
|
||||||
|
raise Exception("No video in Kling result")
|
||||||
|
|
||||||
|
video_url = works[0].get("url")
|
||||||
|
video_duration = works[0].get("duration", request.duration)
|
||||||
|
if not video_url:
|
||||||
|
raise Exception("No video URL in Kling result")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
video_response = await client.get(video_url)
|
||||||
|
video_response.raise_for_status()
|
||||||
|
video_bytes = video_response.content
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
|
||||||
|
|
||||||
|
# 5. Upload to S3
|
||||||
|
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
|
||||||
|
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
|
||||||
|
|
||||||
|
# 6. Create Asset
|
||||||
|
new_asset = Asset(
|
||||||
|
name=f"Video_{generation.linked_character_id or 'gen'}",
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.VIDEO,
|
||||||
|
linked_char_id=generation.linked_character_id,
|
||||||
|
data=None,
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=None, # видео thumbnails можно добавить позже
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
|
||||||
|
# 7. Finalize generation
|
||||||
|
end_time = datetime.now()
|
||||||
|
generation.result_list = [new_asset.id]
|
||||||
|
generation.result = new_asset.id
|
||||||
|
generation.status = GenerationStatus.DONE
|
||||||
|
generation.progress = 100
|
||||||
|
generation.video_duration = video_duration
|
||||||
|
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
|
||||||
|
|
||||||
|
except KlingApiException as e:
|
||||||
|
logger.error(f"Kling API error for generation {generation.id}: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Video generation {generation.id} failed: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
|
||||||
async def delete_generation(self, generation_id: str) -> bool:
|
async def delete_generation(self, generation_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Soft delete generation by marking it as deleted.
|
Soft delete generation by marking it as deleted.
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@ from pydantic import BaseModel, computed_field, Field, model_validator
|
|||||||
|
|
||||||
class AssetContentType(str, Enum):
|
class AssetContentType(str, Enum):
|
||||||
IMAGE = 'image'
|
IMAGE = 'image'
|
||||||
|
VIDEO = 'video'
|
||||||
PROMPT = 'prompt'
|
PROMPT = 'prompt'
|
||||||
|
|
||||||
class AssetType(str, Enum):
|
class AssetType(str, Enum):
|
||||||
@@ -28,6 +29,8 @@ class Asset(BaseModel):
|
|||||||
minio_thumbnail_object_name: Optional[str] = None
|
minio_thumbnail_object_name: Optional[str] = None
|
||||||
thumbnail: Optional[bytes] = None
|
thumbnail: Optional[bytes] = None
|
||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from pydantic_core.core_schema import computed_field
|
|||||||
|
|
||||||
|
|
||||||
class Character(BaseModel):
|
class Character(BaseModel):
|
||||||
id: str | None
|
id: Optional[str] = None
|
||||||
name: str
|
name: str
|
||||||
|
avatar_asset_id: Optional[str] = None
|
||||||
avatar_image: Optional[str] = None
|
avatar_image: Optional[str] = None
|
||||||
character_image_data: Optional[bytes] = None
|
character_image_data: Optional[bytes] = None
|
||||||
character_image_doc_tg_id: str
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
character_image_tg_id: str | None
|
character_image_tg_id: Optional[str] = None
|
||||||
character_bio: str
|
character_bio: Optional[str] = None
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from datetime import datetime, UTC
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.enums import AspectRatios, Quality, GenType
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
@@ -16,6 +16,7 @@ class GenerationStatus(str, Enum):
|
|||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
status: GenerationStatus = GenerationStatus.RUNNING
|
status: GenerationStatus = GenerationStatus.RUNNING
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
telegram_id: Optional[int] = None
|
telegram_id: Optional[int] = None
|
||||||
@@ -34,5 +35,20 @@ class Generation(BaseModel):
|
|||||||
input_token_usage: Optional[int] = None
|
input_token_usage: Optional[int] = None
|
||||||
output_token_usage: Optional[int] = None
|
output_token_usage: Optional[int] = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
|
album_id: Optional[str] = None
|
||||||
|
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
# Video-specific fields
|
||||||
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None # 5 or 10 seconds
|
||||||
|
video_mode: Optional[str] = None # "std" or "pro"
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def cost(self) -> float:
|
||||||
|
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
|
||||||
|
cost_input = self.input_token_usage * 0.000002
|
||||||
|
cost_output = self.output_token_usage * 0.00012
|
||||||
|
return round(cost_input + cost_output, 3)
|
||||||
|
return 0.0
|
||||||
12
models/Project.py
Normal file
12
models/Project.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class Project(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
owner_id: str
|
||||||
|
members: List[str] = [] # List of User IDs
|
||||||
|
is_deleted: bool = False
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -34,10 +34,12 @@ class Quality(str, Enum):
|
|||||||
class GenType(str, Enum):
|
class GenType(str, Enum):
|
||||||
TEXT = 'Text'
|
TEXT = 'Text'
|
||||||
IMAGE = 'Image'
|
IMAGE = 'Image'
|
||||||
|
VIDEO = 'Video'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self) -> str:
|
def value_type(self) -> str:
|
||||||
return {
|
return {
|
||||||
GenType.TEXT: 'Text',
|
GenType.TEXT: 'Text',
|
||||||
GenType.IMAGE: 'Image',
|
GenType.IMAGE: 'Image',
|
||||||
|
GenType.VIDEO: 'Video',
|
||||||
}[self]
|
}[self]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -46,7 +46,7 @@ class AssetsRepo:
|
|||||||
res = await self.collection.insert_one(asset.model_dump())
|
res = await self.collection.insert_one(asset.model_dump())
|
||||||
return str(res.inserted_id)
|
return str(res.inserted_id)
|
||||||
|
|
||||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]:
|
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||||
filter = {}
|
filter = {}
|
||||||
if asset_type:
|
if asset_type:
|
||||||
filter["type"] = asset_type
|
filter["type"] = asset_type
|
||||||
@@ -70,6 +70,12 @@ class AssetsRepo:
|
|||||||
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
|
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
|
||||||
# So list DOES NOT return thumbnails by default.
|
# So list DOES NOT return thumbnails by default.
|
||||||
args["thumbnail"] = 0
|
args["thumbnail"] = 0
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
filter['project_id'] = None
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||||
assets = []
|
assets = []
|
||||||
@@ -157,8 +163,15 @@ class AssetsRepo:
|
|||||||
assets.append(Asset(**doc))
|
assets.append(Asset(**doc))
|
||||||
return assets
|
return assets
|
||||||
|
|
||||||
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
|
async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||||
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
|
filter = {}
|
||||||
|
if character_id:
|
||||||
|
filter["linked_char_id"] = character_id
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
return await self.collection.count_documents(filter)
|
||||||
|
|
||||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
@@ -12,7 +12,7 @@ class CharacterRepo:
|
|||||||
|
|
||||||
async def add_character(self, character: Character) -> Character:
|
async def add_character(self, character: Character) -> Character:
|
||||||
op = await self.collection.insert_one(character.model_dump())
|
op = await self.collection.insert_one(character.model_dump())
|
||||||
character.id = op.inserted_id
|
character.id = str(op.inserted_id)
|
||||||
return character
|
return character
|
||||||
|
|
||||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||||
@@ -26,18 +26,25 @@ class CharacterRepo:
|
|||||||
res["id"] = str(res.pop("_id"))
|
res["id"] = str(res.pop("_id"))
|
||||||
return Character(**res)
|
return Character(**res)
|
||||||
|
|
||||||
async def get_all_characters(self) -> List[Character]:
|
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
|
||||||
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None)
|
filter = {}
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
characters = []
|
args = {"character_image_data": 0} # don't return image data for list
|
||||||
for doc in docs:
|
res = await self.collection.find(filter, args).to_list(None)
|
||||||
# Конвертируем ObjectId в строку и кладем в поле id
|
chars = []
|
||||||
|
for doc in res:
|
||||||
doc["id"] = str(doc.pop("_id"))
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
chars.append(Character(**doc))
|
||||||
|
return chars
|
||||||
|
|
||||||
# Создаем объект
|
async def update_char(self, char_id: str, character: Character) -> bool:
|
||||||
characters.append(Character(**doc))
|
result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||||
|
return result.modified_count > 0
|
||||||
|
|
||||||
return characters
|
async def delete_character(self, char_id: str) -> bool:
|
||||||
|
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
|
||||||
async def update_char(self, char_id: str, character: Character) -> None:
|
return result.deleted_count > 0
|
||||||
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from repos.char_repo import CharacterRepo
|
|||||||
from repos.generation_repo import GenerationRepo
|
from repos.generation_repo import GenerationRepo
|
||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
from repos.albums_repo import AlbumsRepo
|
from repos.albums_repo import AlbumsRepo
|
||||||
|
from repos.project_repo import ProjectRepo
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -16,3 +17,5 @@ class DAO:
|
|||||||
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
||||||
self.generations = GenerationRepo(client, db_name)
|
self.generations = GenerationRepo(client, db_name)
|
||||||
self.albums = AlbumsRepo(client, db_name)
|
self.albums = AlbumsRepo(client, db_name)
|
||||||
|
self.projects = ProjectRepo(client, db_name)
|
||||||
|
self.users = UsersRepo(client, db_name)
|
||||||
|
|||||||
@@ -25,13 +25,19 @@ class GenerationRepo:
|
|||||||
return Generation(**res)
|
return Generation(**res)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
limit: int = 10, offset: int = 10) -> List[Generation]:
|
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
|
|
||||||
filter = {"is_deleted": False}
|
filter = {"is_deleted": False}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
filter["linked_character_id"] = character_id
|
filter["linked_character_id"] = character_id
|
||||||
if status is not None:
|
if status is not None:
|
||||||
filter["status"] = status
|
filter["status"] = status
|
||||||
|
if created_by is not None:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
filter["project_id"] = None
|
||||||
|
if project_id is not None:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||||
offset).limit(limit).to_list(None)
|
offset).limit(limit).to_list(None)
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
@@ -40,12 +46,17 @@ class GenerationRepo:
|
|||||||
generations.append(Generation(**generation))
|
generations.append(Generation(**generation))
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int:
|
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
|
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||||
args = {}
|
args = {}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
args["linked_character_id"] = character_id
|
args["linked_character_id"] = character_id
|
||||||
if status is not None:
|
if status is not None:
|
||||||
args["status"] = status
|
args["status"] = status
|
||||||
|
if created_by is not None:
|
||||||
|
args["created_by"] = created_by
|
||||||
|
if project_id is not None:
|
||||||
|
args["project_id"] = project_id
|
||||||
return await self.collection.count_documents(args)
|
return await self.collection.count_documents(args)
|
||||||
|
|
||||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||||
|
|||||||
62
repos/project_repo.py
Normal file
62
repos/project_repo.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from models.Project import Project
|
||||||
|
|
||||||
|
class ProjectRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["projects"]
|
||||||
|
|
||||||
|
async def create_project(self, project: Project) -> str:
|
||||||
|
res = await self.collection.insert_one(project.model_dump())
|
||||||
|
return str(res.inserted_id)
|
||||||
|
|
||||||
|
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||||
|
if not ObjectId.is_valid(project_id):
|
||||||
|
return None
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(project_id)})
|
||||||
|
if res:
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Project(**res)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_projects_by_user(self, user_id: str) -> List[Project]:
|
||||||
|
# Find projects where user is owner OR in members
|
||||||
|
filter = {
|
||||||
|
"$or": [
|
||||||
|
{"owner_id": user_id},
|
||||||
|
{"members": user_id}
|
||||||
|
],
|
||||||
|
"is_deleted": False
|
||||||
|
}
|
||||||
|
cursor = self.collection.find(filter).sort("created_at", -1)
|
||||||
|
projects = []
|
||||||
|
async for doc in cursor:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
projects.append(Project(**doc))
|
||||||
|
return projects
|
||||||
|
|
||||||
|
async def add_member(self, project_id: str, user_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$addToSet": {"members": user_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def remove_member(self, project_id: str, user_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$pull": {"members": user_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def update_project(self, project_id: str, updates: dict) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$set": updates}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def delete_project(self, project_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
|
||||||
|
return res.modified_count > 0
|
||||||
@@ -19,10 +19,16 @@ class UsersRepo:
|
|||||||
self.collection = client[db_name]["users"]
|
self.collection = client[db_name]["users"]
|
||||||
|
|
||||||
async def get_user(self, user_id: int):
|
async def get_user(self, user_id: int):
|
||||||
return await self.collection.find_one({"user_id": user_id})
|
user = await self.collection.find_one({"user_id": user_id})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
async def get_user_by_username(self, username: str):
|
async def get_user_by_username(self, username: str):
|
||||||
return await self.collection.find_one({"username": username})
|
user = await self.collection.find_one({"username": username})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
|
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
|
||||||
"""Создает нового пользователя с username/паролем"""
|
"""Создает нового пользователя с username/паролем"""
|
||||||
@@ -38,15 +44,23 @@ class UsersRepo:
|
|||||||
"created_at": datetime.now(),
|
"created_at": datetime.now(),
|
||||||
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
|
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
|
||||||
"is_web_user": True,
|
"is_web_user": True,
|
||||||
"is_admin": False
|
"is_admin": False,
|
||||||
|
"project_ids": [],
|
||||||
|
"current_project_id": None
|
||||||
}
|
}
|
||||||
result = await self.collection.insert_one(user_doc)
|
result = await self.collection.insert_one(user_doc)
|
||||||
return await self.collection.find_one({"_id": result.inserted_id})
|
user = await self.collection.find_one({"_id": result.inserted_id})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
async def get_pending_users(self):
|
async def get_pending_users(self):
|
||||||
"""Возвращает список пользователей со статусом PENDING"""
|
"""Возвращает список пользователей со статусом PENDING"""
|
||||||
cursor = self.collection.find({"status": UserStatus.PENDING})
|
cursor = self.collection.find({"status": UserStatus.PENDING})
|
||||||
return await cursor.to_list(length=100)
|
users = await cursor.to_list(length=100)
|
||||||
|
for user in users:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return users
|
||||||
|
|
||||||
async def approve_user(self, username: str):
|
async def approve_user(self, username: str):
|
||||||
await self.collection.update_one(
|
await self.collection.update_one(
|
||||||
|
|||||||
@@ -50,3 +50,5 @@ passlib[argon2]==1.7.4
|
|||||||
python-jose[cryptography]==3.3.0
|
python-jose[cryptography]==3.3.0
|
||||||
python-multipart==0.0.22
|
python-multipart==0.0.22
|
||||||
email-validator
|
email-validator
|
||||||
|
prometheus-fastapi-instrumentator
|
||||||
|
PyJWT
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -63,7 +63,8 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
|||||||
character_image_data=file_io.read(),
|
character_image_data=file_io.read(),
|
||||||
character_image_tg_id=None,
|
character_image_tg_id=None,
|
||||||
character_image_doc_tg_id=file_id,
|
character_image_doc_tg_id=file_id,
|
||||||
character_bio=bio
|
character_bio=bio,
|
||||||
|
created_by=str(message.from_user.id)
|
||||||
)
|
)
|
||||||
file_io.close()
|
file_io.close()
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
|
|||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
|
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
|
||||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("gen_mode"))
|
@router.message(Command("gen_mode"))
|
||||||
@@ -259,7 +259,8 @@ async def handle_album(
|
|||||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||||
linked_char_id = data["char_id"]))
|
linked_char_id = data["char_id"],
|
||||||
|
created_by=str(message.from_user.id)))
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Генерация не вернула изображений.")
|
await message.answer("❌ Генерация не вернула изображений.")
|
||||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||||
@@ -314,7 +315,8 @@ async def gen_mode_start(
|
|||||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||||
linked_char_id=data["char_id"]))
|
linked_char_id=data["char_id"],
|
||||||
|
created_by=str(message.from_user.id)))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Ничего не сгенерировалось.")
|
await message.answer("❌ Ничего не сгенерировалось.")
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
101
tests/test_character_crud.py
Normal file
101
tests/test_character_crud.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_dao
|
||||||
|
from repos.dao import DAO
|
||||||
|
from models.Character import Character
|
||||||
|
|
||||||
|
# Config for test DB
|
||||||
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||||
|
DB_NAME = "bot_db_test_chars"
|
||||||
|
|
||||||
|
# Mock User
|
||||||
|
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||||
|
MOCK_USER = {
|
||||||
|
"_id": MOCK_USER_ID,
|
||||||
|
"username": "testuser",
|
||||||
|
"is_admin": False,
|
||||||
|
"status": "allowed"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Override get_current_user to bypass auth
|
||||||
|
def mock_get_current_user():
|
||||||
|
return MOCK_USER
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||||
|
|
||||||
|
# Setup Real DAO with Test DB
|
||||||
|
client_mongo = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
dao = DAO(client_mongo, db_name=DB_NAME)
|
||||||
|
|
||||||
|
def mock_get_dao():
|
||||||
|
return dao
|
||||||
|
|
||||||
|
app.dependency_overrides[get_dao] = mock_get_dao
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def setup_teardown():
|
||||||
|
# Setup: Ensure clean state
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_character_crud_flow():
|
||||||
|
# 1. Create Character
|
||||||
|
create_payload = {
|
||||||
|
"name": "Test Character",
|
||||||
|
"character_bio": "A bio for test character",
|
||||||
|
"character_image_doc_tg_id": "file_123",
|
||||||
|
"avatar_image": "http://example.com/avatar.jpg"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/characters/", json=create_payload)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
char_data = response.json()
|
||||||
|
assert char_data["name"] == create_payload["name"]
|
||||||
|
assert char_data["created_by"] == MOCK_USER_ID
|
||||||
|
char_id = char_data["id"]
|
||||||
|
assert char_id is not None
|
||||||
|
|
||||||
|
# 2. Get Character
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == char_id
|
||||||
|
|
||||||
|
# 3. Update Character
|
||||||
|
update_payload = {
|
||||||
|
"name": "Updated Name",
|
||||||
|
"character_bio": "Updated bio"
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||||
|
assert response.status_code == 200
|
||||||
|
updated_data = response.json()
|
||||||
|
assert updated_data["name"] == "Updated Name"
|
||||||
|
assert updated_data["character_bio"] == "Updated bio"
|
||||||
|
|
||||||
|
# Verify update persistent
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.json()["name"] == "Updated Name"
|
||||||
|
|
||||||
|
# 4. Delete Character
|
||||||
|
response = client.delete(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 404, "Deleted character should return 404"
|
||||||
64
tests/test_character_integration.py
Normal file
64
tests/test_character_integration.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# 1. Set Auth Bypass and Test Config
|
||||||
|
os.environ["DB_NAME"] = "bot_db_test_integration"
|
||||||
|
# We keep MONGO_HOST as is (it works in verified script)
|
||||||
|
|
||||||
|
# 2. Import app AFTER setting env
|
||||||
|
from main import app
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
|
||||||
|
# 3. Override Auth
|
||||||
|
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||||
|
MOCK_USER = {
|
||||||
|
"_id": MOCK_USER_ID,
|
||||||
|
"username": "testuser",
|
||||||
|
"is_admin": False,
|
||||||
|
"status": "allowed",
|
||||||
|
"project_ids": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def mock_get_current_user():
|
||||||
|
return MOCK_USER
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
def test_character_crud_lifecycle():
|
||||||
|
# 1. Create
|
||||||
|
create_payload = {
|
||||||
|
"name": "Integration Test Char",
|
||||||
|
"character_bio": "Testing with real app structure",
|
||||||
|
"character_image_doc_tg_id": "doc_123",
|
||||||
|
"avatar_image": "http://example.com/img.jpg"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/characters/", json=create_payload)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
char_data = response.json()
|
||||||
|
assert char_data["name"] == create_payload["name"]
|
||||||
|
char_id = char_data["id"]
|
||||||
|
|
||||||
|
# 2. Get
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == char_id
|
||||||
|
|
||||||
|
# 3. Update
|
||||||
|
update_payload = {"name": "Updated Int Name"}
|
||||||
|
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Updated Int Name"
|
||||||
|
|
||||||
|
# 4. Delete
|
||||||
|
response = client.delete(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# 5. Verify Delete
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 404
|
||||||
63
tests/test_external_import.py
Executable file
63
tests/test_external_import.py
Executable file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for external generation import API.
|
||||||
|
This script demonstrates how to call the import endpoint with proper HMAC signature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
API_URL = "http://localhost:8090/api/generations/import"
|
||||||
|
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
|
||||||
|
|
||||||
|
# Sample generation data
|
||||||
|
generation_data = {
|
||||||
|
"prompt": "A beautiful sunset over mountains",
|
||||||
|
"tech_prompt": "High quality landscape photography",
|
||||||
|
"image_url": "https://picsum.photos/512/512", # Sample image URL
|
||||||
|
# OR use base64:
|
||||||
|
# "image_data": "base64_encoded_image_string_here",
|
||||||
|
"aspect_ratio": "9:16",
|
||||||
|
"quality": "1k",
|
||||||
|
"created_by": "external_user_123",
|
||||||
|
"execution_time_seconds": 5.2,
|
||||||
|
"token_usage": 1000,
|
||||||
|
"input_token_usage": 200,
|
||||||
|
"output_token_usage": 800
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert to JSON
|
||||||
|
body = json.dumps(generation_data).encode('utf-8')
|
||||||
|
|
||||||
|
# Compute HMAC signature
|
||||||
|
signature = hmac.new(
|
||||||
|
SECRET.encode('utf-8'),
|
||||||
|
body,
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Signature": signature
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Sending request to {API_URL}")
|
||||||
|
print(f"Signature: {signature}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(API_URL, data=body, headers=headers)
|
||||||
|
print(f"\nStatus Code: {response.status_code}")
|
||||||
|
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
if hasattr(e, 'response'):
|
||||||
|
print(f"Response text: {e.response.text}")
|
||||||
@@ -12,7 +12,7 @@ from models.Generation import Generation, GenerationStatus
|
|||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
# Mock config
|
# Mock config
|
||||||
# Use the same host as main.py but different DB
|
# Use the same host as aiws.py but different DB
|
||||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||||
DB_NAME = "bot_db_test_albums"
|
DB_NAME = "bot_db_test_albums"
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
46
utils/external_auth.py
Normal file
46
utils/external_auth.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from fastapi import Header, HTTPException
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify HMAC-SHA256 signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: Raw request body bytes
|
||||||
|
signature: Signature from X-Signature header
|
||||||
|
secret: Shared secret key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if signature is valid, False otherwise
|
||||||
|
"""
|
||||||
|
expected_signature = hmac.new(
|
||||||
|
secret.encode('utf-8'),
|
||||||
|
body,
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
return hmac.compare_digest(signature, expected_signature)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_external_signature(
|
||||||
|
x_signature: Optional[str] = Header(None, alias="X-Signature")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
FastAPI dependency to verify external API signature.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If signature is missing or invalid
|
||||||
|
"""
|
||||||
|
if not x_signature:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Missing X-Signature header"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We'll need to access the raw request body in the endpoint
|
||||||
|
# This dependency just validates the header exists
|
||||||
|
# Actual signature verification happens in the endpoint
|
||||||
|
return x_signature
|
||||||
Reference in New Issue
Block a user