3 Commits

Author SHA1 Message Date
xds
32ff77e04b feat: Implement video generation functionality and integrate with Kling API. 2026-02-12 10:27:07 +03:00
xds
d1f67c773f 123 2026-02-12 00:25:08 +03:00
xds
c63b51ef75 123
er the commit message for your changes. Lines starting
2026-02-12 00:24:43 +03:00
94 changed files with 664 additions and 1635 deletions

4
.env
View File

@@ -8,4 +8,6 @@ MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=SuperSecretPassword123! MINIO_SECRET_KEY=SuperSecretPassword123!
MINIO_BUCKET=ai-char MINIO_BUCKET=ai-char
MODE=production MODE=production
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4

17
.gitignore vendored
View File

@@ -8,19 +8,4 @@ minio_backup.tar.gz
.idea/ai-char-bot.iml .idea/ai-char-bot.iml
.idea .idea
.venv .venv
.vscode .vscode
.vscode/launch.json
middlewares/__pycache__/
middlewares/*.pyc
api/__pycache__/
api/*.pyc
repos/__pycache__/
repos/*.pyc
adapters/__pycache__/
adapters/*.pyc
services/__pycache__/
services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

25
.vscode/launch.json vendored
View File

@@ -16,6 +16,31 @@
], ],
"jinja": true, "jinja": true,
"justMyCode": true "justMyCode": true
},
{
"name": "Python: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Debug Tests: Current File",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"${file}"
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
} }
] ]
} }

Binary file not shown.

Binary file not shown.

View File

@@ -23,30 +23,28 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview" self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview" self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple: def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
"""Вспомогательный метод для подготовки контента (текст + картинки). """Вспомогательный метод для подготовки контента (текст + картинки)"""
Returns (contents, opened_images) — caller MUST close opened_images after use.""" contents = [prompt]
contents : list [Any]= [prompt]
opened_images = []
if images_list: if images_list:
logger.info(f"Preparing content with {len(images_list)} images") logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list: for img_bytes in images_list:
try: try:
# Gemini API требует PIL Image на входе
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
contents.append(image) contents.append(image)
opened_images.append(image)
except Exception as e: except Exception as e:
logger.error(f"Error processing input image: {e}") logger.error(f"Error processing input image: {e}")
else: else:
logger.info("Preparing content with no images") logger.info("Preparing content with no images")
return contents, opened_images return contents
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str: def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
""" """
Генерация текста (Чат или Vision). Генерация текста (Чат или Vision).
Возвращает строку с ответом. Возвращает строку с ответом.
""" """
contents, opened_images = self._prepare_contents(prompt, images_list) contents = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}") logger.info(f"Generating text: {prompt}")
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
@@ -70,17 +68,14 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Text API Error: {e}") logger.error(f"Gemini Text API Error: {e}")
raise GoogleGenerationException(f"Gemini Text API Error: {e}") raise GoogleGenerationException(f"Gemini Text API Error: {e}")
finally:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]: def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
""" """
Генерация изображений (Text-to-Image или Image-to-Image). Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке). Возвращает список байтовых потоков (готовых к отправке).
""" """
contents, opened_images = self._prepare_contents(prompt, images_list) contents = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}") logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
start_time = datetime.now() start_time = datetime.now()
@@ -105,21 +100,9 @@ class GoogleAdapter:
if response.usage_metadata: if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count token_usage = response.usage_metadata.total_token_count
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case if response.parts is None and response.candidates[0].finish_reason is not None:
if response.prompt_feedback and response.prompt_feedback.block_reason: raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
raise GoogleGenerationException(
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
)
# Check candidate-level block
if response.parts is None:
response_reason = (
response.candidates[0].finish_reason
if response.candidates and len(response.candidates) > 0
else "Unknown"
)
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
generated_images = [] generated_images = []
@@ -130,9 +113,7 @@ class GoogleAdapter:
try: try:
# 1. Берем сырые байты # 1. Берем сырые байты
raw_data = part.inline_data.data raw_data = part.inline_data.data
if raw_data is None: byte_arr = io.BytesIO(raw_data)
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG) # 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp() timestamp = datetime.now().timestamp()
@@ -166,8 +147,4 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Image API Error: {e}") logger.error(f"Gemini Image API Error: {e}")
raise GoogleGenerationException(f"Gemini Image API Error: {e}") raise GoogleGenerationException(f"Gemini Image API Error: {e}")
finally:
for img in opened_images:
img.close()
del contents

165
adapters/kling_adapter.py Normal file
View File

@@ -0,0 +1,165 @@
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import httpx
import jwt
logger = logging.getLogger(__name__)
KLING_API_BASE = "https://api.klingai.com"
class KlingApiException(Exception):
pass
class KlingAdapter:
def __init__(self, access_key: str, secret_key: str):
if not access_key or not secret_key:
raise ValueError("Kling API credentials are missing")
self.access_key = access_key
self.secret_key = secret_key
def _generate_token(self) -> str:
"""Generate a JWT token for Kling API authentication."""
now = int(time.time())
payload = {
"iss": self.access_key,
"exp": now + 1800, # 30 minutes
"iat": now - 5, # небольшой запас назад
"nbf": now - 5,
}
return jwt.encode(payload, self.secret_key, algorithm="HS256",
headers={"typ": "JWT", "alg": "HS256"})
def _headers(self) -> dict:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._generate_token()}"
}
async def create_video_task(
self,
image_url: str,
prompt: str = "",
negative_prompt: str = "",
model_name: str = "kling-v2-6",
duration: int = 5,
mode: str = "std",
cfg_scale: float = 0.5,
aspect_ratio: str = "16:9",
callback_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create an image-to-video generation task.
Returns the full task data dict including task_id.
"""
body: Dict[str, Any] = {
"model_name": model_name,
"image": image_url,
"prompt": prompt,
"negative_prompt": negative_prompt,
"duration": str(duration),
"mode": mode,
"cfg_scale": cfg_scale,
"aspect_ratio": aspect_ratio,
}
if callback_url:
body["callback_url"] = callback_url
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{KLING_API_BASE}/v1/videos/image2video",
headers=self._headers(),
json=body,
)
data = response.json()
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown Kling API error")
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
task_data = data.get("data", {})
task_id = task_data.get("task_id")
if not task_id:
raise KlingApiException("No task_id returned from Kling API")
logger.info(f"Kling video task created: task_id={task_id}")
return task_data
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
Query the status of a video generation task.
Returns the full task data dict.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
headers=self._headers(),
)
data = response.json()
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown error")
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
return data.get("data", {})
async def wait_for_completion(
self,
task_id: str,
poll_interval: int = 10,
timeout: int = 600,
progress_callback=None,
) -> Dict[str, Any]:
"""
Poll the task status until completion.
Args:
task_id: Kling task ID
poll_interval: seconds between polls
timeout: max seconds to wait
progress_callback: async callable(progress_pct: int) to report progress
Returns:
Final task data dict with video URL on success.
Raises:
KlingApiException on failure or timeout.
"""
start = time.time()
attempt = 0
while True:
elapsed = time.time() - start
if elapsed > timeout:
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
task_data = await self.get_task_status(task_id)
status = task_data.get("task_status")
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
if status == "succeed":
logger.info(f"Kling task {task_id} completed successfully")
return task_data
if status == "failed":
fail_reason = task_data.get("task_status_msg", "Unknown failure")
raise KlingApiException(f"Video generation failed: {fail_reason}")
# Report progress estimate (linear approximation based on typical time)
if progress_callback:
# Estimate: typical gen is ~120s, cap at 90%
estimated_progress = min(int((elapsed / 120) * 90), 90)
attempt += 1
await progress_callback(estimated_progress)
await asyncio.sleep(poll_interval)

View File

@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager @asynccontextmanager
async def _get_client(self): async def _get_client(self):
async with self.session.client( # type: ignore[reportGeneralTypeIssues] async with self.session.client(
"s3", "s3",
endpoint_url=self.endpoint_url, endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id, aws_access_key_id=self.aws_access_key_id,
@@ -56,23 +56,6 @@ class S3Adapter:
print(f"Error downloading from S3: {e}") print(f"Error downloading from S3: {e}")
return None return None
async def stream_file(self, object_name: str, chunk_size: int = 65536):
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
try:
async with self._get_client() as client:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']
while True:
chunk = await body.read(chunk_size)
if not chunk:
break
yield chunk
except ClientError as e:
print(f"Error streaming from S3: {e}")
return
async def delete_file(self, object_name: str): async def delete_file(self, object_name: str):
"""Deletes a file from S3.""" """Deletes a file from S3."""
try: try:

86
aiws.py
View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from aiogram import Bot, Dispatcher, Router, F from aiogram import Bot, Dispatcher, Router, F
@@ -8,6 +9,7 @@ from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command from aiogram.filters import CommandStart, Command
from aiogram.types import Message from aiogram.types import Message
from aiogram.fsm.storage.mongo import MongoStorage from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from fastapi import FastAPI from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info from prometheus_client import Info
@@ -15,8 +17,8 @@ from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА --- # --- ИМПОРТЫ ПРОЕКТА ---
from config import settings
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from api.service.album_service import AlbumService from api.service.album_service import AlbumService
@@ -42,20 +44,17 @@ from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- КОНФИГУРАЦИЯ --- # --- КОНФИГУРАЦИЯ ---
# Настройки теперь берутся из config.py BOT_TOKEN = os.getenv("BOT_TOKEN")
BOT_TOKEN = settings.BOT_TOKEN GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
GEMINI_API_KEY = settings.GEMINI_API_KEY
MONGO_HOST = settings.MONGO_HOST MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
DB_NAME = settings.DB_NAME DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
ADMIN_ID = settings.ADMIN_ID ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
def setup_logging(): def setup_logging():
@@ -65,8 +64,6 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ --- # --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML)) bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API # Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -79,19 +76,26 @@ char_repo = CharacterRepo(mongo_client)
# S3 Adapter # S3 Adapter
s3_adapter = S3Adapter( s3_adapter = S3Adapter(
endpoint_url=settings.MINIO_ENDPOINT, endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
aws_access_key_id=settings.MINIO_ACCESS_KEY, aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
aws_secret_access_key=settings.MINIO_SECRET_KEY, aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
bucket_name=settings.MINIO_BUCKET bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
) )
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY) gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
if bot is None:
raise ValueError("bot is not set") # Kling Adapter (optional, for video generation)
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot) kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
kling_adapter = None
if kling_access_key and kling_secret_key:
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
logger.info("Kling adapter initialized")
else:
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
album_service = AlbumService(dao) album_service = AlbumService(dao)
# Dispatcher # Dispatcher
@@ -128,18 +132,6 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
gen_router.message.middleware(AlbumMiddleware(latency=0.8)) gen_router.message.middleware(AlbumMiddleware(latency=0.8))
async def start_scheduler(service: GenerationService):
while True:
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(60) # Check every 60 seconds
# --- LIFESPAN (Запуск FastAPI + Bot) --- # --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -158,6 +150,7 @@ async def lifespan(app: FastAPI):
app.state.gemini_client = gemini app.state.gemini_client = gemini
app.state.bot = bot app.state.bot = bot
app.state.s3_adapter = s3_adapter app.state.s3_adapter = s3_adapter
app.state.kling_adapter = kling_adapter
app.state.album_service = album_service app.state.album_service = album_service
app.state.users_repo = users_repo # Добавляем репозиторий в state app.state.users_repo = users_repo # Добавляем репозиторий в state
@@ -171,28 +164,17 @@ async def lifespan(app: FastAPI):
# ) # )
# print("🤖 Bot polling started") # print("🤖 Bot polling started")
# 3. ЗАПУСК ШЕДУЛЕРА
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
print("⏰ Scheduler started")
yield yield
# --- SHUTDOWN --- # --- SHUTDOWN ---
print("🛑 Shutting down...") print("🛑 Shutting down...")
# 4. Остановка шедулера
scheduler_task.cancel()
try:
await scheduler_task
except asyncio.CancelledError:
print("⏰ Scheduler stopped")
# 3. Остановка бота # 3. Остановка бота
# polling_task.cancel() polling_task.cancel()
# try: try:
# await polling_task await polling_task
# except asyncio.CancelledError: except asyncio.CancelledError:
# print("🤖 Bot polling stopped") print("🤖 Bot polling stopped")
# 4. Отключение БД # 4. Отключение БД
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается # Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
@@ -219,8 +201,6 @@ app.include_router(api_char_router)
app.include_router(api_gen_router) app.include_router(api_gen_router)
app.include_router(api_album_router) app.include_router(api_album_router)
app.include_router(project_api_router) app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
# Prometheus Metrics (Instrument after all routers are added) # Prometheus Metrics (Instrument after all routers are added)
Instrumentator( Instrumentator(
@@ -259,7 +239,7 @@ if __name__ == "__main__":
async def main(): async def main():
# Создаем конфигурацию uvicorn вручную # Создаем конфигурацию uvicorn вручную
# loop="asyncio" заставляет использовать стандартный цикл # loop="asyncio" заставляет использовать стандартный цикл
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120) config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
server = uvicorn.Server(config) server = uvicorn.Server(config)
# Запускаем сервер (lifespan запустится внутри) # Запускаем сервер (lifespan запустится внутри)

View File

@@ -3,9 +3,9 @@ from fastapi import Request, Depends
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from repos.dao import DAO from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ... # ... ваши импорты ...
@@ -37,29 +37,20 @@ def get_dao(
# так что DAO создастся один раз за запрос. # так что DAO создастся один раз за запрос.
return DAO(mongo_client, s3_adapter) return DAO(mongo_client, s3_adapter)
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
return request.app.state.kling_adapter
# Провайдер сервиса (собирается из DAO и Gemini) # Провайдер сервиса (собирается из DAO и Gemini)
def get_generation_service( def get_generation_service(
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao),
gemini: GoogleAdapter = Depends(get_gemini_client), gemini: GoogleAdapter = Depends(get_gemini_client),
s3_adapter: S3Adapter = Depends(get_s3_adapter), s3_adapter: S3Adapter = Depends(get_s3_adapter),
bot: Bot = Depends(get_bot_client), bot: Bot = Depends(get_bot_client),
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
from api.service.idea_service import IdeaService
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
return IdeaService(dao)
from fastapi import Header from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]: async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)

View File

@@ -5,8 +5,6 @@ from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus from repos.user_repo import UsersRepo, UserStatus
from api.dependency import get_dao
from repos.dao import DAO
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt from jose import JWTError, jwt
from starlette.requests import Request from starlette.requests import Request
@@ -25,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
) )
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str | None = payload.get("sub") username: str = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception
except JWTError: except JWTError:

View File

@@ -1,13 +1,10 @@
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse from api.models.GenerationRequest import GenerationResponse
from models.Album import Album from models.Album import Album
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"]) router = APIRouter(prefix="/api/albums", tags=["Albums"])

View File

@@ -9,10 +9,10 @@ from pymongo import MongoClient
from starlette import status from starlette import status
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse, StreamingResponse from starlette.responses import Response, JSONResponse
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.models import AssetsResponse, AssetResponse from api.models.AssetDTO import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao, get_mongo_client, get_s3_adapter from api.dependency import get_dao, get_mongo_client, get_s3_adapter
@@ -33,46 +33,27 @@ async def get_asset(
asset_id: str, asset_id: str,
request: Request, request: Request,
thumbnail: bool = False, thumbnail: bool = False,
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao)
s3_adapter: S3Adapter = Depends(get_s3_adapter),
) -> Response: ) -> Response:
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}") logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
# Загружаем только метаданные (без data/thumbnail bytes) asset = await dao.assets.get_asset(asset_id)
asset = await dao.assets.get_asset(asset_id, with_data=False) # 2. Проверка на существование
if not asset: if not asset:
raise HTTPException(status_code=404, detail="Asset not found") raise HTTPException(status_code=404, detail="Asset not found")
headers = { headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable" "Cache-Control": "public, max-age=31536000, immutable"
} }
# Thumbnail: маленький, можно грузить в RAM content = asset.data
if thumbnail: media_type = "image/png" # Default, or detect
if asset.minio_thumbnail_object_name and s3_adapter:
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
if thumb_bytes:
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
# Fallback: thumbnail in DB
if asset.thumbnail:
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
# No thumbnail available — fall through to main content
# Main content: стримим из S3 без загрузки в RAM if thumbnail and asset.thumbnail:
if asset.minio_object_name and s3_adapter: content = asset.thumbnail
content_type = "image/png" media_type = "image/jpeg"
# if asset.content_type == AssetContentType.VIDEO:
# content_type = "video/mp4" return Response(content=content, media_type=media_type, headers=headers)
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name),
media_type=content_type,
headers=headers,
)
# Fallback: data stored in DB (legacy)
if asset.data:
return Response(content=asset.data, media_type="image/png", headers=headers)
raise HTTPException(status_code=404, detail="Asset data not found")
@router.delete("/orphans", dependencies=[Depends(get_current_user)]) @router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio( async def delete_orphan_assets_from_minio(
@@ -278,7 +259,8 @@ async def upload_asset(
type=asset.type.value if hasattr(asset.type, "value") else asset.type, type=asset.type.value if hasattr(asset.type, "value") else asset.type,
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type, content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
linked_char_id=asset.linked_char_id, linked_char_id=asset.linked_char_id,
created_at=asset.created_at created_at=asset.created_at,
url=asset.url
) )

View File

@@ -5,11 +5,11 @@ from pydantic import BaseModel
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from api.models import AssetsResponse, AssetResponse from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models import GenerationRequest, GenerationResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from models.Asset import Asset from models.Asset import Asset
from models.Character import Character from models.Character import Character
from api.models import CharacterCreateRequest, CharacterUpdateRequest from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao
@@ -24,15 +24,8 @@ router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[
@router.get("/", response_model=List[Character]) @router.get("/", response_model=List[Character])
async def get_characters( async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
request: Request, logger.info("get_characters called")
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
limit: int = 100,
offset: int = 0
) -> List[Character]:
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"]) user_id_filter = str(current_user["_id"])
if project_id: if project_id:
@@ -41,12 +34,7 @@ async def get_characters(
raise HTTPException(status_code=403, detail="Project access denied") raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None user_id_filter = None
characters = await dao.chars.get_all_characters( characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
created_by=user_id_filter,
project_id=project_id,
limit=limit,
offset=offset
)
return characters return characters
@@ -190,3 +178,10 @@ async def delete_character(
raise HTTPException(status_code=500, detail="Failed to delete character") raise HTTPException(status_code=500, detail="Failed to delete character")
return return
@router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse:
logger.info(f"post_character_generation called. CharacterID: {character_id}")
generation_service = request.app.state.generation_service

View File

@@ -1,32 +1,26 @@
import logging
import os
import json
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends from fastapi.params import Depends
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from config import settings from api import service
from api.dependency import get_generation_service, get_project_id, get_dao from api.dependency import get_generation_service, get_project_id, get_dao
from api.endpoints.auth import get_current_user
from api.models import (
GenerationResponse,
GenerationRequest,
GenerationsResponse,
PromptResponse,
PromptRequest,
GenerationGroupResponse,
FinancialReport,
ExternalGenerationRequest
)
from api.service.generation_service import GenerationService
from repos.dao import DAO from repos.dao import DAO
from utils.external_auth import verify_signature
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.models.VideoGenerationRequest import VideoGenerationRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
from starlette import status
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"]) router = APIRouter(prefix='/api/generations', tags=["Generation"])
@@ -75,53 +69,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.get("/usage", response_model=FinancialReport) @router.post("/_run", response_model=GenerationResponse)
async def get_usage_report(
breakdown: Optional[str] = None, # "user" or "project"
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
return await generation_service.get_financial_report(
user_id=user_id_filter,
project_id=project_id,
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service), generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user), current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id), project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: dao: DAO = Depends(get_dao)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id: if project_id:
@@ -133,6 +86,16 @@ async def post_generation(generation: GenerationRequest, request: Request,
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(request: Request,
@@ -151,35 +114,25 @@ async def get_running_generations(request: Request,
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse) @router.post("/video/_run", response_model=GenerationResponse)
async def get_generation_group(group_id: str, async def post_video_generation(
generation_service: GenerationService = Depends(get_generation_service), video_request: VideoGenerationRequest,
current_user: dict = Depends(get_current_user)): request: Request,
logger.info(f"get_generation_group called for group_id: {group_id}") generation_service: GenerationService = Depends(get_generation_service),
generations = await generation_service.dao.generations.get_generations_by_group(group_id) current_user: dict = Depends(get_current_user),
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] project_id: Optional[str] = Depends(get_project_id),
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) dao: DAO = Depends(get_dao),
) -> GenerationResponse:
"""Start image-to-video generation using Kling AI."""
@router.get("/{generation_id}", response_model=GenerationResponse) logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied")
return gen
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
video_request.project_id = project_id
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
@router.post("/import", response_model=GenerationResponse) @router.post("/import", response_model=GenerationResponse)
@@ -192,13 +145,17 @@ async def import_external_generation(
Import a generation from an external source. Import a generation from an external source.
Requires server-to-server authentication via HMAC signature. Requires server-to-server authentication via HMAC signature.
""" """
import os
from utils.external_auth import verify_signature
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
logger.info("import_external_generation called") logger.info("import_external_generation called")
# Get raw request body for signature verification # Get raw request body for signature verification
body = await request.body() body = await request.body()
# Verify signature # Verify signature
secret = settings.EXTERNAL_API_SECRET secret = os.getenv("EXTERNAL_API_SECRET")
if not secret: if not secret:
logger.error("EXTERNAL_API_SECRET not configured") logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error") raise HTTPException(status_code=500, detail="Server configuration error")
@@ -208,6 +165,7 @@ async def import_external_generation(
raise HTTPException(status_code=401, detail="Invalid signature") raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body # Parse request body
import json
try: try:
data = json.loads(body.decode('utf-8')) data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data) external_gen = ExternalGenerationRequest(**data)
@@ -232,4 +190,4 @@ async def delete_generation(generation_id: str,
deleted = await generation_service.delete_generation(generation_id) deleted = await generation_service.delete_generation(generation_id)
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail="Generation not found") raise HTTPException(status_code=404, detail="Generation not found")
return None return None

View File

@@ -1,104 +0,0 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from api.dependency import get_idea_service, get_project_id, get_generation_service
from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models import GenerationResponse, GenerationsResponse
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
request: IdeaCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
pid = project_id or request.project_id
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
@router.get("", response_model=List[IdeaResponse])
async def get_ideas(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
@router.get("/{idea_id}", response_model=Idea)
async def get_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.get_idea(idea_id)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.put("/{idea_id}", response_model=Idea)
async def update_idea(
idea_id: str,
request: IdeaUpdateRequest,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.update_idea(idea_id, request.name, request.description)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.delete("/{idea_id}")
async def delete_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.delete_idea(idea_id)
if not success:
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
return {"status": "success"}
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)
):
# Depending on how generation service implements filtering by idea_id.
# We might need to update generation_service to support getting by idea_id directly
# or ensure generic get_generations supports it.
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
# Let's check generation_service.get_generations signature again.
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
# I need to update GenerationService.get_generations too!
# For now, let's assume I will update it.
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
@router.post("/{idea_id}/generations/{generation_id}")
async def add_generation_to_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}
@router.delete("/{idea_id}/generations/{generation_id}")
async def remove_generation_from_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}

View File

@@ -1,99 +0,0 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=List[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

View File

@@ -1,6 +1,4 @@
from typing import List, Optional from typing import List, Optional
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from api.dependency import get_dao from api.dependency import get_dao
@@ -14,46 +12,14 @@ class ProjectCreate(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
class ProjectMemberResponse(BaseModel):
id: str
username: str
class ProjectResponse(BaseModel): class ProjectResponse(BaseModel):
id: str id: str
name: str name: str
description: Optional[str] = None description: Optional[str] = None
owner_id: str owner_id: str
members: List[ProjectMemberResponse] members: List[str]
is_owner: bool = False is_owner: bool = False
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
member_responses = []
for member_id in project.members:
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
# But for Web users we might need to search by _id.
# Let's try to get user info.
# Since project.members contains strings (ObjectIds as strings), we search by _id.
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
if not user_doc and member_id.isdigit():
# Fallback for telegram IDs if they are stored as strings of digits
user_doc = await dao.users.get_user(int(member_id))
username = "unknown"
if user_doc:
username = user_doc.get("username", "unknown")
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
return ProjectResponse(
id=project.id,
name=project.name,
description=project.description,
owner_id=project.owner_id,
members=member_responses,
is_owner=(project.owner_id == current_user_id)
)
@router.post("", response_model=ProjectResponse) @router.post("", response_model=ProjectResponse)
async def create_project( async def create_project(
project_data: ProjectCreate, project_data: ProjectCreate,
@@ -68,15 +34,27 @@ async def create_project(
members=[user_id] members=[user_id]
) )
project_id = await dao.projects.create_project(new_project) project_id = await dao.projects.create_project(new_project)
new_project.id = project_id
# Add project to user's project list # Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one( await dao.users.collection.update_one(
{"_id": current_user["_id"]}, {"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}} {"$addToSet": {"project_ids": project_id}}
) )
return await _get_project_response(new_project, user_id, dao) return ProjectResponse(
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
@router.get("", response_model=List[ProjectResponse]) @router.get("", response_model=List[ProjectResponse])
async def get_my_projects( async def get_my_projects(
@@ -88,7 +66,14 @@ async def get_my_projects(
responses = [] responses = []
for p in projects: for p in projects:
responses.append(await _get_project_response(p, user_id, dao)) responses.append(ProjectResponse(
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
return responses return responses
class MemberAdd(BaseModel): class MemberAdd(BaseModel):

View File

@@ -1,18 +0,0 @@
from pydantic import BaseModel
from typing import List, Optional
class UsageStats(BaseModel):
total_runs: int
total_tokens: int
total_input_tokens: int
total_output_tokens: int
total_cost: float
class UsageByEntity(BaseModel):
entity_id: Optional[str] = None
stats: UsageStats
class FinancialReport(BaseModel):
summary: UsageStats
by_user: Optional[List[UsageByEntity]] = None
by_project: Optional[List[UsageByEntity]] = None

View File

@@ -1,7 +1,7 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel
from models.Asset import Asset from models.Asset import Asset
from models.Generation import GenerationStatus from models.Generation import GenerationStatus
@@ -17,8 +17,6 @@ class GenerationRequest(BaseModel):
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10)
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):
@@ -29,6 +27,7 @@ class GenerationsResponse(BaseModel):
class GenerationResponse(BaseModel): class GenerationResponse(BaseModel):
id: str id: str
status: GenerationStatus status: GenerationStatus
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None linked_character_id: Optional[str] = None
@@ -47,16 +46,14 @@ class GenerationResponse(BaseModel):
progress: int = 0 progress: int = 0
cost: Optional[float] = None cost: Optional[float] = None
created_by: Optional[str] = None created_by: Optional[str] = None
generation_group_id: Optional[str] = None # Video-specific
idea_id: Optional[str] = None kling_task_id: Optional[str] = None
video_duration: Optional[int] = None
video_mode: Optional[str] = None
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str

View File

@@ -1,16 +0,0 @@
from typing import Optional
from pydantic import BaseModel
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
class IdeaCreateRequest(BaseModel):
name: str
description: Optional[str] = None
project_id: Optional[str] = None # Optional in body if passed via header/dependency
class IdeaUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
class IdeaResponse(Idea):
last_generation: Optional[GenerationResponse] = None

View File

@@ -1,19 +0,0 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]

View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
class VideoGenerationRequest(BaseModel):
prompt: str = ""
negative_prompt: Optional[str] = ""
image_asset_id: str # ID ассета-картинки для source image
duration: int = 5 # 5 or 10 seconds
mode: str = "std" # "std" or "pro"
model_name: str = "kling-v2-1"
cfg_scale: float = 0.5
aspect_ratio: str = "16:9"
linked_character_id: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -1,7 +0,0 @@
from .AssetDTO import AssetResponse, AssetsResponse
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from .ExternalGenerationDTO import ExternalGenerationRequest
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest

View File

@@ -1,31 +1,28 @@
import asyncio import asyncio
import base64
import logging import logging
import random import random
import base64
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from uuid import uuid4 from io import BytesIO
import httpx import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter from adapters.kling_adapter import KlingAdapter, KlingApiException
from api.models import FinancialReport, UsageStats, UsageByEntity from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse from api.models.VideoGenerationRequest import VideoGenerationRequest
# Импортируйте ваши модели DAO, Asset, Generation корректно # Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality, GenType
from repos.dao import DAO from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации --- # --- Вспомогательная функция генерации ---
async def generate_image_task( async def generate_image_task(
@@ -55,30 +52,29 @@ async def generate_image_task(
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e: except GoogleGenerationException as e:
raise e raise e
finally:
# Освобождаем входные данные — они больше не нужны
del media_group_bytes
images_bytes = [] images_bytes = []
if generated_images_io: if generated_images_io:
for img_io in generated_images_io: for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0) img_io.seek(0)
images_bytes.append(img_io.read()) content = img_io.read()
images_bytes.append(content)
# Закрываем поток
img_io.close() img_io.close()
# Освобождаем список BytesIO сразу
del generated_images_io
return images_bytes, metrics return images_bytes, metrics
class GenerationService: class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
self.dao = dao self.dao = dao
self.gemini = gemini self.gemini = gemini
self.s3_adapter = s3_adapter self.s3_adapter = s3_adapter
self.bot = bot self.bot = bot
self.kling_adapter = kling_adapter
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str: async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
@@ -101,9 +97,10 @@ class GenerationService:
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse: async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id) Generation]:
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id) generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations] generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count) return GenerationsResponse(generations=generations, total_count=total_count)
@@ -117,50 +114,29 @@ class GenerationService:
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None gen_id = None
generation_model = None generation_model = None
try: try:
generation_model = Generation(**generation_request.model_dump(exclude={'count'})) generation_model = Generation(**generation_request.model_dump())
if user_id: if user_id:
generation_model.created_by = user_id generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model) gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id generation_model.id = gen_id
async def runner(gen): async def runner(gen):
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...") logger.info(f"Starting background generation task for ID: {gen.id}")
try: try:
async with generation_semaphore: await self.create_generation(gen)
logger.info(f"Starting background generation task for ID: {gen.id}") logger.info(f"Background generation task finished for ID: {gen.id}")
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception: except Exception:
# если генерация уже пошла и упала — пометим FAILED # если генерация уже пошла и упала — пометим FAILED
try: try:
db_gen = await self.dao.generations.get_generation(gen.id) db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen is not None: db_gen.status = GenerationStatus.FAILED
db_gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(db_gen)
await self.dao.generations.update_generation(db_gen)
except Exception: except Exception:
logger.exception("Failed to mark generation as FAILED") logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed") logger.exception("create_generation task failed")
@@ -174,9 +150,8 @@ class GenerationService:
if gen_id is not None: if gen_id is not None:
try: try:
gen = await self.dao.generations.get_generation(gen_id) gen = await self.dao.generations.get_generation(gen_id)
if gen is not None: gen.status = GenerationStatus.FAILED
gen.status = GenerationStatus.FAILED await self.dao.generations.update_generation(gen)
await self.dao.generations.update_generation(gen)
except Exception: except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task") logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise raise
@@ -204,10 +179,9 @@ class GenerationService:
if char_info is None: if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found") raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image: if generation.use_profile_image:
if char_info.avatar_asset_id is not None: avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) if avatar_asset:
if avatar_asset and avatar_asset.data: media_group_bytes.append(avatar_asset.data)
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
@@ -308,9 +282,7 @@ class GenerationService:
# 5. (Опционально) Обновляем запись генерации ссылками на результаты # 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids # Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [] result_ids = [a.id for a in created_assets]
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids generation.result_list = result_ids
generation.status = GenerationStatus.DONE generation.status = GenerationStatus.DONE
@@ -378,7 +350,8 @@ class GenerationService:
Returns: Returns:
Created Generation object Created Generation object
""" """
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
# Validate image source # Validate image source
external_gen.validate_image_source() external_gen.validate_image_source()
@@ -458,6 +431,168 @@ class GenerationService:
return generation return generation
# === VIDEO GENERATION (Kling) ===
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
"""Create a video generation task (async, returns immediately)."""
if not self.kling_adapter:
raise Exception("Kling adapter is not configured")
generation = Generation(
status=GenerationStatus.RUNNING,
gen_type=GenType.VIDEO,
linked_character_id=request.linked_character_id,
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
quality=Quality.ONEK,
prompt=request.prompt,
assets_list=[request.image_asset_id],
video_duration=request.duration,
video_mode=request.mode,
project_id=request.project_id,
)
if user_id:
generation.created_by = user_id
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
async def runner(gen, req):
logger.info(f"Starting background video generation task for ID: {gen.id}")
try:
await self.create_video_generation(gen, req)
logger.info(f"Background video generation task finished for ID: {gen.id}")
except Exception:
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen and db_gen.status != GenerationStatus.FAILED:
db_gen.status = GenerationStatus.FAILED
db_gen.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark video generation as FAILED")
logger.exception("create_video_generation task failed")
asyncio.create_task(runner(generation, request))
return GenerationResponse(**generation.model_dump())
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
"""Background video generation: call Kling API, poll, download result, save asset."""
start_time = datetime.now()
try:
# 1. Get source image presigned URL
asset = await self.dao.assets.get_asset(request.image_asset_id)
if not asset:
raise Exception(f"Asset {request.image_asset_id} not found")
if not asset.minio_object_name:
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
if not presigned_url:
raise Exception("Failed to generate presigned URL for source image")
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
# 2. Create Kling task
task_data = await self.kling_adapter.create_video_task(
image_url=presigned_url,
prompt=request.prompt,
negative_prompt=request.negative_prompt or "",
model_name=request.model_name,
duration=request.duration,
mode=request.mode,
cfg_scale=request.cfg_scale,
aspect_ratio=request.aspect_ratio,
)
task_id = task_data.get("task_id")
generation.kling_task_id = task_id
await self.dao.generations.update_generation(generation)
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
# 3. Poll for completion with progress updates
async def progress_callback(progress_pct: int):
generation.progress = progress_pct
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
result = await self.kling_adapter.wait_for_completion(
task_id=task_id,
poll_interval=10,
timeout=600,
progress_callback=progress_callback,
)
# 4. Extract video URL and download
works = result.get("task_result", {}).get("videos", [])
if not works:
raise Exception("No video in Kling result")
video_url = works[0].get("url")
video_duration = works[0].get("duration", request.duration)
if not video_url:
raise Exception("No video URL in Kling result")
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
async with httpx.AsyncClient(timeout=120.0) as client:
video_response = await client.get(video_url)
video_response.raise_for_status()
video_bytes = video_response.content
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
# 5. Upload to S3
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
# 6. Create Asset
new_asset = Asset(
name=f"Video_{generation.linked_character_id or 'gen'}",
type=AssetType.GENERATED,
content_type=AssetContentType.VIDEO,
linked_char_id=generation.linked_character_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=None, # видео thumbnails можно добавить позже
created_by=generation.created_by,
project_id=generation.project_id,
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
# 7. Finalize generation
end_time = datetime.now()
generation.result_list = [new_asset.id]
generation.result = new_asset.id
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.video_duration = video_duration
generation.execution_time_seconds = (end_time - start_time).total_seconds()
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
except KlingApiException as e:
logger.error(f"Kling API error for generation {generation.id}: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
except Exception as e:
logger.error(f"Video generation {generation.id} failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
async def delete_generation(self, generation_id: str) -> bool: async def delete_generation(self, generation_id: str) -> bool:
""" """
Soft delete generation by marking it as deleted. Soft delete generation by marking it as deleted.
@@ -473,62 +608,4 @@ class GenerationService:
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}") logger.error(f"Error deleting generation {generation_id}: {e}")
return False return False
async def cleanup_stale_generations(self):
"""
Cancels generations that have been running for more than 1 hour.
"""
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
"""
Generates a financial usage report for a specific user or project.
'breakdown_by' can be 'created_by' or 'project_id'.
"""
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user = None
by_project = None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(
summary=summary,
by_user=by_user,
by_project=by_project
)

View File

@@ -1,75 +0,0 @@
from typing import List, Optional
from datetime import datetime
from repos.dao import DAO
from models.Idea import Idea
class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return None
if name is not None:
idea.name = name
if description is not None:
idea.description = description
idea.updated_at = datetime.now()
await self.dao.ideas.update_idea(idea)
return idea
async def delete_idea(self, idea_id: str) -> bool:
return await self.dao.ideas.delete_idea(idea_id)
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Link
gen.idea_id = idea_id
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists (optional, but good for validation)
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Unlink only if currently linked to this idea
if gen.idea_id == idea_id:
gen.idea_id = None
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
return False

View File

@@ -1,79 +0,0 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

View File

@@ -1,39 +0,0 @@
import os
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# Telegram Bot
BOT_TOKEN: str
ADMIN_ID: int = 0
# AI Service
GEMINI_API_KEY: str
# Database
MONGO_HOST: str = "mongodb://localhost:27017"
DB_NAME: str = "my_bot_db"
# S3 Storage (Minio)
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "ai-char"
# External API
EXTERNAL_API_SECRET: Optional[str] = None
# JWT Security
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
model_config = SettingsConfigDict(
env_file=os.getenv("ENV_FILE", ".env"),
env_file_encoding="utf-8",
extra="ignore"
)
settings = Settings()

View File

@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
# Ждем сбора остальных частей # Ждем сбора остальных частей
await asyncio.sleep(self.latency) await asyncio.sleep(self.latency)
# Проверяем, что ключ все еще существует # Проверяем, что ключ все еще существует (на всякий случай)
if group_id in self.album_data: if group_id in self.album_data:
# Передаем собранный альбом в хендлер # Передаем собранный альбом в хендлер
# Сортируем по message_id, чтобы порядок был верным # Сортируем по message_id, чтобы порядок был верным
current_album = self.album_data[group_id] self.album_data[group_id].sort(key=lambda x: x.message_id)
current_album.sort(key=lambda x: x.message_id) data["album"] = self.album_data[group_id]
data["album"] = current_album
return await handler(event, data) return await handler(event, data)
finally: finally:
# ЧИСТКА: Удаляем запись после обработки или таймаута # ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
# Используем pop() с дефолтом, чтобы избежать KeyError # Проверяем, что мы удаляем именно то, что создали, и ключ существует
self.album_data.pop(group_id, None) if group_id in self.album_data and self.album_data[group_id][0] == event:
del self.album_data[group_id]
else: else:
# Если группа уже собирается - просто добавляем и выходим # Если группа уже собирается - просто добавляем и выходим

View File

@@ -7,6 +7,7 @@ from pydantic import BaseModel, computed_field, Field, model_validator
class AssetContentType(str, Enum): class AssetContentType(str, Enum):
IMAGE = 'image' IMAGE = 'image'
VIDEO = 'video'
PROMPT = 'prompt' PROMPT = 'prompt'
class AssetType(str, Enum): class AssetType(str, Enum):
@@ -30,7 +31,6 @@ class Asset(BaseModel):
tags: List[str] = [] tags: List[str] = []
created_by: Optional[str] = None created_by: Optional[str] = None
project_id: Optional[str] = None project_id: Optional[str] = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@@ -63,7 +63,6 @@ class Asset(BaseModel):
# --- CALCULATED FIELD --- # --- CALCULATED FIELD ---
@computed_field @computed_field
@property
def url(self) -> str: def url(self) -> str:
""" """
Это поле автоматически вычислится и попадет в model_dump() / .json() Это поле автоматически вычислится и попадет в model_dump() / .json()

View File

@@ -9,6 +9,7 @@ class Character(BaseModel):
name: str name: str
avatar_asset_id: Optional[str] = None avatar_asset_id: Optional[str] = None
avatar_image: Optional[str] = None avatar_image: Optional[str] = None
character_image_data: Optional[bytes] = None
character_image_doc_tg_id: Optional[str] = None character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: Optional[str] = None character_image_tg_id: Optional[str] = None
character_bio: Optional[str] = None character_bio: Optional[str] = None

View File

@@ -16,6 +16,7 @@ class GenerationStatus(str, Enum):
class Generation(BaseModel): class Generation(BaseModel):
id: Optional[str] = None id: Optional[str] = None
status: GenerationStatus = GenerationStatus.RUNNING status: GenerationStatus = GenerationStatus.RUNNING
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None linked_character_id: Optional[str] = None
telegram_id: Optional[int] = None telegram_id: Optional[int] = None
@@ -35,10 +36,12 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None album_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None # Video-specific fields
kling_task_id: Optional[str] = None
video_duration: Optional[int] = None # 5 or 10 seconds
video_mode: Optional[str] = None # "std" or "pro"
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,13 +0,0 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: Optional[str] = None
name: str = "New Idea"
description: Optional[str] = None
project_id: Optional[str] = None
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)

View File

@@ -1,23 +0,0 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

View File

@@ -34,10 +34,12 @@ class Quality(str, Enum):
class GenType(str, Enum): class GenType(str, Enum):
TEXT = 'Text' TEXT = 'Text'
IMAGE = 'Image' IMAGE = 'Image'
VIDEO = 'Video'
@property @property
def value_type(self) -> str: def value_type(self) -> str:
return { return {
GenType.TEXT: 'Text', GenType.TEXT: 'Text',
GenType.IMAGE: 'Image', GenType.IMAGE: 'Image',
GenType.VIDEO: 'Video',
}[self] }[self]

View File

@@ -1,8 +1,6 @@
from typing import Any, List, Optional from typing import List, Optional
import logging import logging
from datetime import datetime, UTC
from bson import ObjectId from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset from models.Asset import Asset
@@ -21,8 +19,7 @@ class AssetsRepo:
# Main data # Main data
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
uid = uuid4().hex[:8] object_name = f"{asset.type.value}/{ts}_{asset.name}"
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
uploaded = await self.s3.upload_file(object_name, asset.data) uploaded = await self.s3.upload_file(object_name, asset.data)
if uploaded: if uploaded:
@@ -35,8 +32,7 @@ class AssetsRepo:
# Thumbnail # Thumbnail
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
uid = uuid4().hex[:8] thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail) uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
if uploaded_thumb: if uploaded_thumb:
@@ -51,7 +47,7 @@ class AssetsRepo:
return str(res.inserted_id) return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]: async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter: dict[str, Any]= {"is_deleted": {"$ne": True}} filter = {}
if asset_type: if asset_type:
filter["type"] = asset_type filter["type"] = asset_type
args = {} args = {}
@@ -138,8 +134,7 @@ class AssetsRepo:
if self.s3: if self.s3:
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
uid = uuid4().hex[:8] object_name = f"{asset.type.value}/{ts}_{asset.name}"
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
if await self.s3.upload_file(object_name, asset.data): if await self.s3.upload_file(object_name, asset.data):
asset.minio_object_name = object_name asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name asset.minio_bucket = self.s3.bucket_name
@@ -147,8 +142,7 @@ class AssetsRepo:
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
uid = uuid4().hex[:8] thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, asset.thumbnail): if await self.s3.upload_file(thumb_name, asset.thumbnail):
asset.minio_thumbnail_object_name = thumb_name asset.minio_thumbnail_object_name = thumb_name
asset.thumbnail = None asset.thumbnail = None
@@ -175,8 +169,6 @@ class AssetsRepo:
filter["linked_char_id"] = character_id filter["linked_char_id"] = character_id
if created_by: if created_by:
filter["created_by"] = created_by filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id: if project_id:
filter["project_id"] = project_id filter["project_id"] = project_id
return await self.collection.count_documents(filter) return await self.collection.count_documents(filter)
@@ -205,61 +197,6 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)}) res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0 return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict: async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO.""" """Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3: if not self.s3:
@@ -279,8 +216,7 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
uid = uuid4().hex[:8] object_name = f"{type_}/{ts}_{asset_id}_{name}"
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
if await self.s3.upload_file(object_name, data): if await self.s3.upload_file(object_name, data):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},
@@ -307,8 +243,7 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
uid = uuid4().hex[:8] thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, thumb): if await self.s3.upload_file(thumb_name, thumb):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},

View File

@@ -15,24 +15,26 @@ class CharacterRepo:
character.id = str(op.inserted_id) character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str) -> Character | None: async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
res = await self.collection.find_one({"_id": ObjectId(character_id)}) args = {}
if not with_image_data:
args["character_image_data"] = 0
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
if res is None: if res is None:
return None return None
else: else:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
filter = {} filter = {}
if created_by: if created_by:
filter["created_by"] = created_by filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id: if project_id:
filter["project_id"] = project_id filter["project_id"] = project_id
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None) args = {"character_image_data": 0} # don't return image data for list
res = await self.collection.find(filter, args).to_list(None)
chars = [] chars = []
for doc in res: for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))

View File

@@ -6,8 +6,6 @@ from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from typing import Optional from typing import Optional
@@ -21,5 +19,3 @@ class DAO:
self.albums = AlbumsRepo(client, db_name) self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name) self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name) self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)

View File

@@ -1,5 +1,4 @@
from typing import Any, Optional, List from typing import Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset from PIL.ImageChops import offset
from bson import ObjectId from bson import ObjectId
@@ -17,7 +16,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump()) res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Generation | None: async def get_generation(self, generation_id: str) -> Optional[Generation]:
res = await self.collection.find_one({"_id": ObjectId(generation_id)}) res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None: if res is None:
return None return None
@@ -26,29 +25,20 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]: limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
filter: dict[str, Any] = {"is_deleted": False} filter = {"is_deleted": False}
if character_id is not None: if character_id is not None:
filter["linked_character_id"] = character_id filter["linked_character_id"] = character_id
if status is not None: if status is not None:
filter["status"] = status filter["status"] = status
if created_by is not None: if created_by is not None:
filter["created_by"] = created_by filter["created_by"] = created_by
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None. filter["project_id"] = None
# But if project_id is passed, we filter by that.
if project_id is None:
filter["project_id"] = None
if project_id is not None: if project_id is not None:
filter["project_id"] = project_id filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
# If fetching for an idea, sort by created_at ascending (cronological) res = await self.collection.find(filter).sort("created_at", -1).skip(
# Otherwise typically descending (newest first)
sort_order = 1 if idea_id else -1
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
for generation in res: for generation in res:
@@ -57,7 +47,7 @@ class GenerationRepo:
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int: album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
@@ -65,14 +55,8 @@ class GenerationRepo:
args["status"] = status args["status"] = status
if created_by is not None: if created_by is not None:
args["created_by"] = created_by args["created_by"] = created_by
if project_id is None:
args["project_id"] = None
if project_id is not None: if project_id is not None:
args["project_id"] = project_id args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -93,177 +77,3 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
"""
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
"""
pipeline = []
# 1. Match active done generations
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
if created_by:
match_stage["created_by"] = created_by
if project_id:
match_stage["project_id"] = project_id
pipeline.append({"$match": match_stage})
# 2. Group by null (total)
pipeline.append({
"$group": {
"_id": None,
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(1)
if not res:
return {
"total_runs": 0,
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cost": 0.0
}
result = res[0]
result.pop("_id")
result["total_cost"] = round(result["total_cost"], 4)
return result
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
"""
Returns usage statistics grouped by user or project.
"""
pipeline = []
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
if project_id:
match_stage["project_id"] = project_id
if created_by:
match_stage["created_by"] = created_by
pipeline.append({"$match": match_stage})
pipeline.append({
"$group": {
"_id": f"${group_by}",
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
pipeline.append({"$sort": {"total_cost": -1}})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(None)
results = []
for item in res:
entity_id = item.pop("_id")
item["total_cost"] = round(item["total_cost"], 4)
results.append({
"entity_id": str(entity_id) if entity_id else "unknown",
"stats": item
})
return results
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
res = await self.collection.update_many(
{
"status": GenerationStatus.RUNNING,
"created_at": {"$lt": cutoff_time}
},
{
"$set": {
"status": GenerationStatus.FAILED,
"failed_reason": "Timeout: Execution time limit exceeded",
"updated_at": datetime.now(UTC)
}
}
)
return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

View File

@@ -1,91 +0,0 @@
from typing import Optional, List
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Idea import Idea
class IdeaRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["ideas"]
async def create_idea(self, idea: Idea) -> str:
res = await self.collection.insert_one(idea.model_dump())
return str(res.inserted_id)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
if not ObjectId.is_valid(idea_id):
return None
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
if res:
res["id"] = str(res.pop("_id"))
return Idea(**res)
return None
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
if project_id:
match_stage = {"project_id": project_id, "is_deleted": False}
else:
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
pipeline = [
{"$match": match_stage},
{"$sort": {"updated_at": -1}},
{"$skip": offset},
{"$limit": limit},
# Add string id field for lookup
{"$addFields": {"str_id": {"$toString": "$_id"}}},
# Lookup generations
{
"$lookup": {
"from": "generations",
"let": {"idea_id": "$str_id"},
"pipeline": [
{
"$match": {
"$and": [
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
{"status": "done"},
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
{"is_deleted": False}
]
}
},
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
{"$limit": 1}
],
"as": "generations"
}
},
# Unwind generations array (preserve ideas without generations)
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
# Rename for clarity
{"$addFields": {
"last_generation": "$generations",
"id": "$str_id"
}},
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
]
return await self.collection.aggregate(pipeline).to_list(None)
async def delete_idea(self, idea_id: str) -> bool:
if not ObjectId.is_valid(idea_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(idea_id)},
{"$set": {"is_deleted": True}}
)
return res.modified_count > 0
async def update_idea(self, idea: Idea) -> bool:
if not idea.id or not ObjectId.is_valid(idea.id):
return False
idea_dict = idea.model_dump()
if "id" in idea_dict:
del idea_dict["id"]
res = await self.collection.update_one(
{"_id": ObjectId(idea.id)},
{"$set": idea_dict}
)
return res.modified_count > 0

View File

@@ -1,97 +0,0 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0

View File

@@ -51,4 +51,4 @@ python-jose[cryptography]==3.3.0
python-multipart==0.0.22 python-multipart==0.0.22
email-validator email-validator
prometheus-fastapi-instrumentator prometheus-fastapi-instrumentator
pydantic-settings==2.13.0 PyJWT

View File

@@ -51,66 +51,57 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
wait_msg = await message.answer("💾 Сохраняю персонажа...") wait_msg = await message.answer("💾 Сохраняю персонажа...")
try: try:
# 1. Скачиваем файл (один раз) # ВОТ ТУТ скачиваем файл (прямо перед сохранением)
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
file_io = await bot.download(file_id) file_io = await bot.download(file_id)
file_bytes = file_io.read() # photo_bytes = file_io.getvalue() # Получаем байты
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
# Создаем модель
char = Character( char = Character(
id=None, id=None,
name=name, name=name,
character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio, character_bio=bio,
created_by=str(message.from_user.id) created_by=str(message.from_user.id)
) )
file_io.close()
# Сохраняем, чтобы получить ID
# Сохраняем через DAO
await dao.chars.add_character(char) await dao.chars.add_character(char)
file_info = await bot.get_file(char.character_image_doc_tg_id)
# 3. Создаем Asset (связанный с персонажем) file_bytes = await bot.download_file(file_info.file_path)
avatar_asset_id = await dao.assets.create_asset( file_io = file_bytes.read()
Asset( avatar_asset = await dao.assets.create_asset(
name="avatar.png", Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
type=AssetType.UPLOADED, tg_doc_file_id=file_id))
content_type=AssetContentType.IMAGE, char.avatar_image = avatar_asset.link
linked_char_id=str(char.id),
data=file_bytes,
tg_doc_file_id=file_id
)
)
# 4. Обновляем персонажа ссылками на ассет
char.avatar_asset_id = avatar_asset_id
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
# Отправляем подтверждение # Отправляем подтверждение
# Используем байты для отправки обратно
photo_msg = await message.answer_photo( photo_msg = await message.answer_photo(
photo=BufferedInputFile(file_bytes, filename="char.jpg"), photo=BufferedInputFile(file_io,
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
caption=( caption=(
"🎉 <b>Персонаж создан!</b>\n\n" "🎉 <b>Персонаж создан!</b>\n\n"
f"👤 <b>Имя:</b> {char.name}\n" f"👤 <b>Имя:</b> {char.name}\n"
f"📝 <b>Био:</b> {char.character_bio}" f"📝 <b>Био:</b> {char.character_bio}"
) )
) )
file_bytes.close()
# Сохраняем TG ID фото (которое отправили как фото, а не документ) char.character_image_tg_id = photo_msg.photo[0].file_id
char.character_image_tg_id = photo_msg.photo[-1].file_id
# Финальное обновление персонажа
await dao.chars.update_char(char.id, char) await dao.chars.update_char(char.id, char)
await wait_msg.delete() await wait_msg.delete()
file_io.close()
# Сбрасываем состояние # Сбрасываем состояние
await state.clear() await state.clear()
except Exception as e: except Exception as e:
logger.error(f"Error creating character: {e}") logging.error(e)
traceback.print_exc()
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}") await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
@router.message(Command("chars")) @router.message(Command("chars"))

View File

@@ -3,17 +3,17 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import MagicMock from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio import asyncio
from config import settings
from aiws import app from main import app
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_dao from api.dependency import get_dao
from repos.dao import DAO from repos.dao import DAO
from models.Character import Character from models.Character import Character
# Config for test DB # Config for test DB
MONGO_HOST = settings.MONGO_HOST MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_chars" DB_NAME = "bot_db_test_chars"
# Mock User # Mock User

View File

@@ -10,13 +10,13 @@ import json
import requests import requests
import base64 import base64
import os import os
from config import settings from dotenv import load_dotenv
# Load env is not needed as settings handles it load_dotenv()
# Configuration # Configuration
API_URL = "http://localhost:8090/api/generations/import" API_URL = "http://localhost:8090/api/generations/import"
SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production" SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
# Sample generation data # Sample generation data
generation_data = { generation_data = {

View File

@@ -1,96 +0,0 @@
import asyncio
import os
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from bson import ObjectId
# Import from project root (requires PYTHONPATH=.)
from api.service.idea_service import IdeaService
from repos.dao import DAO
from models.Idea import Idea
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from config import settings
MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_idea_flow():
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
service = IdeaService(dao)
# 1. Create an Idea
print("Creating idea...")
user_id = "test_user_123"
project_id = "test_project_abc"
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
print(f"Idea created: {idea.id} - {idea.name}")
# 2. Update Idea
print("Updating idea...")
updated_idea = await service.update_idea(idea.id, description="Updated description")
print(f"Idea updated: {updated_idea.description}")
if updated_idea.description == "Updated description":
print("✅ Idea update successful")
else:
print("❌ Idea update FAILED")
# 3. Add Generation linked to Idea
print("Creating generation linked to idea...")
gen = Generation(
prompt="idea generation 1",
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
project_id=project_id,
created_by=user_id,
aspect_ratio=AspectRatios.NINESIXTEEN,
quality=Quality.ONEK,
assets_list=[]
)
gen_id = await dao.generations.create_generation(gen)
print(f"Created generation: {gen_id}")
# Link generation to idea
print("Linking generation to idea...")
success = await service.add_generation_to_idea(idea.id, gen_id)
if success:
print("✅ Linking successful")
else:
print("❌ Linking FAILED")
# Debug: Check if generation was saved with idea_id
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
# 4. Fetch Generations for Idea (Verify filtering and ordering)
print("Fetching generations for idea...")
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
print(f"Found {len(gens)} generations in idea")
if len(gens) == 1 and gens[0].id == gen_id:
print("✅ Generation retrieval successful")
else:
print("❌ Generation retrieval FAILED")
# 5. Fetch Ideas for Project
ideas = await service.get_ideas(project_id)
print(f"Found {len(ideas)} ideas for project")
# Cleaning up
print("Cleaning up...")
await service.delete_idea(idea.id)
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
# Verify deletion
deleted_idea = await service.get_idea(idea.id)
# IdeaRepo.delete_idea logic sets is_deleted=True
if deleted_idea and deleted_idea.is_deleted:
print(f"✅ Idea deleted successfully")
# Hard delete for cleanup
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
if __name__ == "__main__":
asyncio.run(test_idea_flow())

View File

@@ -1,14 +1,15 @@
import asyncio import asyncio
import os import os
from config import settings from dotenv import load_dotenv
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
async def test_s3(): async def test_s3():
load_dotenv()
endpoint = settings.MINIO_ENDPOINT endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
access_key = settings.MINIO_ACCESS_KEY access_key = os.getenv("MINIO_ACCESS_KEY")
secret_key = settings.MINIO_SECRET_KEY secret_key = os.getenv("MINIO_SECRET_KEY")
bucket = settings.MINIO_BUCKET bucket = os.getenv("MINIO_BUCKET")
print(f"Connecting to {endpoint}, bucket: {bucket}") print(f"Connecting to {endpoint}, bucket: {bucket}")

View File

@@ -1,50 +0,0 @@
import asyncio
import os
from datetime import datetime, timedelta, UTC
from motor.motor_asyncio import AsyncIOMotorClient
from models.Generation import Generation, GenerationStatus
from repos.generation_repo import GenerationRepo
from config import settings
# Mock configs if not present in env
MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
async def test_scheduler():
client = AsyncIOMotorClient(MONGO_HOST)
repo = GenerationRepo(client, db_name=DB_NAME)
# 1. Create a "stale" generation (2 hours ago)
stale_gen = Generation(
prompt="stale test",
status=GenerationStatus.RUNNING,
created_at=datetime.now(UTC) - timedelta(minutes=120),
assets_list=[],
aspect_ratio="NINESIXTEEN",
quality="ONEK"
)
gen_id = await repo.create_generation(stale_gen)
print(f"Created stale generation: {gen_id}")
# 2. Run cleanup
print("Running cleanup...")
count = await repo.cancel_stale_generations(timeout_minutes=60)
print(f"Cleaned up {count} generations")
# 3. Verify status
updated_gen = await repo.get_generation(gen_id)
print(f"Generation status: {updated_gen.status}")
print(f"Failed reason: {updated_gen.failed_reason}")
if updated_gen.status == GenerationStatus.FAILED:
print("✅ SUCCESS: Generation marked as FAILED")
else:
print("❌ FAILURE: Generation status not updated")
# Cleanup
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
if __name__ == "__main__":
asyncio.run(test_scheduler())

View File

@@ -10,11 +10,10 @@ from repos.dao import DAO
from models.Album import Album from models.Album import Album
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality
from config import settings
# Mock config # Mock config
# Use the same host as aiws.py but different DB # Use the same host as aiws.py but different DB
MONGO_HOST = settings.MONGO_HOST MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_albums" DB_NAME = "bot_db_test_albums"
async def test_albums(): async def test_albums():
@@ -84,6 +83,8 @@ async def test_albums():
client.close() client.close()
if __name__ == "__main__": if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
try: try:
asyncio.run(test_albums()) asyncio.run(test_albums())
except Exception as e: except Exception as e:

View File

@@ -1,28 +1,29 @@
import asyncio import asyncio
import os import os
from datetime import datetime from datetime import datetime
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from config import settings
from models.Asset import Asset, AssetType from models.Asset import Asset, AssetType
from repos.assets_repo import AssetsRepo from repos.assets_repo import AssetsRepo
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
# Load env is not needed as settings handles it # Load env to get credentials
load_dotenv()
async def test_integration(): async def test_integration():
print("🚀 Starting integration test...") print("🚀 Starting integration test...")
# 1. Setup Dependencies # 1. Setup Dependencies
mongo_uri = settings.MONGO_HOST mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
client = AsyncIOMotorClient(mongo_uri) client = AsyncIOMotorClient(mongo_uri)
db_name = settings.DB_NAME + "_test" db_name = os.getenv("DB_NAME", "bot_db_test")
s3_adapter = S3Adapter( s3_adapter = S3Adapter(
endpoint_url=settings.MINIO_ENDPOINT, endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
aws_access_key_id=settings.MINIO_ACCESS_KEY, aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
aws_secret_access_key=settings.MINIO_SECRET_KEY, aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
bucket_name=settings.MINIO_BUCKET bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
) )
repo = AssetsRepo(client, s3_adapter, db_name=db_name) repo = AssetsRepo(client, s3_adapter, db_name=db_name)

View File

@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from config import settings
# Настройки безопасности берутся из config.py # Настройки безопасности (лучше вынести в config/env, но для старта здесь)
SECRET_KEY = settings.SECRET_KEY # SECRET_KEY должен быть сложным и секретным в продакшене!
ALGORITHM = settings.ALGORITHM SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")