Compare commits
44 Commits
c25b029006
...
enviroment
| Author | SHA1 | Date | |
|---|---|---|---|
| 5aa6391dc8 | |||
| ffb0463fe0 | |||
| dd0f8a1cb6 | |||
| 4af5134726 | |||
| 7488665d04 | |||
| ecc88aca62 | |||
| 70f50170fc | |||
| f4207fc4c1 | |||
| c50d2c8ad9 | |||
| 4586daac38 | |||
| 198ac44960 | |||
| d820d9145b | |||
| c93e577bcf | |||
| c5d4849bff | |||
| 9abfbef871 | |||
| 68a3f529cb | |||
| e2c050515d | |||
| 5e7dc19bf3 | |||
| 97483b7030 | |||
| 2d3da59de9 | |||
| 279cb5c6f6 | |||
| 30138bab38 | |||
| 977cab92f8 | |||
| dcab238d3e | |||
| 9d2e4e47de | |||
| c6142715d9 | |||
| 456562ec1d | |||
| 0d0fbdf7d6 | |||
| f63bcedb13 | |||
| be92c766ac | |||
| 482bc1d9b7 | |||
| a2321cf070 | |||
| 29ccd5743e | |||
| d9de2f48d2 | |||
| 1ddeb0af46 | |||
| a7c2319f13 | |||
| 00e83b8561 | |||
| a9d24c725e | |||
| 458b6ebfc3 | |||
| 668aadcdc9 | |||
| 4461964791 | |||
| fa3e1bb05f | |||
| 8a89b27624 | |||
| c17c47ccc1 |
19
.dockerignore
Normal file
19
.dockerignore
Normal file
@@ -0,0 +1,19 @@
|
||||
.git
|
||||
.gitignore
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
.venv/
|
||||
node_modules/
|
||||
tmp/
|
||||
logs/
|
||||
*.log
|
||||
dist/
|
||||
build/
|
||||
.cache/
|
||||
.idea/
|
||||
.vscode/
|
||||
3
.env
3
.env
@@ -7,4 +7,5 @@ MINIO_ENDPOINT=http://31.59.58.220:9000
|
||||
MINIO_ACCESS_KEY=admin
|
||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||
MINIO_BUCKET=ai-char
|
||||
MODE=production
|
||||
MODE=production
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
29
.gitignore
vendored
29
.gitignore
vendored
@@ -1,15 +1,26 @@
|
||||
minio_backup.tar.gz
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
.DS_Store
|
||||
**/__pycache__/
|
||||
# Игнорируем файлы скомпилированного байт-кода напрямую
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# Игнорируем расширения CPython конкретно
|
||||
*.cpython-*.pyc
|
||||
|
||||
# Игнорируем файлы .DS_Store на всех уровнях
|
||||
**/.*.DS_Store
|
||||
**/.DS_Store
|
||||
**/.DS_Store
|
||||
.idea/ai-char-bot.iml
|
||||
.idea
|
||||
.venv
|
||||
.vscode
|
||||
.vscode/launch.json
|
||||
middlewares/__pycache__/
|
||||
middlewares/*.pyc
|
||||
api/__pycache__/
|
||||
api/*.pyc
|
||||
repos/__pycache__/
|
||||
repos/*.pyc
|
||||
adapters/__pycache__/
|
||||
adapters/*.pyc
|
||||
services/__pycache__/
|
||||
services/*.pyc
|
||||
utils/__pycache__/
|
||||
utils/*.pyc
|
||||
.vscode/launch.json
|
||||
repos/__pycache__/assets_repo.cpython-313.pyc
|
||||
|
||||
31
.vscode/launch.json
vendored
31
.vscode/launch.json
vendored
@@ -7,38 +7,15 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"main:app",
|
||||
"aiws:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8090"
|
||||
"8090",
|
||||
"--host",
|
||||
"0.0.0.0"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Tests: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"${file}"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY . .
|
||||
|
||||
# Запуск приложения (замени app.py на свой файл)
|
||||
CMD ["python", "main.py"]
|
||||
CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -23,28 +23,30 @@ class GoogleAdapter:
|
||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||
contents = [prompt]
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
||||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
||||
contents : list [Any]= [prompt]
|
||||
opened_images = []
|
||||
if images_list:
|
||||
logger.info(f"Preparing content with {len(images_list)} images")
|
||||
for img_bytes in images_list:
|
||||
try:
|
||||
# Gemini API требует PIL Image на входе
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
contents.append(image)
|
||||
opened_images.append(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {e}")
|
||||
else:
|
||||
logger.info("Preparing content with no images")
|
||||
return contents
|
||||
return contents, opened_images
|
||||
|
||||
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
|
||||
"""
|
||||
Генерация текста (Чат или Vision).
|
||||
Возвращает строку с ответом.
|
||||
"""
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating text: {prompt}")
|
||||
try:
|
||||
response = self.client.models.generate_content(
|
||||
@@ -68,14 +70,17 @@ class GoogleAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Text API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
"""
|
||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||
Возвращает список байтовых потоков (готовых к отправке).
|
||||
"""
|
||||
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
||||
|
||||
start_time = datetime.now()
|
||||
@@ -100,9 +105,21 @@ class GoogleAdapter:
|
||||
|
||||
if response.usage_metadata:
|
||||
token_usage = response.usage_metadata.total_token_count
|
||||
|
||||
if response.parts is None and response.candidates[0].finish_reason is not None:
|
||||
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
||||
|
||||
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
|
||||
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
||||
raise GoogleGenerationException(
|
||||
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
|
||||
)
|
||||
|
||||
# Check candidate-level block
|
||||
if response.parts is None:
|
||||
response_reason = (
|
||||
response.candidates[0].finish_reason
|
||||
if response.candidates and len(response.candidates) > 0
|
||||
else "Unknown"
|
||||
)
|
||||
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
|
||||
|
||||
generated_images = []
|
||||
|
||||
@@ -113,7 +130,9 @@ class GoogleAdapter:
|
||||
try:
|
||||
# 1. Берем сырые байты
|
||||
raw_data = part.inline_data.data
|
||||
byte_arr = io.BytesIO(raw_data)
|
||||
if raw_data is None:
|
||||
raise GoogleGenerationException("Generation returned no data")
|
||||
byte_arr : io.BytesIO = io.BytesIO(raw_data)
|
||||
|
||||
# 2. Нейминг (формально, для TG)
|
||||
timestamp = datetime.now().timestamp()
|
||||
@@ -147,4 +166,8 @@ class GoogleAdapter:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Image API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
del contents
|
||||
@@ -18,7 +18,7 @@ class S3Adapter:
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_client(self):
|
||||
async with self.session.client(
|
||||
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
@@ -56,6 +56,23 @@ class S3Adapter:
|
||||
print(f"Error downloading from S3: {e}")
|
||||
return None
|
||||
|
||||
async def stream_file(self, object_name: str, chunk_size: int = 65536):
|
||||
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
|
||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
||||
body = response['Body']
|
||||
|
||||
while True:
|
||||
chunk = await body.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
except ClientError as e:
|
||||
print(f"Error streaming from S3: {e}")
|
||||
return
|
||||
|
||||
async def delete_file(self, object_name: str):
|
||||
"""Deletes a file from S3."""
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aiogram import Bot, Dispatcher, Router, F
|
||||
@@ -9,15 +8,18 @@ from aiogram.enums import ParseMode
|
||||
from aiogram.filters import CommandStart, Command
|
||||
from aiogram.types import Message
|
||||
from aiogram.fsm.storage.mongo import MongoStorage
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from prometheus_client import Info
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||
from config import settings
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from api.service.album_service import AlbumService
|
||||
from middlewares.album import AlbumMiddleware
|
||||
from middlewares.auth import AuthMiddleware
|
||||
from middlewares.dao import DaoMiddleware
|
||||
@@ -38,17 +40,22 @@ from api.endpoints.character_router import router as api_char_router # Роут
|
||||
from api.endpoints.generation_router import router as api_gen_router
|
||||
from api.endpoints.auth import router as api_auth_router
|
||||
from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
from api.endpoints.post_router import router as post_api_router
|
||||
from api.endpoints.environment_router import router as environment_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- КОНФИГУРАЦИЯ ---
|
||||
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
# Настройки теперь берутся из config.py
|
||||
BOT_TOKEN = settings.BOT_TOKEN
|
||||
GEMINI_API_KEY = settings.GEMINI_API_KEY
|
||||
|
||||
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
||||
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
||||
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
ADMIN_ID = settings.ADMIN_ID
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@@ -58,6 +65,8 @@ def setup_logging():
|
||||
|
||||
|
||||
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
||||
if BOT_TOKEN is None:
|
||||
raise ValueError("BOT_TOKEN is not set")
|
||||
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||
|
||||
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
||||
@@ -70,15 +79,20 @@ char_repo = CharacterRepo(mongo_client)
|
||||
|
||||
# S3 Adapter
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||
if GEMINI_API_KEY is None:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||
generation_service = GenerationService(dao, gemini, bot)
|
||||
if bot is None:
|
||||
raise ValueError("bot is not set")
|
||||
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
|
||||
album_service = AlbumService(dao)
|
||||
|
||||
# Dispatcher
|
||||
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
||||
@@ -114,6 +128,18 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
|
||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||
|
||||
|
||||
async def start_scheduler(service: GenerationService):
|
||||
while True:
|
||||
try:
|
||||
logger.info("Running scheduler for stacked generation killing")
|
||||
await service.cleanup_stale_generations()
|
||||
await service.cleanup_old_data(days=2)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler error: {e}")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -132,6 +158,7 @@ async def lifespan(app: FastAPI):
|
||||
app.state.gemini_client = gemini
|
||||
app.state.bot = bot
|
||||
app.state.s3_adapter = s3_adapter
|
||||
app.state.album_service = album_service
|
||||
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||
|
||||
print("✅ DB & DAO initialized")
|
||||
@@ -139,22 +166,33 @@ async def lifespan(app: FastAPI):
|
||||
# 2. ЗАПУСК БОТА (в фоне)
|
||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
||||
polling_task = asyncio.create_task(
|
||||
dp.start_polling(bot, handle_signals=False)
|
||||
)
|
||||
print("🤖 Bot polling started")
|
||||
# polling_task = asyncio.create_task(
|
||||
# dp.start_polling(bot, handle_signals=False)
|
||||
# )
|
||||
# print("🤖 Bot polling started")
|
||||
|
||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||
print("⏰ Scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
# --- SHUTDOWN ---
|
||||
print("🛑 Shutting down...")
|
||||
|
||||
# 3. Остановка бота
|
||||
polling_task.cancel()
|
||||
|
||||
# 4. Остановка шедулера
|
||||
scheduler_task.cancel()
|
||||
try:
|
||||
await polling_task
|
||||
await scheduler_task
|
||||
except asyncio.CancelledError:
|
||||
print("🤖 Bot polling stopped")
|
||||
print("⏰ Scheduler stopped")
|
||||
|
||||
# 3. Остановка бота
|
||||
# polling_task.cancel()
|
||||
# try:
|
||||
# await polling_task
|
||||
# except asyncio.CancelledError:
|
||||
# print("🤖 Bot polling stopped")
|
||||
|
||||
# 4. Отключение БД
|
||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||
@@ -173,16 +211,29 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Подключаем роутер API
|
||||
from api.endpoints.auth import router as auth_api_router
|
||||
from api.endpoints.admin import router as admin_api_router
|
||||
app.include_router(auth_api_router)
|
||||
app.include_router(admin_api_router)
|
||||
# Подключаем роутеры API
|
||||
app.include_router(api_auth_router)
|
||||
app.include_router(api_admin_router)
|
||||
app.include_router(api_assets_router)
|
||||
app.include_router(api_char_router)
|
||||
app.include_router(api_gen_router)
|
||||
app.include_router(api_admin_router)
|
||||
app.include_router(api_auth_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
app.include_router(post_api_router)
|
||||
app.include_router(environment_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
|
||||
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
|
||||
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
|
||||
).instrument(
|
||||
app,
|
||||
metric_namespace="ai_bot",
|
||||
).expose(app, endpoint="/metrics", include_in_schema=False)
|
||||
app_info = Info("fastapi_app_info", "FastAPI application info")
|
||||
app_info.info({"app_name": "ai-bot"})
|
||||
|
||||
|
||||
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
||||
@@ -209,7 +260,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# Создаем конфигурацию uvicorn вручную
|
||||
# loop="asyncio" заставляет использовать стандартный цикл
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Запускаем сервер (lifespan запустится внутри)
|
||||
Binary file not shown.
@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
# ... ваши импорты ...
|
||||
@@ -43,4 +44,22 @@ def get_generation_service(
|
||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||
bot: Bot = Depends(get_bot_client),
|
||||
) -> GenerationService:
|
||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
||||
|
||||
from api.service.idea_service import IdeaService
|
||||
|
||||
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
|
||||
return IdeaService(dao)
|
||||
|
||||
from fastapi import Header
|
||||
|
||||
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||
return x_project_id
|
||||
|
||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
||||
return AlbumService(dao)
|
||||
|
||||
from api.service.post_service import PostService
|
||||
|
||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||
return PostService(dao)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -5,6 +5,8 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from repos.user_repo import UsersRepo, UserStatus
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||
from jose import JWTError, jwt
|
||||
from starlette.requests import Request
|
||||
@@ -23,7 +25,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
username: str | None = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
|
||||
84
api/endpoints/album_router.py
Normal file
84
api/endpoints/album_router.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, status, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
from models.Album import Album
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_album_service
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||
|
||||
class AlbumCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class AlbumUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class AlbumResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||
|
||||
@router.post("", response_model=AlbumResponse)
|
||||
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||
return AlbumResponse(**album.model_dump())
|
||||
|
||||
@router.get("", response_model=List[AlbumResponse])
|
||||
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
albums = await service.get_albums(limit=limit, offset=offset)
|
||||
return [AlbumResponse(**album.model_dump()) for album in albums]
|
||||
|
||||
@router.get("/{album_id}", response_model=AlbumResponse)
|
||||
async def get_album(request: Request, album_id: str):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
album = await service.get_album(album_id)
|
||||
if not album:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||
return AlbumResponse(**album.model_dump())
|
||||
|
||||
@router.put("/{album_id}", response_model=AlbumResponse)
|
||||
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
|
||||
if not album:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||
return AlbumResponse(**album.model_dump())
|
||||
|
||||
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_album(request: Request, album_id: str):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
deleted = await service.delete_album(album_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||
|
||||
@router.post("/{album_id}/generations/{generation_id}")
|
||||
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
success = await service.add_generation_to_album(album_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.delete("/{album_id}/generations/{generation_id}")
|
||||
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
success = await service.remove_generation_from_album(album_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
||||
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||
return [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
@@ -1,17 +1,21 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from aiogram.types import BufferedInputFile
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
||||
from fastapi.openapi.models import MediaType
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from starlette import status
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
from starlette.responses import Response, JSONResponse, StreamingResponse
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
@@ -19,6 +23,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_project_id
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||
|
||||
@@ -28,28 +33,160 @@ async def get_asset(
|
||||
asset_id: str,
|
||||
request: Request,
|
||||
thumbnail: bool = False,
|
||||
dao: DAO = Depends(get_dao)
|
||||
dao: DAO = Depends(get_dao),
|
||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||
) -> Response:
|
||||
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
||||
asset = await dao.assets.get_asset(asset_id)
|
||||
# 2. Проверка на существование
|
||||
# Загружаем только метаданные (без data/thumbnail bytes)
|
||||
asset = await dao.assets.get_asset(asset_id, with_data=False)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
headers = {
|
||||
# Кэшировать на 1 год (31536000 сек)
|
||||
"Cache-Control": "public, max-age=31536000, immutable"
|
||||
}
|
||||
|
||||
content = asset.data
|
||||
media_type = "image/png" # Default, or detect
|
||||
# Thumbnail: маленький, можно грузить в RAM
|
||||
if thumbnail:
|
||||
if asset.minio_thumbnail_object_name and s3_adapter:
|
||||
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
||||
if thumb_bytes:
|
||||
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
|
||||
# Fallback: thumbnail in DB
|
||||
if asset.thumbnail:
|
||||
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
|
||||
# No thumbnail available — fall through to main content
|
||||
|
||||
if thumbnail and asset.thumbnail:
|
||||
content = asset.thumbnail
|
||||
media_type = "image/jpeg"
|
||||
|
||||
return Response(content=content, media_type=media_type, headers=headers)
|
||||
# Main content: стримим из S3 без загрузки в RAM
|
||||
if asset.minio_object_name and s3_adapter:
|
||||
content_type = "image/png"
|
||||
# if asset.content_type == AssetContentType.VIDEO:
|
||||
# content_type = "video/mp4"
|
||||
return StreamingResponse(
|
||||
s3_adapter.stream_file(asset.minio_object_name),
|
||||
media_type=content_type,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Fallback: data stored in DB (legacy)
|
||||
if asset.data:
|
||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Asset data not found")
|
||||
|
||||
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||
async def delete_orphan_assets_from_minio(
|
||||
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
|
||||
minio_client: S3Adapter = Depends(get_s3_adapter),
|
||||
*,
|
||||
assets_collection: str = "assets",
|
||||
generations_collection: str = "generations",
|
||||
asset_type: Optional[str] = "generated",
|
||||
project_id: Optional[str] = None,
|
||||
dry_run: bool = True,
|
||||
mark_assets_deleted: bool = False,
|
||||
batch_size: int = 500,
|
||||
) -> Dict[str, Any]:
|
||||
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||
assets = db[assets_collection]
|
||||
|
||||
match_assets: Dict[str, Any] = {}
|
||||
if asset_type is not None:
|
||||
match_assets["type"] = asset_type
|
||||
if project_id is not None:
|
||||
match_assets["project_id"] = project_id
|
||||
|
||||
pipeline: List[Dict[str, Any]] = [
|
||||
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||
{
|
||||
"$lookup": {
|
||||
"from": generations_collection,
|
||||
"let": {"assetIdStr": {"$toString": "$_id"}},
|
||||
"pipeline": [
|
||||
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
|
||||
{"$match": {"is_deleted": {"$ne": True}}},
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {
|
||||
"$in": [
|
||||
"$$assetIdStr",
|
||||
{"$ifNull": ["$result_list", []]},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$limit": 1},
|
||||
],
|
||||
"as": "alive_generations",
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 1,
|
||||
"minio_object_name": 1,
|
||||
"minio_thumbnail_object_name": 1,
|
||||
}
|
||||
},
|
||||
]
|
||||
print(pipeline)
|
||||
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
|
||||
|
||||
deleted_objects = 0
|
||||
deleted_assets = 0
|
||||
errors: List[Dict[str, Any]] = []
|
||||
orphan_asset_ids: List[ObjectId] = []
|
||||
|
||||
async for asset in cursor:
|
||||
aid = asset["_id"]
|
||||
obj = asset.get("minio_object_name")
|
||||
thumb = asset.get("minio_thumbnail_object_name")
|
||||
|
||||
orphan_asset_ids.append(aid)
|
||||
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if obj:
|
||||
await minio_client.delete_file(obj)
|
||||
deleted_objects += 1
|
||||
|
||||
if thumb:
|
||||
await minio_client.delete_file(thumb)
|
||||
deleted_objects += 1
|
||||
|
||||
deleted_assets += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append({"asset_id": str(aid), "error": str(e)})
|
||||
|
||||
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
|
||||
res = await assets.update_many(
|
||||
{"_id": {"$in": orphan_asset_ids}},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
marked = res.modified_count
|
||||
else:
|
||||
marked = 0
|
||||
|
||||
return {
|
||||
"dry_run": dry_run,
|
||||
"filter": {
|
||||
"asset_type": asset_type,
|
||||
"project_id": project_id,
|
||||
},
|
||||
"orphans_found": len(orphan_asset_ids),
|
||||
"deleted_assets": deleted_assets,
|
||||
"deleted_objects": deleted_objects,
|
||||
"marked_assets_deleted": marked,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||
async def delete_asset(
|
||||
@@ -68,11 +205,19 @@ async def delete_asset(
|
||||
|
||||
|
||||
@router.get("", dependencies=[Depends(get_current_user)])
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse:
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
|
||||
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||
assets = await dao.assets.get_assets(type, limit, offset)
|
||||
|
||||
user_id_filter = current_user["id"]
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
|
||||
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
|
||||
total_count = await dao.assets.get_asset_count()
|
||||
total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
|
||||
|
||||
# Manually map to DTO to trigger computed fields validation if necessary,
|
||||
# but primarily to ensure valid Pydantic models for the response list.
|
||||
@@ -84,11 +229,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
|
||||
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)])
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
linked_char_id: Optional[str] = Form(None),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id)
|
||||
):
|
||||
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||
if not file.content_type:
|
||||
@@ -96,6 +243,11 @@ async def upload_asset(
|
||||
|
||||
if not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
data = await file.read()
|
||||
if not data:
|
||||
@@ -111,7 +263,9 @@ async def upload_asset(
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=data,
|
||||
thumbnail=thumbnail_bytes
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=str(current_user["_id"]),
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
asset_id = await dao.assets.create_asset(asset)
|
||||
@@ -124,8 +278,7 @@ async def upload_asset(
|
||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
|
||||
linked_char_id=asset.linked_char_id,
|
||||
created_at=asset.created_at,
|
||||
url=asset.url
|
||||
created_at=asset.created_at
|
||||
)
|
||||
|
||||
|
||||
@@ -171,4 +324,5 @@ async def migrate_to_minio(dao: DAO = Depends(get_dao)):
|
||||
logger.info("Starting migration to MinIO")
|
||||
result = await dao.assets.migrate_to_minio()
|
||||
logger.info(f"Migration result: {result}")
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class Token(BaseModel):
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
full_name: str | None = None
|
||||
status: str
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from typing import List, Any, Coroutine
|
||||
from typing import List, Any, Coroutine, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from api.models import GenerationRequest, GenerationResponse
|
||||
from models.Asset import Asset
|
||||
from models.Character import Character
|
||||
from api.models import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
@@ -17,25 +18,61 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_project_id
|
||||
|
||||
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Character])
|
||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
|
||||
logger.info("get_characters called")
|
||||
characters = await dao.chars.get_all_characters()
|
||||
async def get_characters(
|
||||
request: Request,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Character]:
|
||||
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
characters = await dao.chars.get_all_characters(
|
||||
created_by=user_id_filter,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
return characters
|
||||
|
||||
|
||||
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
||||
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
||||
offset: int = 0, ) -> AssetsResponse:
|
||||
offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
|
||||
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if character is None:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
# Access Check
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
||||
# Filter assets by user ownership as well?
|
||||
# Usually if you own character, you see its assets.
|
||||
# But assets also have specific created_by.
|
||||
# Let's assume if you own character you can see its assets.
|
||||
|
||||
total_count = await dao.assets.get_asset_count(character_id)
|
||||
|
||||
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
||||
@@ -43,14 +80,113 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
|
||||
|
||||
|
||||
@router.get("/{character_id}", response_model=Character)
|
||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
|
||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
|
||||
logger.debug(f"get_character_by_id called. ID: {character_id}")
|
||||
character = await dao.chars.get_character(character_id)
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
if character:
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||
request: Request) -> GenerationResponse:
|
||||
logger.info(f"post_character_generation called. CharacterID: {character_id}")
|
||||
generation_service = request.app.state.generation_service
|
||||
@router.post("/", response_model=Character)
|
||||
async def create_character(
|
||||
char_req: CharacterCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Character:
|
||||
logger.info("create_character called")
|
||||
char_req.project_id = project_id
|
||||
char_data = char_req.model_dump()
|
||||
char_data["created_by"] = str(current_user["_id"])
|
||||
if "id" not in char_data:
|
||||
char_data["id"] = None
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
new_char = Character(**char_data)
|
||||
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
|
||||
created_char = await dao.chars.add_character(new_char)
|
||||
return created_char
|
||||
|
||||
|
||||
@router.put("/{character_id}", response_model=Character)
|
||||
async def update_character(
|
||||
character_id: str,
|
||||
char_update: CharacterUpdateRequest,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Character:
|
||||
logger.info(f"update_character called. ID: {character_id}")
|
||||
|
||||
existing_char = await dao.chars.get_character(character_id)
|
||||
if not existing_char:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
update_data = char_update.model_dump(exclude_unset=True)
|
||||
|
||||
if "project_id" in update_data and update_data["project_id"]:
|
||||
new_project_id = update_data["project_id"]
|
||||
project = await dao.projects.get_project(new_project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Target project access denied")
|
||||
|
||||
updated_char_data = existing_char.model_dump()
|
||||
updated_char_data.update(update_data)
|
||||
|
||||
updated_char = Character(**updated_char_data)
|
||||
|
||||
success = await dao.chars.update_char(character_id, updated_char)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update character")
|
||||
|
||||
return updated_char
|
||||
|
||||
|
||||
@router.delete("/{character_id}", status_code=204)
|
||||
async def delete_character(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"delete_character called. ID: {character_id}")
|
||||
|
||||
existing_char = await dao.chars.get_character(character_id)
|
||||
if not existing_char:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
success = await dao.chars.delete_character(character_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||
|
||||
return
|
||||
|
||||
180
api/endpoints/environment_router.py
Normal file
180
api/endpoints/environment_router.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||
from models.Environment import Environment
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/", response_model=Environment)
|
||||
async def create_environment(
|
||||
env_req: EnvironmentCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||
await check_character_access(env_req.character_id, current_user, dao)
|
||||
|
||||
# Verify assets exist if provided
|
||||
if env_req.asset_ids:
|
||||
for aid in env_req.asset_ids:
|
||||
asset = await dao.assets.get_asset(aid)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||
|
||||
new_env = Environment(**env_req.model_dump())
|
||||
created_env = await dao.environments.create_env(new_env)
|
||||
return created_env
|
||||
|
||||
|
||||
@router.get("/character/{character_id}", response_model=List[Environment])
|
||||
async def get_character_environments(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Getting environments for character {character_id}")
|
||||
await check_character_access(character_id, current_user, dao)
|
||||
return await dao.environments.get_character_envs(character_id)
|
||||
|
||||
|
||||
@router.get("/{env_id}", response_model=Environment)
|
||||
async def get_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
return env
|
||||
|
||||
|
||||
@router.put("/{env_id}", response_model=Environment)
|
||||
async def update_environment(
|
||||
env_id: str,
|
||||
env_update: EnvironmentUpdate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
update_data = env_update.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
return env
|
||||
|
||||
success = await dao.environments.update_env(env_id, update_data)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||
|
||||
return await dao.environments.get_env(env_id)
|
||||
|
||||
|
||||
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.delete_env(env_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||
async def add_asset_to_environment(
|
||||
env_id: str,
|
||||
req: AssetToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify asset exists
|
||||
asset = await dao.assets.get_asset(req.asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||
async def add_assets_batch_to_environment(
|
||||
env_id: str,
|
||||
req: AssetsToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify all assets exist
|
||||
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||
if len(assets) != len(req.asset_ids):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||
async def remove_asset_from_environment(
|
||||
env_id: str,
|
||||
asset_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||
return {"success": success}
|
||||
@@ -1,31 +1,40 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from api import service
|
||||
from api.dependency import get_generation_service
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||
from config import settings
|
||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models import (
|
||||
GenerationResponse,
|
||||
GenerationRequest,
|
||||
GenerationsResponse,
|
||||
PromptResponse,
|
||||
PromptRequest,
|
||||
GenerationGroupResponse,
|
||||
FinancialReport,
|
||||
ExternalGenerationRequest
|
||||
)
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Generation import Generation
|
||||
|
||||
from starlette import status
|
||||
|
||||
import logging
|
||||
from repos.dao import DAO
|
||||
from utils.external_auth import verify_signature
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> PromptResponse:
|
||||
get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
@@ -35,7 +44,8 @@ async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
async def prompt_from_image(
|
||||
prompt: Optional[str] = Form(None),
|
||||
images: List[UploadFile] = File(...),
|
||||
generation_service: GenerationService = Depends(get_generation_service)
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
@@ -49,36 +59,177 @@ async def prompt_from_image(
|
||||
|
||||
@router.get("", response_model=GenerationsResponse)
|
||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # Show all project generations
|
||||
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.post("/_run", response_model=GenerationResponse)
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
async def get_usage_report(
|
||||
breakdown: Optional[str] = None, # "user" or "project"
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
"""
|
||||
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
||||
If project_id is provided, returns stats for that project.
|
||||
Otherwise, returns stats for the current user.
|
||||
"""
|
||||
user_id_filter = str(current_user["_id"])
|
||||
breakdown_by = None
|
||||
|
||||
if project_id:
|
||||
# Permission check
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
else:
|
||||
# Default: Stats for current user
|
||||
if breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
elif breakdown == "user":
|
||||
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
||||
# No, if project_id is None, it's personal.
|
||||
breakdown_by = "created_by"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id,
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> GenerationResponse:
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||
return await generation_service.create_generation_task(generation)
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
generation.project_id = project_id
|
||||
|
||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
return await generation_service.get_generation(generation_id)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
return await generation_service.get_running_generations()
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
|
||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||
async def get_generation_group(group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"get_generation_group called for group_id: {group_id}")
|
||||
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
||||
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
gen = await generation_service.get_generation(generation_id)
|
||||
if gen and gen.created_by != str(current_user["_id"]):
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return gen
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Verify signature
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
logger.warning("Invalid signature for external generation import")
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||
|
||||
# Import generation
|
||||
try:
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import external generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return None
|
||||
return None
|
||||
|
||||
104
api/endpoints/idea_router.py
Normal file
104
api/endpoints/idea_router.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.idea_service import IdeaService
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Idea import Idea
|
||||
from api.models import GenerationResponse, GenerationsResponse
|
||||
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
|
||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
|
||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
||||
|
||||
@router.post("", response_model=Idea)
|
||||
async def create_idea(
|
||||
request: IdeaCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
|
||||
|
||||
@router.get("", response_model=List[IdeaResponse])
|
||||
async def get_ideas(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
|
||||
|
||||
@router.get("/{idea_id}", response_model=Idea)
|
||||
async def get_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.get_idea(idea_id)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.put("/{idea_id}", response_model=Idea)
|
||||
async def update_idea(
|
||||
idea_id: str,
|
||||
request: IdeaUpdateRequest,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.update_idea(idea_id, request.name, request.description)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.delete("/{idea_id}")
|
||||
async def delete_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.delete_idea(idea_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
|
||||
async def get_idea_generations(
|
||||
idea_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service)
|
||||
):
|
||||
# Depending on how generation service implements filtering by idea_id.
|
||||
# We might need to update generation_service to support getting by idea_id directly
|
||||
# or ensure generic get_generations supports it.
|
||||
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
|
||||
# Let's check generation_service.get_generations signature again.
|
||||
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
|
||||
# I need to update GenerationService.get_generations too!
|
||||
|
||||
# For now, let's assume I will update it.
|
||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
|
||||
|
||||
@router.post("/{idea_id}/generations/{generation_id}")
|
||||
async def add_generation_to_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.delete("/{idea_id}/generations/{generation_id}")
|
||||
async def remove_generation_from_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
99
api/endpoints/post_router.py
Normal file
99
api/endpoints/post_router.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.dependency import get_post_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.post_service import PostService
|
||||
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
from models.Post import Post
|
||||
|
||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
||||
|
||||
|
||||
@router.post("", response_model=Post)
|
||||
async def create_post(
|
||||
request: PostCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
return await post_service.create_post(
|
||||
date=request.date,
|
||||
topic=request.topic,
|
||||
generation_ids=request.generation_ids,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[Post])
|
||||
async def get_posts(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=Post)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=Post)
|
||||
async def update_post(
|
||||
post_id: str,
|
||||
request: PostUpdateRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.delete_post(post_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/generations")
|
||||
async def add_generations(
|
||||
post_id: str,
|
||||
request: AddGenerationsRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/generations/{generation_id}")
|
||||
async def remove_generation(
|
||||
post_id: str,
|
||||
generation_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.remove_generation(post_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
||||
return {"status": "success"}
|
||||
182
api/endpoints/project_router.py
Normal file
182
api/endpoints/project_router.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from models.Project import Project
|
||||
from repos.dao import DAO
|
||||
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class ProjectMemberResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
|
||||
class ProjectResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
owner_id: str
|
||||
members: List[ProjectMemberResponse]
|
||||
is_owner: bool = False
|
||||
|
||||
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
||||
member_responses = []
|
||||
for member_id in project.members:
|
||||
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
|
||||
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
|
||||
# But for Web users we might need to search by _id.
|
||||
# Let's try to get user info.
|
||||
# Since project.members contains strings (ObjectIds as strings), we search by _id.
|
||||
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
|
||||
if not user_doc and member_id.isdigit():
|
||||
# Fallback for telegram IDs if they are stored as strings of digits
|
||||
user_doc = await dao.users.get_user(int(member_id))
|
||||
|
||||
username = "unknown"
|
||||
if user_doc:
|
||||
username = user_doc.get("username", "unknown")
|
||||
|
||||
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
|
||||
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
description=project.description,
|
||||
owner_id=project.owner_id,
|
||||
members=member_responses,
|
||||
is_owner=(project.owner_id == current_user_id)
|
||||
)
|
||||
|
||||
@router.post("", response_model=ProjectResponse)
|
||||
async def create_project(
|
||||
project_data: ProjectCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
new_project = Project(
|
||||
name=project_data.name,
|
||||
description=project_data.description,
|
||||
owner_id=user_id,
|
||||
members=[user_id]
|
||||
)
|
||||
project_id = await dao.projects.create_project(new_project)
|
||||
new_project.id = project_id
|
||||
|
||||
# Add project to user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return await _get_project_response(new_project, user_id, dao)
|
||||
|
||||
@router.get("", response_model=List[ProjectResponse])
|
||||
async def get_my_projects(
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
projects = await dao.projects.get_projects_by_user(user_id)
|
||||
|
||||
responses = []
|
||||
for p in projects:
|
||||
responses.append(await _get_project_response(p, user_id, dao))
|
||||
return responses
|
||||
|
||||
class MemberAdd(BaseModel):
|
||||
username: str
|
||||
|
||||
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
|
||||
async def add_member(
|
||||
project_id: str,
|
||||
member_data: MemberAdd,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if project.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Only owner can add members")
|
||||
|
||||
target_user = await dao.users.get_user_by_username(member_data.username)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
target_user_id = str(target_user["_id"])
|
||||
|
||||
if target_user_id in project.members:
|
||||
return {"message": "User already in project"}
|
||||
|
||||
await dao.projects.add_member(project_id, target_user_id)
|
||||
|
||||
# Update target user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": target_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Member added"}
|
||||
|
||||
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
|
||||
async def join_project(
|
||||
project_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
# Retrieve project to verify it exists
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = str(current_user["_id"])
|
||||
|
||||
# Check if user is ALREADY in project
|
||||
if user_id in project.members:
|
||||
return {"message": "Already a member"}
|
||||
|
||||
# Add member
|
||||
await dao.projects.add_member(project_id, user_id)
|
||||
|
||||
# Update user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Joined project"}
|
||||
|
||||
|
||||
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if project.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Only owner can delete project")
|
||||
|
||||
await dao.projects.delete_project(project_id)
|
||||
|
||||
# Remove project from user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$pull": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Project deleted"}
|
||||
18
api/models/CharacterDTO.py
Normal file
18
api/models/CharacterDTO.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CharacterCreateRequest(BaseModel):
|
||||
name: str
|
||||
character_bio: str
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
class CharacterUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
character_bio: Optional[str] = None
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
22
api/models/EnvironmentRequest.py
Normal file
22
api/models/EnvironmentRequest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EnvironmentCreate(BaseModel):
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: Optional[str] = None
|
||||
asset_ids: Optional[List[str]] = []
|
||||
|
||||
|
||||
class EnvironmentUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class AssetToEnvironment(BaseModel):
|
||||
asset_id: str
|
||||
|
||||
|
||||
class AssetsToEnvironment(BaseModel):
|
||||
asset_ids: List[str]
|
||||
37
api/models/ExternalGenerationDTO.py
Normal file
37
api/models/ExternalGenerationDTO.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class ExternalGenerationRequest(BaseModel):
|
||||
"""Request model for importing external generations."""
|
||||
|
||||
prompt: str
|
||||
tech_prompt: Optional[str] = None
|
||||
|
||||
# Image can be provided as base64 string OR URL (one must be provided)
|
||||
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
||||
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||
|
||||
# Generation metadata
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
quality: Quality = Quality.ONEK
|
||||
|
||||
# Optional linking
|
||||
linked_character_id: Optional[str] = None
|
||||
created_by: str = Field(..., description="User ID from external system")
|
||||
project_id: Optional[str] = None
|
||||
|
||||
# Performance metrics
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
|
||||
def validate_image_source(self):
|
||||
"""Ensure at least one image source is provided."""
|
||||
if not self.image_data and not self.image_url:
|
||||
raise ValueError("Either image_data or image_url must be provided")
|
||||
if self.image_data and self.image_url:
|
||||
raise ValueError("Only one of image_data or image_url should be provided")
|
||||
18
api/models/FinancialUsageDTO.py
Normal file
18
api/models/FinancialUsageDTO.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
total_runs: int
|
||||
total_tokens: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_cost: float
|
||||
|
||||
class UsageByEntity(BaseModel):
|
||||
entity_id: Optional[str] = None
|
||||
stats: UsageStats
|
||||
|
||||
class FinancialReport(BaseModel):
|
||||
summary: UsageStats
|
||||
by_user: Optional[List[UsageByEntity]] = None
|
||||
by_project: Optional[List[UsageByEntity]] = None
|
||||
@@ -1,7 +1,7 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.Generation import GenerationStatus
|
||||
@@ -16,6 +16,10 @@ class GenerationRequest(BaseModel):
|
||||
telegram_id: Optional[int] = None
|
||||
use_profile_image: bool = True
|
||||
assets_list: List[str]
|
||||
environment_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
|
||||
class GenerationsResponse(BaseModel):
|
||||
@@ -42,10 +46,18 @@ class GenerationResponse(BaseModel):
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
progress: int = 0
|
||||
cost: Optional[float] = None
|
||||
created_by: Optional[str] = None
|
||||
generation_group_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
|
||||
|
||||
class GenerationGroupResponse(BaseModel):
|
||||
generation_group_id: str
|
||||
generations: List[GenerationResponse]
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
|
||||
16
api/models/IdeaRequest.py
Normal file
16
api/models/IdeaRequest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
|
||||
class IdeaCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[str] = None # Optional in body if passed via header/dependency
|
||||
|
||||
class IdeaUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class IdeaResponse(Idea):
|
||||
last_generation: Optional[GenerationResponse] = None
|
||||
19
api/models/PostRequest.py
Normal file
19
api/models/PostRequest.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PostCreateRequest(BaseModel):
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = []
|
||||
project_id: Optional[str] = None
|
||||
|
||||
|
||||
class PostUpdateRequest(BaseModel):
|
||||
date: Optional[datetime] = None
|
||||
topic: Optional[str] = None
|
||||
|
||||
|
||||
class AddGenerationsRequest(BaseModel):
|
||||
generation_ids: List[str]
|
||||
@@ -0,0 +1,7 @@
|
||||
from .AssetDTO import AssetResponse, AssetsResponse
|
||||
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from .ExternalGenerationDTO import ExternalGenerationRequest
|
||||
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
||||
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
|
||||
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
85
api/service/album_service.py
Normal file
85
api/service/album_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import List, Optional
|
||||
from models.Album import Album
|
||||
from models.Generation import Generation
|
||||
from repos.dao import DAO
|
||||
|
||||
class AlbumService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
|
||||
album = Album(name=name, description=description)
|
||||
album_id = await self.dao.albums.create_album(album)
|
||||
album.id = album_id
|
||||
return album
|
||||
|
||||
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||
return await self.dao.albums.get_albums(limit=limit, offset=offset)
|
||||
|
||||
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||
return await self.dao.albums.get_album(album_id)
|
||||
|
||||
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
|
||||
album = await self.dao.albums.get_album(album_id)
|
||||
if not album:
|
||||
return None
|
||||
|
||||
if name:
|
||||
album.name = name
|
||||
if description is not None:
|
||||
album.description = description
|
||||
|
||||
await self.dao.albums.update_album(album_id, album)
|
||||
return album
|
||||
|
||||
async def delete_album(self, album_id: str) -> bool:
|
||||
return await self.dao.albums.delete_album(album_id)
|
||||
|
||||
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
|
||||
# Verify album exists
|
||||
album = await self.dao.albums.get_album(album_id)
|
||||
if not album:
|
||||
return False
|
||||
|
||||
# Verify generation exists (optional but good practice)
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
if album.cover_asset_id is None and gen.status == 'done':
|
||||
album.cover_asset_id = gen.result_list[0]
|
||||
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
|
||||
|
||||
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
|
||||
return await self.dao.albums.remove_generation(album_id, generation_id)
|
||||
|
||||
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
|
||||
album = await self.dao.albums.get_album(album_id)
|
||||
if not album or not album.generation_ids:
|
||||
return []
|
||||
|
||||
# Slice the generation IDs (simple pagination on ID list)
|
||||
# Note: This pagination is on IDs, then we fetch objects.
|
||||
# Ideally, fetch only slice.
|
||||
|
||||
# Reverse to show newest first? Or just follow list order?
|
||||
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
|
||||
# Let's assume user wants same order as in list.
|
||||
|
||||
sliced_ids = album.generation_ids[offset : offset + limit]
|
||||
if not sliced_ids:
|
||||
return []
|
||||
|
||||
# Fetch generations by IDs
|
||||
# We need a method in GenerationRepo to fetch by IDs.
|
||||
# Currently we only have get_generations with filters.
|
||||
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
|
||||
# Let's add get_generations_by_ids to GenerationRepo.
|
||||
|
||||
# For now, I will use a loop if I can't modify Repo immediately,
|
||||
# but I SHOULD modify GenerationRepo.
|
||||
|
||||
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
|
||||
# But get_generations doesn't support generic filter passing.
|
||||
|
||||
# I'll update GenerationRepo to support fetching by IDs.
|
||||
return await self.dao.generations.get_generations_by_ids(sliced_ids)
|
||||
@@ -1,24 +1,32 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
from fastapi import HTTPException
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import FinancialReport, UsageStats, UsageByEntity
|
||||
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limit concurrent generations to 4
|
||||
generation_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
@@ -48,16 +56,18 @@ async def generate_image_task(
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
except GoogleGenerationException as e:
|
||||
raise e
|
||||
finally:
|
||||
# Освобождаем входные данные — они больше не нужны
|
||||
del media_group_bytes
|
||||
|
||||
images_bytes = []
|
||||
if generated_images_io:
|
||||
for img_io in generated_images_io:
|
||||
# Читаем байты из BytesIO
|
||||
img_io.seek(0)
|
||||
content = img_io.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
# Закрываем поток
|
||||
images_bytes.append(img_io.read())
|
||||
img_io.close()
|
||||
# Освобождаем список BytesIO сразу
|
||||
del generated_images_io
|
||||
|
||||
return images_bytes, metrics
|
||||
|
||||
@@ -69,7 +79,7 @@ class GenerationService:
|
||||
self.bot = bot
|
||||
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||
@@ -92,10 +102,9 @@ class GenerationService:
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
||||
Generation]:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id)
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||
|
||||
@@ -106,29 +115,56 @@ class GenerationService:
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
|
||||
async def get_running_generations(self) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
count = generation_request.count
|
||||
|
||||
if generation_group_id is None:
|
||||
generation_group_id = str(uuid4())
|
||||
|
||||
results = []
|
||||
for _ in range(count):
|
||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||
results.append(gen_response)
|
||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||
|
||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
if generation_request.environment_id and not generation_request.linked_character_id:
|
||||
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump())
|
||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
if user_id:
|
||||
generation_model.created_by = user_id
|
||||
if generation_group_id:
|
||||
generation_model.generation_group_id = generation_group_id
|
||||
|
||||
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
|
||||
if generation_request.idea_id:
|
||||
generation_model.idea_id = generation_request.idea_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
async with generation_semaphore:
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
if db_gen is not None:
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
@@ -142,8 +178,9 @@ class GenerationService:
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
if gen is not None:
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
@@ -153,41 +190,40 @@ class GenerationService:
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = f"""
|
||||
|
||||
Create detailed image of character in scene.
|
||||
|
||||
SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
Rules:
|
||||
- Integrate the character's appearance naturally into the scene description.
|
||||
- Focus on lighting, texture, and composition.
|
||||
"""
|
||||
generation_prompt = generation.prompt
|
||||
|
||||
# 2.1 Аватар персонажа (всегда первый, если включен)
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
media_group_bytes.append(char_info.character_image_data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
img_data = await self._get_asset_data(avatar_asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
# 2.2 Явно указанные ассеты
|
||||
if generation.assets_list:
|
||||
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in explicit_assets:
|
||||
ref_asset_data = await self._get_asset_data(asset)
|
||||
if ref_asset_data:
|
||||
media_group_bytes.append(ref_asset_data)
|
||||
|
||||
# 2.3 Ассеты из окружения (в самый конец)
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
img_data = await self._get_asset_data(asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
@@ -258,7 +294,9 @@ class GenerationService:
|
||||
data=None, # Not storing bytes in DB anymore
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id
|
||||
)
|
||||
|
||||
# Сохраняем в БД
|
||||
@@ -269,7 +307,9 @@ class GenerationService:
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = [a.id for a in created_assets]
|
||||
result_ids = []
|
||||
for a in created_assets:
|
||||
result_ids.append(a.id)
|
||||
|
||||
generation.result_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
@@ -300,6 +340,14 @@ class GenerationService:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
@@ -325,6 +373,98 @@ class GenerationService:
|
||||
logger.error(f"Error in progress simulation: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
|
||||
Args:
|
||||
external_gen: ExternalGenerationRequest with generation data and image
|
||||
|
||||
Returns:
|
||||
Created Generation object
|
||||
"""
|
||||
|
||||
# Validate image source
|
||||
external_gen.validate_image_source()
|
||||
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
|
||||
# 1. Process image (download or decode)
|
||||
image_bytes = None
|
||||
|
||||
if external_gen.image_url:
|
||||
# Download image from URL
|
||||
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif external_gen.image_data:
|
||||
# Decode base64 image
|
||||
logger.info("Decoding base64 image data")
|
||||
image_bytes = base64.b64decode(external_gen.image_data)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Failed to process image data")
|
||||
|
||||
# 2. Generate thumbnail
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
|
||||
# 3. Save to S3
|
||||
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
# 4. Create Asset
|
||||
new_asset = Asset(
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
data=None, # Not storing bytes in DB
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
)
|
||||
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
|
||||
logger.info(f"Created asset {asset_id} for external generation")
|
||||
|
||||
# 5. Create Generation record
|
||||
generation = Generation(
|
||||
status=GenerationStatus.DONE,
|
||||
linked_character_id=external_gen.linked_character_id,
|
||||
aspect_ratio=external_gen.aspect_ratio,
|
||||
quality=external_gen.quality,
|
||||
prompt=external_gen.prompt,
|
||||
tech_prompt=external_gen.tech_prompt,
|
||||
result_list=[new_asset.id],
|
||||
result=new_asset.id,
|
||||
progress=100,
|
||||
execution_time_seconds=external_gen.execution_time_seconds,
|
||||
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||
token_usage=external_gen.token_usage,
|
||||
input_token_usage=external_gen.input_token_usage,
|
||||
output_token_usage=external_gen.output_token_usage,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC)
|
||||
)
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
logger.info(f"Created generation {gen_id} from external source")
|
||||
|
||||
return generation
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
Soft delete generation by marking it as deleted.
|
||||
@@ -340,4 +480,62 @@ class GenerationService:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||
return False
|
||||
return False
|
||||
|
||||
async def cleanup_stale_generations(self):
|
||||
"""
|
||||
Cancels generations that have been running for more than 1 hour.
|
||||
"""
|
||||
try:
|
||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} stale generations (timeout)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up stale generations: {e}")
|
||||
|
||||
async def cleanup_old_data(self, days: int = 2):
|
||||
"""
|
||||
Очистка старых данных:
|
||||
1. Мягко удаляет генерации старше N дней
|
||||
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
|
||||
"""
|
||||
try:
|
||||
# 1. Мягко удаляем генерации и собираем asset IDs
|
||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
||||
|
||||
if gen_count > 0:
|
||||
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
|
||||
f"Found {len(asset_ids)} associated asset IDs.")
|
||||
|
||||
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
|
||||
if asset_ids:
|
||||
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during old data cleanup: {e}")
|
||||
|
||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||
"""
|
||||
Generates a financial usage report for a specific user or project.
|
||||
'breakdown_by' can be 'created_by' or 'project_id'.
|
||||
"""
|
||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||
summary = UsageStats(**summary_data)
|
||||
|
||||
by_user = None
|
||||
by_project = None
|
||||
|
||||
if breakdown_by == "created_by":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||
by_user = [UsageByEntity(**item) for item in res]
|
||||
|
||||
if breakdown_by == "project_id":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||
by_project = [UsageByEntity(**item) for item in res]
|
||||
|
||||
return FinancialReport(
|
||||
summary=summary,
|
||||
by_user=by_user,
|
||||
by_project=by_project
|
||||
)
|
||||
75
api/service/idea_service.py
Normal file
75
api/service/idea_service.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
|
||||
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
|
||||
idea_id = await self.dao.ideas.create_idea(idea)
|
||||
idea.id = idea_id
|
||||
return idea
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
return await self.dao.ideas.get_idea(idea_id)
|
||||
|
||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
idea.name = name
|
||||
if description is not None:
|
||||
idea.description = description
|
||||
|
||||
idea.updated_at = datetime.now()
|
||||
await self.dao.ideas.update_idea(idea)
|
||||
return idea
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
return await self.dao.ideas.delete_idea(idea_id)
|
||||
|
||||
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Link
|
||||
gen.idea_id = idea_id
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists (optional, but good for validation)
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Unlink only if currently linked to this idea
|
||||
if gen.idea_id == idea_id:
|
||||
gen.idea_id = None
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
79
api/service/post_service.py
Normal file
79
api/service/post_service.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from repos.dao import DAO
|
||||
from models.Post import Post
|
||||
|
||||
|
||||
class PostService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_post(
|
||||
self,
|
||||
date: datetime,
|
||||
topic: str,
|
||||
generation_ids: List[str],
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Post:
|
||||
post = Post(
|
||||
date=date,
|
||||
topic=topic,
|
||||
generation_ids=generation_ids,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
)
|
||||
post_id = await self.dao.posts.create_post(post)
|
||||
post.id = post_id
|
||||
return post
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
||||
|
||||
async def update_post(
|
||||
self,
|
||||
post_id: str,
|
||||
date: Optional[datetime] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> Optional[Post]:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return None
|
||||
|
||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
||||
if date is not None:
|
||||
updates["date"] = date
|
||||
if topic is not None:
|
||||
updates["topic"] = topic
|
||||
|
||||
await self.dao.posts.update_post(post_id, updates)
|
||||
|
||||
# Return refreshed post
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
return await self.dao.posts.delete_post(post_id)
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
||||
39
config.py
Normal file
39
config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Telegram Bot
|
||||
BOT_TOKEN: str
|
||||
ADMIN_ID: int = 0
|
||||
|
||||
# AI Service
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
# Database
|
||||
MONGO_HOST: str = "mongodb://localhost:27017"
|
||||
DB_NAME: str = "my_bot_db"
|
||||
|
||||
# S3 Storage (Minio)
|
||||
MINIO_ENDPOINT: str = "http://localhost:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "ai-char"
|
||||
|
||||
# External API
|
||||
EXTERNAL_API_SECRET: Optional[str] = None
|
||||
|
||||
# JWT Security
|
||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.getenv("ENV_FILE", ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
|
||||
# Ждем сбора остальных частей
|
||||
await asyncio.sleep(self.latency)
|
||||
|
||||
# Проверяем, что ключ все еще существует (на всякий случай)
|
||||
# Проверяем, что ключ все еще существует
|
||||
if group_id in self.album_data:
|
||||
# Передаем собранный альбом в хендлер
|
||||
# Сортируем по message_id, чтобы порядок был верным
|
||||
self.album_data[group_id].sort(key=lambda x: x.message_id)
|
||||
data["album"] = self.album_data[group_id]
|
||||
current_album = self.album_data[group_id]
|
||||
current_album.sort(key=lambda x: x.message_id)
|
||||
data["album"] = current_album
|
||||
return await handler(event, data)
|
||||
|
||||
finally:
|
||||
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
|
||||
# Проверяем, что мы удаляем именно то, что создали, и ключ существует
|
||||
if group_id in self.album_data and self.album_data[group_id][0] == event:
|
||||
del self.album_data[group_id]
|
||||
# ЧИСТКА: Удаляем запись после обработки или таймаута
|
||||
# Используем pop() с дефолтом, чтобы избежать KeyError
|
||||
self.album_data.pop(group_id, None)
|
||||
|
||||
else:
|
||||
# Если группа уже собирается - просто добавляем и выходим
|
||||
|
||||
12
models/Album.py
Normal file
12
models/Album.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Album(BaseModel):
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
cover_asset_id: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
@@ -28,6 +28,9 @@ class Asset(BaseModel):
|
||||
minio_thumbnail_object_name: Optional[str] = None
|
||||
thumbnail: Optional[bytes] = None
|
||||
tags: List[str] = []
|
||||
created_by: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@@ -60,6 +63,7 @@ class Asset(BaseModel):
|
||||
|
||||
# --- CALCULATED FIELD ---
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""
|
||||
Это поле автоматически вычислится и попадет в model_dump() / .json()
|
||||
|
||||
@@ -5,11 +5,12 @@ from pydantic_core.core_schema import computed_field
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
id: str | None
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
avatar_asset_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_data: Optional[bytes] = None
|
||||
character_image_doc_tg_id: str
|
||||
character_image_tg_id: str | None
|
||||
character_bio: str
|
||||
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
character_bio: Optional[str] = None
|
||||
created_by: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
20
models/Environment.py
Normal file
20
models/Environment.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: Optional[str] = None
|
||||
asset_ids: List[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_encoders={ObjectId: str},
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
@@ -34,5 +34,19 @@ class Generation(BaseModel):
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
is_deleted: bool = False
|
||||
album_id: Optional[str] = None
|
||||
environment_id: Optional[str] = None
|
||||
generation_group_id: Optional[str] = None
|
||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@computed_field
|
||||
def cost(self) -> float:
|
||||
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
|
||||
cost_input = self.input_token_usage * 0.000002
|
||||
cost_output = self.output_token_usage * 0.00012
|
||||
return round(cost_input + cost_output, 3)
|
||||
return 0.0
|
||||
13
models/Idea.py
Normal file
13
models/Idea.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Idea(BaseModel):
|
||||
id: Optional[str] = None
|
||||
name: str = "New Idea"
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
created_by: str # User ID
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
23
models/Post.py
Normal file
23
models/Post.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timezone, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
id: Optional[str] = None
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = Field(default_factory=list)
|
||||
project_id: Optional[str] = None
|
||||
created_by: str
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_tz_aware(self):
|
||||
for field in ("date", "created_at", "updated_at"):
|
||||
val = getattr(self, field)
|
||||
if val is not None and val.tzinfo is None:
|
||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
||||
return self
|
||||
12
models/Project.py
Normal file
12
models/Project.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Project(BaseModel):
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
owner_id: str
|
||||
members: List[str] = [] # List of User IDs
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
61
repos/albums_repo.py
Normal file
61
repos/albums_repo.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Album import Album
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AlbumsRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["albums"]
|
||||
|
||||
async def create_album(self, album: Album) -> str:
|
||||
res = await self.collection.insert_one(album.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||
try:
|
||||
res = await self.collection.find_one({"_id": ObjectId(album_id)})
|
||||
if not res:
|
||||
return None
|
||||
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Album(**res)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||
albums = []
|
||||
for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
albums.append(Album(**doc))
|
||||
return albums
|
||||
|
||||
async def update_album(self, album_id: str, album: Album) -> bool:
|
||||
if not album.id:
|
||||
album.id = album_id
|
||||
|
||||
model_dump = album.model_dump()
|
||||
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_album(self, album_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(album_id)},
|
||||
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(album_id)},
|
||||
{"$pull": {"generation_ids": generation_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
import logging
|
||||
from datetime import datetime, UTC
|
||||
from bson import ObjectId
|
||||
from uuid import uuid4
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Asset import Asset
|
||||
@@ -19,7 +21,8 @@ class AssetsRepo:
|
||||
# Main data
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
|
||||
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||
if uploaded:
|
||||
@@ -32,7 +35,8 @@ class AssetsRepo:
|
||||
# Thumbnail
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
|
||||
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||
if uploaded_thumb:
|
||||
@@ -46,8 +50,8 @@ class AssetsRepo:
|
||||
res = await self.collection.insert_one(asset.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]:
|
||||
filter = {}
|
||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
|
||||
if asset_type:
|
||||
filter["type"] = asset_type
|
||||
args = {}
|
||||
@@ -70,6 +74,12 @@ class AssetsRepo:
|
||||
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
|
||||
# So list DOES NOT return thumbnails by default.
|
||||
args["thumbnail"] = 0
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
filter['project_id'] = None
|
||||
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||
assets = []
|
||||
@@ -128,7 +138,8 @@ class AssetsRepo:
|
||||
if self.s3:
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
if await self.s3.upload_file(object_name, asset.data):
|
||||
asset.minio_object_name = object_name
|
||||
asset.minio_bucket = self.s3.bucket_name
|
||||
@@ -136,7 +147,8 @@ class AssetsRepo:
|
||||
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||
asset.minio_thumbnail_object_name = thumb_name
|
||||
asset.thumbnail = None
|
||||
@@ -157,8 +169,17 @@ class AssetsRepo:
|
||||
assets.append(Asset(**doc))
|
||||
return assets
|
||||
|
||||
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
|
||||
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
|
||||
async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||
filter = {}
|
||||
if character_id:
|
||||
filter["linked_char_id"] = character_id
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
return await self.collection.count_documents(filter)
|
||||
|
||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||
@@ -184,6 +205,61 @@ class AssetsRepo:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
|
||||
"""
|
||||
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
|
||||
Возвращает количество обработанных ассетов.
|
||||
"""
|
||||
if not asset_ids:
|
||||
return 0
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
|
||||
if not object_ids:
|
||||
return 0
|
||||
|
||||
# Находим ассеты, которые ещё не удалены
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
|
||||
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
|
||||
)
|
||||
|
||||
purged_count = 0
|
||||
ids_to_update = []
|
||||
|
||||
async for doc in cursor:
|
||||
ids_to_update.append(doc["_id"])
|
||||
|
||||
# Жёсткое удаление файлов из S3
|
||||
if self.s3:
|
||||
if doc.get("minio_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
|
||||
if doc.get("minio_thumbnail_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
|
||||
|
||||
purged_count += 1
|
||||
|
||||
# Мягкое удаление + очистка ссылок на S3
|
||||
if ids_to_update:
|
||||
await self.collection.update_many(
|
||||
{"_id": {"$in": ids_to_update}},
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"minio_object_name": None,
|
||||
"minio_thumbnail_object_name": None,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return purged_count
|
||||
|
||||
async def migrate_to_minio(self) -> dict:
|
||||
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
||||
if not self.s3:
|
||||
@@ -203,7 +279,8 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
|
||||
if await self.s3.upload_file(object_name, data):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
@@ -230,7 +307,8 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, thumb):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
@@ -12,32 +12,37 @@ class CharacterRepo:
|
||||
|
||||
async def add_character(self, character: Character) -> Character:
|
||||
op = await self.collection.insert_one(character.model_dump())
|
||||
character.id = op.inserted_id
|
||||
character.id = str(op.inserted_id)
|
||||
return character
|
||||
|
||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||
args = {}
|
||||
if not with_image_data:
|
||||
args["character_image_data"] = 0
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
|
||||
async def get_character(self, character_id: str) -> Character | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)})
|
||||
if res is None:
|
||||
return None
|
||||
else:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Character(**res)
|
||||
|
||||
async def get_all_characters(self) -> List[Character]:
|
||||
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None)
|
||||
|
||||
characters = []
|
||||
for doc in docs:
|
||||
# Конвертируем ObjectId в строку и кладем в поле id
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
|
||||
filter = {}
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
|
||||
chars = []
|
||||
for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
chars.append(Character(**doc))
|
||||
return chars
|
||||
|
||||
# Создаем объект
|
||||
characters.append(Character(**doc))
|
||||
async def update_char(self, char_id: str, character: Character) -> bool:
|
||||
result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||
return result.modified_count > 0
|
||||
|
||||
return characters
|
||||
|
||||
async def update_char(self, char_id: str, character: Character) -> None:
|
||||
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||
async def delete_character(self, char_id: str) -> bool:
|
||||
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
|
||||
return result.deleted_count > 0
|
||||
|
||||
11
repos/dao.py
11
repos/dao.py
@@ -4,6 +4,11 @@ from repos.assets_repo import AssetsRepo
|
||||
from repos.char_repo import CharacterRepo
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
from repos.environment_repo import EnvironmentRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -14,3 +19,9 @@ class DAO:
|
||||
self.chars = CharacterRepo(client, db_name)
|
||||
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
||||
self.generations = GenerationRepo(client, db_name)
|
||||
self.albums = AlbumsRepo(client, db_name)
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
self.posts = PostRepo(client, db_name)
|
||||
self.environments = EnvironmentRepo(client, db_name)
|
||||
|
||||
73
repos/environment_repo.py
Normal file
73
repos/environment_repo.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Environment import Environment
|
||||
|
||||
|
||||
class EnvironmentRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["environments"]
|
||||
|
||||
async def create_env(self, env: Environment) -> Environment:
|
||||
env_dict = env.model_dump(exclude={"id"})
|
||||
res = await self.collection.insert_one(env_dict)
|
||||
env.id = str(res.inserted_id)
|
||||
return env
|
||||
|
||||
async def get_env(self, env_id: str) -> Optional[Environment]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(env_id)})
|
||||
if not res:
|
||||
return None
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Environment(**res)
|
||||
|
||||
async def get_character_envs(self, character_id: str) -> List[Environment]:
|
||||
cursor = self.collection.find({"character_id": character_id})
|
||||
envs = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
envs.append(Environment(**doc))
|
||||
return envs
|
||||
|
||||
async def update_env(self, env_id: str, update_data: dict) -> bool:
|
||||
update_data["updated_at"] = datetime.utcnow()
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{"$set": update_data}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_env(self, env_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def add_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": {"$each": asset_ids}},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$pull": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, List
|
||||
from typing import Any, Optional, List
|
||||
from datetime import datetime, timedelta, UTC
|
||||
|
||||
from PIL.ImageChops import offset
|
||||
from bson import ObjectId
|
||||
@@ -16,7 +17,7 @@ class GenerationRepo:
|
||||
res = await self.collection.insert_one(generation.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
async def get_generation(self, generation_id: str) -> Generation | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
||||
if res is None:
|
||||
return None
|
||||
@@ -25,14 +26,29 @@ class GenerationRepo:
|
||||
return Generation(**res)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
limit: int = 10, offset: int = 10) -> List[Generation]:
|
||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
|
||||
|
||||
filter = {"is_deleted": False}
|
||||
filter: dict[str, Any] = {"is_deleted": False}
|
||||
if character_id is not None:
|
||||
filter["linked_character_id"] = character_id
|
||||
if status is not None:
|
||||
filter["status"] = status
|
||||
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||
if created_by is not None:
|
||||
filter["created_by"] = created_by
|
||||
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
|
||||
# But if project_id is passed, we filter by that.
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id is not None:
|
||||
filter["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
filter["idea_id"] = idea_id
|
||||
|
||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
||||
# Otherwise typically descending (newest first)
|
||||
sort_order = 1 if idea_id else -1
|
||||
|
||||
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
|
||||
offset).limit(limit).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
@@ -40,13 +56,216 @@ class GenerationRepo:
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int:
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
|
||||
args = {}
|
||||
if character_id is not None:
|
||||
args["linked_character_id"] = character_id
|
||||
if status is not None:
|
||||
args["status"] = status
|
||||
if created_by is not None:
|
||||
args["created_by"] = created_by
|
||||
if project_id is None:
|
||||
args["project_id"] = None
|
||||
if project_id is not None:
|
||||
args["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
args["idea_id"] = idea_id
|
||||
if album_id is not None:
|
||||
args["album_id"] = album_id
|
||||
return await self.collection.count_documents(args)
|
||||
|
||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
|
||||
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
|
||||
# Maintain order of generation_ids
|
||||
gen_map = {str(doc["_id"]): doc for doc in res}
|
||||
|
||||
for gen_id in generation_ids:
|
||||
doc = gen_map.get(gen_id)
|
||||
if doc:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
generations.append(Generation(**doc))
|
||||
|
||||
return generations
|
||||
|
||||
async def update_generation(self, generation: Generation, ):
|
||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||
|
||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
# 1. Match all done generations (including soft-deleted)
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
# 2. Group by null (total)
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(1)
|
||||
|
||||
if not res:
|
||||
return {
|
||||
"total_runs": 0,
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_cost": 0.0
|
||||
}
|
||||
|
||||
result = res[0]
|
||||
result.pop("_id")
|
||||
result["total_cost"] = round(result["total_cost"], 4)
|
||||
return result
|
||||
|
||||
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Returns usage statistics grouped by user or project.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": f"${group_by}",
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
pipeline.append({"$sort": {"total_cost": -1}})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(None)
|
||||
|
||||
results = []
|
||||
for item in res:
|
||||
entity_id = item.pop("_id")
|
||||
item["total_cost"] = round(item["total_cost"], 4)
|
||||
results.append({
|
||||
"entity_id": str(entity_id) if entity_id else "unknown",
|
||||
"stats": item
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
||||
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
generation["id"] = str(generation.pop("_id"))
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||
res = await self.collection.update_many(
|
||||
{
|
||||
"status": GenerationStatus.RUNNING,
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"status": GenerationStatus.FAILED,
|
||||
"failed_reason": "Timeout: Execution time limit exceeded",
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
return res.modified_count
|
||||
|
||||
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
|
||||
"""
|
||||
Мягко удаляет генерации старше N дней.
|
||||
Возвращает (количество удалённых, список asset IDs для очистки).
|
||||
"""
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days)
|
||||
filter_query = {
|
||||
"is_deleted": False,
|
||||
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
}
|
||||
|
||||
# Сначала собираем asset IDs из удаляемых генераций
|
||||
asset_ids: List[str] = []
|
||||
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
||||
async for doc in cursor:
|
||||
asset_ids.extend(doc.get("result_list", []))
|
||||
asset_ids.extend(doc.get("assets_list", []))
|
||||
|
||||
# Мягкое удаление
|
||||
res = await self.collection.update_many(
|
||||
filter_query,
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Убираем дубликаты
|
||||
unique_asset_ids = list(set(asset_ids))
|
||||
return res.modified_count, unique_asset_ids
|
||||
|
||||
91
repos/idea_repo.py
Normal file
91
repos/idea_repo.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Optional, List
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["ideas"]
|
||||
|
||||
async def create_idea(self, idea: Idea) -> str:
|
||||
res = await self.collection.insert_one(idea.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Idea(**res)
|
||||
return None
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
if project_id:
|
||||
match_stage = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_stage},
|
||||
{"$sort": {"updated_at": -1}},
|
||||
{"$skip": offset},
|
||||
{"$limit": limit},
|
||||
# Add string id field for lookup
|
||||
{"$addFields": {"str_id": {"$toString": "$_id"}}},
|
||||
# Lookup generations
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "generations",
|
||||
"let": {"idea_id": "$str_id"},
|
||||
"pipeline": [
|
||||
{
|
||||
"$match": {
|
||||
"$and": [
|
||||
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
|
||||
{"status": "done"},
|
||||
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
|
||||
{"is_deleted": False}
|
||||
]
|
||||
}
|
||||
},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
|
||||
{"$limit": 1}
|
||||
],
|
||||
"as": "generations"
|
||||
}
|
||||
},
|
||||
# Unwind generations array (preserve ideas without generations)
|
||||
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
|
||||
# Rename for clarity
|
||||
{"$addFields": {
|
||||
"last_generation": "$generations",
|
||||
"id": "$str_id"
|
||||
}},
|
||||
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
|
||||
]
|
||||
|
||||
return await self.collection.aggregate(pipeline).to_list(None)
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea_id)},
|
||||
{"$set": {"is_deleted": True}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_idea(self, idea: Idea) -> bool:
|
||||
if not idea.id or not ObjectId.is_valid(idea.id):
|
||||
return False
|
||||
|
||||
idea_dict = idea.model_dump()
|
||||
if "id" in idea_dict:
|
||||
del idea_dict["id"]
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea.id)},
|
||||
{"$set": idea_dict}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
97
repos/post_repo.py
Normal file
97
repos/post_repo.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Post import Post
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["posts"]
|
||||
|
||||
async def create_post(self, post: Post) -> str:
|
||||
res = await self.collection.insert_one(post.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Post(**res)
|
||||
return None
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
if project_id:
|
||||
match = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
if date_from or date_to:
|
||||
date_filter = {}
|
||||
if date_from:
|
||||
date_filter["$gte"] = date_from
|
||||
if date_to:
|
||||
date_filter["$lte"] = date_to
|
||||
match["date"] = date_filter
|
||||
|
||||
cursor = (
|
||||
self.collection.find(match)
|
||||
.sort("date", -1)
|
||||
.skip(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
posts = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
posts.append(Post(**doc))
|
||||
return posts
|
||||
|
||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": data},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$pull": {"generation_ids": generation_id}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
62
repos/project_repo.py
Normal file
62
repos/project_repo.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Project import Project
|
||||
|
||||
class ProjectRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["projects"]
|
||||
|
||||
async def create_project(self, project: Project) -> str:
|
||||
res = await self.collection.insert_one(project.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||
if not ObjectId.is_valid(project_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(project_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Project(**res)
|
||||
return None
|
||||
|
||||
async def get_projects_by_user(self, user_id: str) -> List[Project]:
|
||||
# Find projects where user is owner OR in members
|
||||
filter = {
|
||||
"$or": [
|
||||
{"owner_id": user_id},
|
||||
{"members": user_id}
|
||||
],
|
||||
"is_deleted": False
|
||||
}
|
||||
cursor = self.collection.find(filter).sort("created_at", -1)
|
||||
projects = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
projects.append(Project(**doc))
|
||||
return projects
|
||||
|
||||
async def add_member(self, project_id: str, user_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$addToSet": {"members": user_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_member(self, project_id: str, user_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$pull": {"members": user_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_project(self, project_id: str, updates: dict) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$set": updates}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_project(self, project_id: str) -> bool:
|
||||
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
|
||||
return res.modified_count > 0
|
||||
@@ -19,10 +19,16 @@ class UsersRepo:
|
||||
self.collection = client[db_name]["users"]
|
||||
|
||||
async def get_user(self, user_id: int):
|
||||
return await self.collection.find_one({"user_id": user_id})
|
||||
user = await self.collection.find_one({"user_id": user_id})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def get_user_by_username(self, username: str):
|
||||
return await self.collection.find_one({"username": username})
|
||||
user = await self.collection.find_one({"username": username})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
|
||||
"""Создает нового пользователя с username/паролем"""
|
||||
@@ -38,15 +44,23 @@ class UsersRepo:
|
||||
"created_at": datetime.now(),
|
||||
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
|
||||
"is_web_user": True,
|
||||
"is_admin": False
|
||||
"is_admin": False,
|
||||
"project_ids": [],
|
||||
"current_project_id": None
|
||||
}
|
||||
result = await self.collection.insert_one(user_doc)
|
||||
return await self.collection.find_one({"_id": result.inserted_id})
|
||||
user = await self.collection.find_one({"_id": result.inserted_id})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def get_pending_users(self):
|
||||
"""Возвращает список пользователей со статусом PENDING"""
|
||||
cursor = self.collection.find({"status": UserStatus.PENDING})
|
||||
return await cursor.to_list(length=100)
|
||||
users = await cursor.to_list(length=100)
|
||||
for user in users:
|
||||
user["id"] = str(user["_id"])
|
||||
return users
|
||||
|
||||
async def approve_user(self, username: str):
|
||||
await self.collection.update_one(
|
||||
|
||||
@@ -50,3 +50,5 @@ passlib[argon2]==1.7.4
|
||||
python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.22
|
||||
email-validator
|
||||
prometheus-fastapi-instrumentator
|
||||
pydantic-settings==2.13.0
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -51,56 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||
wait_msg = await message.answer("💾 Сохраняю персонажа...")
|
||||
|
||||
try:
|
||||
# ВОТ ТУТ скачиваем файл (прямо перед сохранением)
|
||||
# 1. Скачиваем файл (один раз)
|
||||
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
|
||||
file_io = await bot.download(file_id)
|
||||
# photo_bytes = file_io.getvalue() # Получаем байты
|
||||
|
||||
|
||||
# Создаем модель
|
||||
file_bytes = file_io.read()
|
||||
|
||||
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
|
||||
char = Character(
|
||||
id=None,
|
||||
name=name,
|
||||
character_image_data=file_io.read(),
|
||||
character_image_tg_id=None,
|
||||
character_image_doc_tg_id=file_id,
|
||||
character_bio=bio
|
||||
character_bio=bio,
|
||||
created_by=str(message.from_user.id)
|
||||
)
|
||||
file_io.close()
|
||||
|
||||
# Сохраняем через DAO
|
||||
|
||||
|
||||
# Сохраняем, чтобы получить ID
|
||||
await dao.chars.add_character(char)
|
||||
file_info = await bot.get_file(char.character_image_doc_tg_id)
|
||||
file_bytes = await bot.download_file(file_info.file_path)
|
||||
file_io = file_bytes.read()
|
||||
avatar_asset = await dao.assets.create_asset(
|
||||
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||
tg_doc_file_id=file_id))
|
||||
char.avatar_image = avatar_asset.link
|
||||
|
||||
# 3. Создаем Asset (связанный с персонажем)
|
||||
avatar_asset_id = await dao.assets.create_asset(
|
||||
Asset(
|
||||
name="avatar.png",
|
||||
type=AssetType.UPLOADED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=str(char.id),
|
||||
data=file_bytes,
|
||||
tg_doc_file_id=file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Обновляем персонажа ссылками на ассет
|
||||
char.avatar_asset_id = avatar_asset_id
|
||||
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
|
||||
|
||||
# Отправляем подтверждение
|
||||
# Используем байты для отправки обратно
|
||||
photo_msg = await message.answer_photo(
|
||||
photo=BufferedInputFile(file_io,
|
||||
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
|
||||
photo=BufferedInputFile(file_bytes, filename="char.jpg"),
|
||||
caption=(
|
||||
"🎉 <b>Персонаж создан!</b>\n\n"
|
||||
f"👤 <b>Имя:</b> {char.name}\n"
|
||||
f"📝 <b>Био:</b> {char.character_bio}"
|
||||
)
|
||||
)
|
||||
file_bytes.close()
|
||||
char.character_image_tg_id = photo_msg.photo[0].file_id
|
||||
|
||||
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
|
||||
char.character_image_tg_id = photo_msg.photo[-1].file_id
|
||||
|
||||
# Финальное обновление персонажа
|
||||
await dao.chars.update_char(char.id, char)
|
||||
|
||||
await wait_msg.delete()
|
||||
file_io.close()
|
||||
|
||||
# Сбрасываем состояние
|
||||
await state.clear()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logger.error(f"Error creating character: {e}")
|
||||
traceback.print_exc()
|
||||
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
|
||||
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
||||
|
||||
|
||||
@router.message(Command("chars"))
|
||||
|
||||
@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
|
||||
await wait_msg.delete()
|
||||
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
|
||||
|
||||
|
||||
@router.message(Command("gen_mode"))
|
||||
@@ -259,7 +259,8 @@ async def handle_album(
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||
linked_char_id = data["char_id"]))
|
||||
linked_char_id = data["char_id"],
|
||||
created_by=str(message.from_user.id)))
|
||||
else:
|
||||
await message.answer("❌ Генерация не вернула изображений.")
|
||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||
@@ -314,7 +315,8 @@ async def gen_mode_start(
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||
linked_char_id=data["char_id"]))
|
||||
linked_char_id=data["char_id"],
|
||||
created_by=str(message.from_user.id)))
|
||||
|
||||
else:
|
||||
await message.answer("❌ Ничего не сгенерировалось.")
|
||||
|
||||
101
tests/test_character_crud.py
Normal file
101
tests/test_character_crud.py
Normal file
@@ -0,0 +1,101 @@
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import asyncio
|
||||
from config import settings
|
||||
|
||||
from aiws import app
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from models.Character import Character
|
||||
|
||||
# Config for test DB
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_chars"
|
||||
|
||||
# Mock User
|
||||
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||
MOCK_USER = {
|
||||
"_id": MOCK_USER_ID,
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"status": "allowed"
|
||||
}
|
||||
|
||||
# Override get_current_user to bypass auth
|
||||
def mock_get_current_user():
|
||||
return MOCK_USER
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||
|
||||
# Setup Real DAO with Test DB
|
||||
client_mongo = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client_mongo, db_name=DB_NAME)
|
||||
|
||||
def mock_get_dao():
|
||||
return dao
|
||||
|
||||
app.dependency_overrides[get_dao] = mock_get_dao
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
# Setup: Ensure clean state
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||
|
||||
yield
|
||||
|
||||
# Teardown
|
||||
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||
loop.close()
|
||||
|
||||
def test_character_crud_flow():
|
||||
# 1. Create Character
|
||||
create_payload = {
|
||||
"name": "Test Character",
|
||||
"character_bio": "A bio for test character",
|
||||
"character_image_doc_tg_id": "file_123",
|
||||
"avatar_image": "http://example.com/avatar.jpg"
|
||||
}
|
||||
|
||||
response = client.post("/api/characters/", json=create_payload)
|
||||
assert response.status_code == 200, response.text
|
||||
char_data = response.json()
|
||||
assert char_data["name"] == create_payload["name"]
|
||||
assert char_data["created_by"] == MOCK_USER_ID
|
||||
char_id = char_data["id"]
|
||||
assert char_id is not None
|
||||
|
||||
# 2. Get Character
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == char_id
|
||||
|
||||
# 3. Update Character
|
||||
update_payload = {
|
||||
"name": "Updated Name",
|
||||
"character_bio": "Updated bio"
|
||||
}
|
||||
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
updated_data = response.json()
|
||||
assert updated_data["name"] == "Updated Name"
|
||||
assert updated_data["character_bio"] == "Updated bio"
|
||||
|
||||
# Verify update persistent
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
# 4. Delete Character
|
||||
response = client.delete(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify deletion
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 404, "Deleted character should return 404"
|
||||
64
tests/test_character_integration.py
Normal file
64
tests/test_character_integration.py
Normal file
@@ -0,0 +1,64 @@
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# 1. Set Auth Bypass and Test Config
|
||||
os.environ["DB_NAME"] = "bot_db_test_integration"
|
||||
# We keep MONGO_HOST as is (it works in verified script)
|
||||
|
||||
# 2. Import app AFTER setting env
|
||||
from main import app
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
# 3. Override Auth
|
||||
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||
MOCK_USER = {
|
||||
"_id": MOCK_USER_ID,
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"status": "allowed",
|
||||
"project_ids": []
|
||||
}
|
||||
|
||||
def mock_get_current_user():
|
||||
return MOCK_USER
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_character_crud_lifecycle():
|
||||
# 1. Create
|
||||
create_payload = {
|
||||
"name": "Integration Test Char",
|
||||
"character_bio": "Testing with real app structure",
|
||||
"character_image_doc_tg_id": "doc_123",
|
||||
"avatar_image": "http://example.com/img.jpg"
|
||||
}
|
||||
|
||||
response = client.post("/api/characters/", json=create_payload)
|
||||
assert response.status_code == 200, response.text
|
||||
char_data = response.json()
|
||||
assert char_data["name"] == create_payload["name"]
|
||||
char_id = char_data["id"]
|
||||
|
||||
# 2. Get
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == char_id
|
||||
|
||||
# 3. Update
|
||||
update_payload = {"name": "Updated Int Name"}
|
||||
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Int Name"
|
||||
|
||||
# 4. Delete
|
||||
response = client.delete(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# 5. Verify Delete
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 404
|
||||
63
tests/test_external_import.py
Executable file
63
tests/test_external_import.py
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for external generation import API.
|
||||
This script demonstrates how to call the import endpoint with proper HMAC signature.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import hashlib
|
||||
import json
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
from config import settings
|
||||
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
# Configuration
|
||||
API_URL = "http://localhost:8090/api/generations/import"
|
||||
SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production"
|
||||
|
||||
# Sample generation data
|
||||
generation_data = {
|
||||
"prompt": "A beautiful sunset over mountains",
|
||||
"tech_prompt": "High quality landscape photography",
|
||||
"image_url": "https://picsum.photos/512/512", # Sample image URL
|
||||
# OR use base64:
|
||||
# "image_data": "base64_encoded_image_string_here",
|
||||
"aspect_ratio": "9:16",
|
||||
"quality": "1k",
|
||||
"created_by": "external_user_123",
|
||||
"execution_time_seconds": 5.2,
|
||||
"token_usage": 1000,
|
||||
"input_token_usage": 200,
|
||||
"output_token_usage": 800
|
||||
}
|
||||
|
||||
# Convert to JSON
|
||||
body = json.dumps(generation_data).encode('utf-8')
|
||||
|
||||
# Compute HMAC signature
|
||||
signature = hmac.new(
|
||||
SECRET.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Make request
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Signature": signature
|
||||
}
|
||||
|
||||
print(f"Sending request to {API_URL}")
|
||||
print(f"Signature: {signature}")
|
||||
|
||||
try:
|
||||
response = requests.post(API_URL, data=body, headers=headers)
|
||||
print(f"\nStatus Code: {response.status_code}")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(f"Response text: {e.response.text}")
|
||||
96
tests/test_idea.py
Normal file
96
tests/test_idea.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
|
||||
# Import from project root (requires PYTHONPATH=.)
|
||||
from api.service.idea_service import IdeaService
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
async def test_idea_flow():
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client, db_name=DB_NAME)
|
||||
service = IdeaService(dao)
|
||||
|
||||
# 1. Create an Idea
|
||||
print("Creating idea...")
|
||||
user_id = "test_user_123"
|
||||
project_id = "test_project_abc"
|
||||
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
|
||||
print(f"Idea created: {idea.id} - {idea.name}")
|
||||
|
||||
# 2. Update Idea
|
||||
print("Updating idea...")
|
||||
updated_idea = await service.update_idea(idea.id, description="Updated description")
|
||||
print(f"Idea updated: {updated_idea.description}")
|
||||
if updated_idea.description == "Updated description":
|
||||
print("✅ Idea update successful")
|
||||
else:
|
||||
print("❌ Idea update FAILED")
|
||||
|
||||
# 3. Add Generation linked to Idea
|
||||
print("Creating generation linked to idea...")
|
||||
gen = Generation(
|
||||
prompt="idea generation 1",
|
||||
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
aspect_ratio=AspectRatios.NINESIXTEEN,
|
||||
quality=Quality.ONEK,
|
||||
assets_list=[]
|
||||
)
|
||||
gen_id = await dao.generations.create_generation(gen)
|
||||
print(f"Created generation: {gen_id}")
|
||||
|
||||
# Link generation to idea
|
||||
print("Linking generation to idea...")
|
||||
success = await service.add_generation_to_idea(idea.id, gen_id)
|
||||
if success:
|
||||
print("✅ Linking successful")
|
||||
else:
|
||||
print("❌ Linking FAILED")
|
||||
|
||||
# Debug: Check if generation was saved with idea_id
|
||||
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
|
||||
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
|
||||
|
||||
# 4. Fetch Generations for Idea (Verify filtering and ordering)
|
||||
print("Fetching generations for idea...")
|
||||
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
|
||||
print(f"Found {len(gens)} generations in idea")
|
||||
|
||||
if len(gens) == 1 and gens[0].id == gen_id:
|
||||
print("✅ Generation retrieval successful")
|
||||
else:
|
||||
print("❌ Generation retrieval FAILED")
|
||||
|
||||
# 5. Fetch Ideas for Project
|
||||
ideas = await service.get_ideas(project_id)
|
||||
print(f"Found {len(ideas)} ideas for project")
|
||||
|
||||
# Cleaning up
|
||||
print("Cleaning up...")
|
||||
await service.delete_idea(idea.id)
|
||||
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
|
||||
|
||||
# Verify deletion
|
||||
deleted_idea = await service.get_idea(idea.id)
|
||||
# IdeaRepo.delete_idea logic sets is_deleted=True
|
||||
if deleted_idea and deleted_idea.is_deleted:
|
||||
print(f"✅ Idea deleted successfully")
|
||||
|
||||
# Hard delete for cleanup
|
||||
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_idea_flow())
|
||||
@@ -1,15 +1,14 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import settings
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
async def test_s3():
|
||||
load_dotenv()
|
||||
|
||||
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
|
||||
access_key = os.getenv("MINIO_ACCESS_KEY")
|
||||
secret_key = os.getenv("MINIO_SECRET_KEY")
|
||||
bucket = os.getenv("MINIO_BUCKET")
|
||||
endpoint = settings.MINIO_ENDPOINT
|
||||
access_key = settings.MINIO_ACCESS_KEY
|
||||
secret_key = settings.MINIO_SECRET_KEY
|
||||
bucket = settings.MINIO_BUCKET
|
||||
|
||||
print(f"Connecting to {endpoint}, bucket: {bucket}")
|
||||
|
||||
|
||||
50
tests/test_scheduler.py
Normal file
50
tests/test_scheduler.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from config import settings
|
||||
|
||||
# Mock configs if not present in env
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
async def test_scheduler():
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
repo = GenerationRepo(client, db_name=DB_NAME)
|
||||
|
||||
# 1. Create a "stale" generation (2 hours ago)
|
||||
stale_gen = Generation(
|
||||
prompt="stale test",
|
||||
status=GenerationStatus.RUNNING,
|
||||
created_at=datetime.now(UTC) - timedelta(minutes=120),
|
||||
assets_list=[],
|
||||
aspect_ratio="NINESIXTEEN",
|
||||
quality="ONEK"
|
||||
)
|
||||
gen_id = await repo.create_generation(stale_gen)
|
||||
print(f"Created stale generation: {gen_id}")
|
||||
|
||||
# 2. Run cleanup
|
||||
print("Running cleanup...")
|
||||
count = await repo.cancel_stale_generations(timeout_minutes=60)
|
||||
print(f"Cleaned up {count} generations")
|
||||
|
||||
# 3. Verify status
|
||||
updated_gen = await repo.get_generation(gen_id)
|
||||
print(f"Generation status: {updated_gen.status}")
|
||||
print(f"Failed reason: {updated_gen.failed_reason}")
|
||||
|
||||
if updated_gen.status == GenerationStatus.FAILED:
|
||||
print("✅ SUCCESS: Generation marked as FAILED")
|
||||
else:
|
||||
print("❌ FAILURE: Generation status not updated")
|
||||
|
||||
# Cleanup
|
||||
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_scheduler())
|
||||
90
tests/verify_albums_manual.py
Normal file
90
tests/verify_albums_manual.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add project root to path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from repos.dao import DAO
|
||||
from models.Album import Album
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
# Mock config
|
||||
# Use the same host as aiws.py but different DB
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_albums"
|
||||
|
||||
async def test_albums():
|
||||
print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...")
|
||||
|
||||
# Needs to run inside a loop from main
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client, db_name=DB_NAME)
|
||||
|
||||
try:
|
||||
# 1. Clean up
|
||||
await client[DB_NAME]["albums"].drop()
|
||||
await client[DB_NAME]["generations"].drop()
|
||||
print("✅ Cleaned up test database")
|
||||
|
||||
# 2. Create Album
|
||||
album = Album(name="Test Album", description="A test album")
|
||||
print("Creating album...")
|
||||
album_id = await dao.albums.create_album(album)
|
||||
print(f"✅ Created Album: {album_id}")
|
||||
|
||||
# 3. Create Generations
|
||||
gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||
gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||
|
||||
print("Creating generations...")
|
||||
gen1_id = await dao.generations.create_generation(gen1)
|
||||
gen2_id = await dao.generations.create_generation(gen2)
|
||||
print(f"✅ Created Generations: {gen1_id}, {gen2_id}")
|
||||
|
||||
# 4. Add generations to album
|
||||
print("Adding generations to album...")
|
||||
await dao.albums.add_generation(album_id, gen1_id)
|
||||
await dao.albums.add_generation(album_id, gen2_id)
|
||||
print("✅ Added generations to album")
|
||||
|
||||
# 5. Fetch album and check generation_ids
|
||||
album_fetched = await dao.albums.get_album(album_id)
|
||||
assert album_fetched is not None
|
||||
assert len(album_fetched.generation_ids) == 2
|
||||
assert gen1_id in album_fetched.generation_ids
|
||||
assert gen2_id in album_fetched.generation_ids
|
||||
print("✅ Verified generations in album")
|
||||
|
||||
# 6. Fetch generations by IDs via GenerationRepo
|
||||
generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id])
|
||||
assert len(generations) == 2
|
||||
|
||||
# Ensure ID type match (str vs ObjectId handling in repo)
|
||||
gen_ids_fetched = [g.id for g in generations]
|
||||
assert gen1_id in gen_ids_fetched
|
||||
assert gen2_id in gen_ids_fetched
|
||||
print("✅ Verified fetching generations by IDs")
|
||||
|
||||
# 7. Remove generation
|
||||
print("Removing generation...")
|
||||
await dao.albums.remove_generation(album_id, gen1_id)
|
||||
album_fetched = await dao.albums.get_album(album_id)
|
||||
assert len(album_fetched.generation_ids) == 1
|
||||
assert album_fetched.generation_ids[0] == gen2_id
|
||||
print("✅ Verified removing generation from album")
|
||||
|
||||
print("🎉 Album Verification SUCCESS")
|
||||
|
||||
finally:
|
||||
# Cleanup client
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(test_albums())
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
@@ -1,29 +1,28 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from config import settings
|
||||
from models.Asset import Asset, AssetType
|
||||
from repos.assets_repo import AssetsRepo
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
# Load env to get credentials
|
||||
load_dotenv()
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
async def test_integration():
|
||||
print("🚀 Starting integration test...")
|
||||
|
||||
# 1. Setup Dependencies
|
||||
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
mongo_uri = settings.MONGO_HOST
|
||||
client = AsyncIOMotorClient(mongo_uri)
|
||||
db_name = os.getenv("DB_NAME", "bot_db_test")
|
||||
db_name = settings.DB_NAME + "_test"
|
||||
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
repo = AssetsRepo(client, s3_adapter, db_name=db_name)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
46
utils/external_auth.py
Normal file
46
utils/external_auth.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import hmac
|
||||
import hashlib
|
||||
import os
|
||||
from fastapi import Header, HTTPException
|
||||
from typing import Optional
|
||||
|
||||
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
|
||||
"""
|
||||
Verify HMAC-SHA256 signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from X-Signature header
|
||||
secret: Shared secret key
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise
|
||||
"""
|
||||
expected_signature = hmac.new(
|
||||
secret.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(signature, expected_signature)
|
||||
|
||||
|
||||
async def verify_external_signature(
|
||||
x_signature: Optional[str] = Header(None, alias="X-Signature")
|
||||
):
|
||||
"""
|
||||
FastAPI dependency to verify external API signature.
|
||||
|
||||
Raises:
|
||||
HTTPException: If signature is missing or invalid
|
||||
"""
|
||||
if not x_signature:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing X-Signature header"
|
||||
)
|
||||
|
||||
# Note: We'll need to access the raw request body in the endpoint
|
||||
# This dependency just validates the header exists
|
||||
# Actual signature verification happens in the endpoint
|
||||
return x_signature
|
||||
@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from config import settings
|
||||
|
||||
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
|
||||
# SECRET_KEY должен быть сложным и секретным в продакшене!
|
||||
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
|
||||
# Настройки безопасности берутся из config.py
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = settings.ALGORITHM
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user