Compare commits
56 Commits
8a89b27624
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e011805186 | |||
| d9caececd7 | |||
| c1300b7a2d | |||
| f6001f5994 | |||
| e4a39e90c3 | |||
| e976fe1c58 | |||
| ecc8d69039 | |||
| bc9230a49b | |||
| f07105b0e5 | |||
| 9a5d54a373 | |||
| 1868864f76 | |||
| 9e0c522b5f | |||
| e1d941a2cd | |||
| c7c27197c9 | |||
| 5aa6391dc8 | |||
| ffb0463fe0 | |||
| dd0f8a1cb6 | |||
| 4af5134726 | |||
| 7488665d04 | |||
| ecc88aca62 | |||
| 70f50170fc | |||
| f4207fc4c1 | |||
| c50d2c8ad9 | |||
| 4586daac38 | |||
| 198ac44960 | |||
| d820d9145b | |||
| c93e577bcf | |||
| c5d4849bff | |||
| 9abfbef871 | |||
| 68a3f529cb | |||
| e2c050515d | |||
| 5e7dc19bf3 | |||
| 97483b7030 | |||
| 2d3da59de9 | |||
| 279cb5c6f6 | |||
| 30138bab38 | |||
| 977cab92f8 | |||
| dcab238d3e | |||
| 9d2e4e47de | |||
| c6142715d9 | |||
| 456562ec1d | |||
| 0d0fbdf7d6 | |||
| f63bcedb13 | |||
| be92c766ac | |||
| 482bc1d9b7 | |||
| a2321cf070 | |||
| 29ccd5743e | |||
| d9de2f48d2 | |||
| 1ddeb0af46 | |||
| a7c2319f13 | |||
| 00e83b8561 | |||
| a9d24c725e | |||
| 458b6ebfc3 | |||
| 668aadcdc9 | |||
| 4461964791 | |||
| fa3e1bb05f |
33
.context.md
Normal file
33
.context.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Project Context: AI Char Bot
|
||||
|
||||
## Overview
|
||||
Python backend project using FastAPI and MongoDB (Motor).
|
||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||
|
||||
## Architecture
|
||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||
- **Service Layer**: `api/service` (Business logic).
|
||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||
|
||||
## Coding Standards & Preferences
|
||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||
- **Error Handling**:
|
||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||
- Services/Routers handle `HTTPException`.
|
||||
|
||||
## Key Features & Implementation Details
|
||||
- **Generations**:
|
||||
- Managed by `GenerationService` and `GenerationRepo`.
|
||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||
- **Ideas**:
|
||||
- Managed by `IdeaService` and `IdeaRepo`.
|
||||
- Can have linked generations.
|
||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||
|
||||
## Recent Changes
|
||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||
19
.dockerignore
Normal file
19
.dockerignore
Normal file
@@ -0,0 +1,19 @@
|
||||
.git
|
||||
.gitignore
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
.venv/
|
||||
node_modules/
|
||||
tmp/
|
||||
logs/
|
||||
*.log
|
||||
dist/
|
||||
build/
|
||||
.cache/
|
||||
.idea/
|
||||
.vscode/
|
||||
1
.env
1
.env
@@ -8,3 +8,4 @@ MINIO_ACCESS_KEY=admin
|
||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||
MINIO_BUCKET=ai-char
|
||||
MODE=production
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
33
.gemini/AGENTS.md
Normal file
33
.gemini/AGENTS.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Project Context: AI Char Bot
|
||||
|
||||
## Overview
|
||||
Python backend project using FastAPI and MongoDB (Motor).
|
||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||
|
||||
## Architecture
|
||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||
- **Service Layer**: `api/service` (Business logic).
|
||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||
|
||||
## Coding Standards & Preferences
|
||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||
- **Error Handling**:
|
||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||
- Services/Routers handle `HTTPException`.
|
||||
|
||||
## Key Features & Implementation Details
|
||||
- **Generations**:
|
||||
- Managed by `GenerationService` and `GenerationRepo`.
|
||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||
- **Ideas**:
|
||||
- Managed by `IdeaService` and `IdeaRepo`.
|
||||
- Can have linked generations.
|
||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||
|
||||
## Recent Changes
|
||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -9,3 +9,18 @@ minio_backup.tar.gz
|
||||
.idea
|
||||
.venv
|
||||
.vscode
|
||||
.vscode/launch.json
|
||||
middlewares/__pycache__/
|
||||
middlewares/*.pyc
|
||||
api/__pycache__/
|
||||
api/*.pyc
|
||||
repos/__pycache__/
|
||||
repos/*.pyc
|
||||
adapters/__pycache__/
|
||||
adapters/*.pyc
|
||||
services/__pycache__/
|
||||
services/*.pyc
|
||||
utils/__pycache__/
|
||||
utils/*.pyc
|
||||
.vscode/launch.json
|
||||
repos/__pycache__/assets_repo.cpython-313.pyc
|
||||
|
||||
31
.vscode/launch.json
vendored
31
.vscode/launch.json
vendored
@@ -7,38 +7,15 @@
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"main:app",
|
||||
"aiws:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8090"
|
||||
"8090",
|
||||
"--host",
|
||||
"0.0.0.0"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Tests: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"${file}"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY . .
|
||||
|
||||
# Запуск приложения (замени app.py на свой файл)
|
||||
CMD ["python", "main.py"]
|
||||
CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -8,7 +8,7 @@ from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from models.enums import AspectRatios, Quality
|
||||
from models.enums import AspectRatios, Quality, TextModel, ImageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,36 +19,37 @@ class GoogleAdapter:
|
||||
raise ValueError("API Key for Gemini is missing")
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
# Константы моделей
|
||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||
contents = [prompt]
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
||||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
||||
contents : list [Any]= [prompt]
|
||||
opened_images = []
|
||||
if images_list:
|
||||
logger.info(f"Preparing content with {len(images_list)} images")
|
||||
for img_bytes in images_list:
|
||||
try:
|
||||
# Gemini API требует PIL Image на входе
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
contents.append(image)
|
||||
opened_images.append(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {e}")
|
||||
else:
|
||||
logger.info("Preparing content with no images")
|
||||
return contents
|
||||
return contents, opened_images
|
||||
|
||||
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||
def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
|
||||
"""
|
||||
Генерация текста (Чат или Vision).
|
||||
Возвращает строку с ответом.
|
||||
"""
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating text: {prompt}")
|
||||
if model not in [m.value for m in TextModel]:
|
||||
raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
|
||||
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating text: {prompt} with model: {model}")
|
||||
try:
|
||||
response = self.client.models.generate_content(
|
||||
model=self.TEXT_MODEL,
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['TEXT'],
|
||||
@@ -68,22 +69,27 @@ class GoogleAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Text API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
"""
|
||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||
Возвращает список байтовых потоков (готовых к отправке).
|
||||
"""
|
||||
if model not in [m.value for m in ImageModel]:
|
||||
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
|
||||
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
|
||||
|
||||
start_time = datetime.now()
|
||||
token_usage = 0
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(
|
||||
model=self.IMAGE_MODEL,
|
||||
model=model,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['IMAGE'],
|
||||
@@ -101,8 +107,20 @@ class GoogleAdapter:
|
||||
if response.usage_metadata:
|
||||
token_usage = response.usage_metadata.total_token_count
|
||||
|
||||
if response.parts is None and response.candidates[0].finish_reason is not None:
|
||||
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
||||
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
|
||||
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
||||
raise GoogleGenerationException(
|
||||
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
|
||||
)
|
||||
|
||||
# Check candidate-level block
|
||||
if response.parts is None:
|
||||
response_reason = (
|
||||
response.candidates[0].finish_reason
|
||||
if response.candidates and len(response.candidates) > 0
|
||||
else "Unknown"
|
||||
)
|
||||
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
|
||||
|
||||
generated_images = []
|
||||
|
||||
@@ -113,7 +131,9 @@ class GoogleAdapter:
|
||||
try:
|
||||
# 1. Берем сырые байты
|
||||
raw_data = part.inline_data.data
|
||||
byte_arr = io.BytesIO(raw_data)
|
||||
if raw_data is None:
|
||||
raise GoogleGenerationException("Generation returned no data")
|
||||
byte_arr : io.BytesIO = io.BytesIO(raw_data)
|
||||
|
||||
# 2. Нейминг (формально, для TG)
|
||||
timestamp = datetime.now().timestamp()
|
||||
@@ -148,3 +168,7 @@ class GoogleAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Image API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
del contents
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, BinaryIO
|
||||
from typing import Optional, BinaryIO, AsyncGenerator
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
import os
|
||||
@@ -18,7 +18,7 @@ class S3Adapter:
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_client(self):
|
||||
async with self.session.client(
|
||||
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
@@ -56,6 +56,37 @@ class S3Adapter:
|
||||
print(f"Error downloading from S3: {e}")
|
||||
return None
|
||||
|
||||
async def get_file_size(self, object_name: str) -> Optional[int]:
|
||||
"""Returns the size of the file in bytes."""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
response = await client.head_object(Bucket=self.bucket_name, Key=object_name)
|
||||
return response['ContentLength']
|
||||
except ClientError as e:
|
||||
print(f"Error getting file size from S3: {e}")
|
||||
return None
|
||||
|
||||
async def stream_file(self, object_name: str, range_header: Optional[str] = None, chunk_size: int = 65536) -> AsyncGenerator[bytes, None]:
|
||||
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
args = {'Bucket': self.bucket_name, 'Key': object_name}
|
||||
if range_header:
|
||||
args['Range'] = range_header
|
||||
|
||||
response = await client.get_object(**args)
|
||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
||||
body = response['Body']
|
||||
|
||||
while True:
|
||||
chunk = await body.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
except ClientError as e:
|
||||
print(f"Error streaming from S3: {e}")
|
||||
return
|
||||
|
||||
async def delete_file(self, object_name: str):
|
||||
"""Deletes a file from S3."""
|
||||
try:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aiogram import Bot, Dispatcher, Router, F
|
||||
@@ -9,12 +8,14 @@ from aiogram.enums import ParseMode
|
||||
from aiogram.filters import CommandStart, Command
|
||||
from aiogram.types import Message
|
||||
from aiogram.fsm.storage.mongo import MongoStorage
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from prometheus_client import Info
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||
from config import settings
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.service.generation_service import GenerationService
|
||||
@@ -40,17 +41,22 @@ from api.endpoints.generation_router import router as api_gen_router
|
||||
from api.endpoints.auth import router as api_auth_router
|
||||
from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
from api.endpoints.post_router import router as post_api_router
|
||||
from api.endpoints.environment_router import router as environment_api_router
|
||||
from api.endpoints.inspiration_router import router as inspiration_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- КОНФИГУРАЦИЯ ---
|
||||
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
# Настройки теперь берутся из config.py
|
||||
BOT_TOKEN = settings.BOT_TOKEN
|
||||
GEMINI_API_KEY = settings.GEMINI_API_KEY
|
||||
|
||||
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
||||
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
||||
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
ADMIN_ID = settings.ADMIN_ID
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@@ -60,6 +66,8 @@ def setup_logging():
|
||||
|
||||
|
||||
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
||||
if BOT_TOKEN is None:
|
||||
raise ValueError("BOT_TOKEN is not set")
|
||||
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||
|
||||
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
||||
@@ -72,15 +80,19 @@ char_repo = CharacterRepo(mongo_client)
|
||||
|
||||
# S3 Adapter
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||
if GEMINI_API_KEY is None:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||
generation_service = GenerationService(dao, gemini, bot)
|
||||
if bot is None:
|
||||
raise ValueError("bot is not set")
|
||||
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
|
||||
album_service = AlbumService(dao)
|
||||
|
||||
# Dispatcher
|
||||
@@ -117,6 +129,18 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
|
||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||
|
||||
|
||||
async def start_scheduler(service: GenerationService):
|
||||
while True:
|
||||
try:
|
||||
logger.info("Running scheduler for stacked generation killing")
|
||||
await service.cleanup_stale_generations()
|
||||
await service.cleanup_old_data(days=14)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler error: {e}")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -143,22 +167,33 @@ async def lifespan(app: FastAPI):
|
||||
# 2. ЗАПУСК БОТА (в фоне)
|
||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
||||
polling_task = asyncio.create_task(
|
||||
dp.start_polling(bot, handle_signals=False)
|
||||
)
|
||||
print("🤖 Bot polling started")
|
||||
# polling_task = asyncio.create_task(
|
||||
# dp.start_polling(bot, handle_signals=False)
|
||||
# )
|
||||
# print("🤖 Bot polling started")
|
||||
|
||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||
print("⏰ Scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
# --- SHUTDOWN ---
|
||||
print("🛑 Shutting down...")
|
||||
|
||||
# 3. Остановка бота
|
||||
polling_task.cancel()
|
||||
# 4. Остановка шедулера
|
||||
scheduler_task.cancel()
|
||||
try:
|
||||
await polling_task
|
||||
await scheduler_task
|
||||
except asyncio.CancelledError:
|
||||
print("🤖 Bot polling stopped")
|
||||
print("⏰ Scheduler stopped")
|
||||
|
||||
# 3. Остановка бота
|
||||
# polling_task.cancel()
|
||||
# try:
|
||||
# await polling_task
|
||||
# except asyncio.CancelledError:
|
||||
# print("🤖 Bot polling stopped")
|
||||
|
||||
# 4. Отключение БД
|
||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||
@@ -177,17 +212,30 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Подключаем роутер API
|
||||
from api.endpoints.auth import router as auth_api_router
|
||||
from api.endpoints.admin import router as admin_api_router
|
||||
app.include_router(auth_api_router)
|
||||
app.include_router(admin_api_router)
|
||||
# Подключаем роутеры API
|
||||
app.include_router(api_auth_router)
|
||||
app.include_router(api_admin_router)
|
||||
app.include_router(api_assets_router)
|
||||
app.include_router(api_char_router)
|
||||
app.include_router(api_gen_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(api_admin_router)
|
||||
app.include_router(api_auth_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
app.include_router(post_api_router)
|
||||
app.include_router(environment_api_router)
|
||||
app.include_router(inspiration_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
|
||||
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
|
||||
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
|
||||
).instrument(
|
||||
app,
|
||||
metric_namespace="ai_bot",
|
||||
).expose(app, endpoint="/metrics", include_in_schema=False)
|
||||
app_info = Info("fastapi_app_info", "FastAPI application info")
|
||||
app_info.info({"app_name": "ai-bot"})
|
||||
|
||||
|
||||
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
||||
@@ -214,7 +262,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# Создаем конфигурацию uvicorn вручную
|
||||
# loop="asyncio" заставляет использовать стандартный цикл
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Запускаем сервер (lifespan запустится внутри)
|
||||
Binary file not shown.
@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
# ... ваши импорты ...
|
||||
@@ -44,3 +45,26 @@ def get_generation_service(
|
||||
bot: Bot = Depends(get_bot_client),
|
||||
) -> GenerationService:
|
||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
||||
|
||||
from api.service.idea_service import IdeaService
|
||||
|
||||
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
|
||||
return IdeaService(dao)
|
||||
|
||||
from fastapi import Header
|
||||
|
||||
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||
return x_project_id
|
||||
|
||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
||||
return AlbumService(dao)
|
||||
|
||||
from api.service.post_service import PostService
|
||||
|
||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||
return PostService(dao)
|
||||
|
||||
from api.service.inspiration_service import InspirationService
|
||||
|
||||
def get_inspiration_service(dao: DAO = Depends(get_dao), s3_adapter: S3Adapter = Depends(get_s3_adapter)) -> InspirationService:
|
||||
return InspirationService(dao, s3_adapter)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,10 +1,12 @@
|
||||
from typing import Annotated, List
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from repos.user_repo import UsersRepo, UserStatus
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||
from jose import JWTError, jwt
|
||||
from starlette.requests import Request
|
||||
@@ -23,7 +25,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str = payload.get("sub")
|
||||
username: str | None = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
@@ -52,7 +54,7 @@ class UserResponse(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@router.get("/approvals", response_model=List[UserResponse])
|
||||
@router.get("/approvals", response_model=list[UserResponse])
|
||||
async def list_pending_users(
|
||||
admin: Annotated[dict, Depends(get_current_admin)],
|
||||
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||
|
||||
@@ -1,35 +1,37 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from fastapi import APIRouter, HTTPException, status, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
from models.Album import Album
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_album_service
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||
|
||||
class AlbumCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
|
||||
class AlbumUpdateRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
|
||||
class AlbumResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||
description: str | None = None
|
||||
generation_ids: list[str] = []
|
||||
cover_asset_id: str | None = None # Not implemented yet
|
||||
|
||||
@router.post("/", response_model=AlbumResponse)
|
||||
@router.post("", response_model=AlbumResponse)
|
||||
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||
return AlbumResponse(**album.model_dump())
|
||||
|
||||
@router.get("/", response_model=List[AlbumResponse])
|
||||
@router.get("", response_model=list[AlbumResponse])
|
||||
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
albums = await service.get_albums(limit=limit, offset=offset)
|
||||
@@ -74,7 +76,7 @@ async def remove_generation_from_album(request: Request, album_id: str, generati
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
||||
@router.get("/{album_id}/generations", response_model=list[GenerationResponse])
|
||||
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any
|
||||
|
||||
from aiogram.types import BufferedInputFile
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
||||
from fastapi.openapi.models import MediaType
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import MongoClient
|
||||
from starlette import status
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
from starlette.responses import Response, JSONResponse, StreamingResponse
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||
import asyncio
|
||||
|
||||
import logging
|
||||
@@ -19,6 +23,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_project_id
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||
|
||||
@@ -28,28 +33,214 @@ async def get_asset(
|
||||
asset_id: str,
|
||||
request: Request,
|
||||
thumbnail: bool = False,
|
||||
dao: DAO = Depends(get_dao)
|
||||
dao: DAO = Depends(get_dao),
|
||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||
) -> Response:
|
||||
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
||||
asset = await dao.assets.get_asset(asset_id)
|
||||
# 2. Проверка на существование
|
||||
# Загружаем только метаданные (без data/thumbnail bytes)
|
||||
asset = await dao.assets.get_asset(asset_id, with_data=False)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
headers = {
|
||||
# Кэшировать на 1 год (31536000 сек)
|
||||
"Cache-Control": "public, max-age=31536000, immutable"
|
||||
base_headers = {
|
||||
"Cache-Control": "public, max-age=31536000, immutable",
|
||||
"Accept-Ranges": "bytes"
|
||||
}
|
||||
|
||||
content = asset.data
|
||||
media_type = "image/png" # Default, or detect
|
||||
# Thumbnail: маленький, можно грузить в RAM
|
||||
if thumbnail:
|
||||
if asset.minio_thumbnail_object_name and s3_adapter:
|
||||
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
||||
if thumb_bytes:
|
||||
return Response(content=thumb_bytes, media_type="image/jpeg", headers=base_headers)
|
||||
# Fallback: thumbnail in DB
|
||||
if asset.thumbnail:
|
||||
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=base_headers)
|
||||
# No thumbnail available — fall through to main content
|
||||
|
||||
if thumbnail and asset.thumbnail:
|
||||
content = asset.thumbnail
|
||||
media_type = "image/jpeg"
|
||||
# Main content: стримим из S3 без загрузки в RAM
|
||||
if asset.minio_object_name and s3_adapter:
|
||||
content_type = "image/png"
|
||||
if asset.content_type == AssetContentType.VIDEO:
|
||||
content_type = "video/mp4" # Or detect from extension if stored
|
||||
elif asset.content_type == AssetContentType.IMAGE:
|
||||
content_type = "image/png" # Default for images
|
||||
|
||||
return Response(content=content, media_type=media_type, headers=headers)
|
||||
# Better content type detection based on extension if possible, but for now this is okay
|
||||
if asset.minio_object_name.endswith(".mp4"):
|
||||
content_type = "video/mp4"
|
||||
elif asset.minio_object_name.endswith(".mov"):
|
||||
content_type = "video/quicktime"
|
||||
elif asset.minio_object_name.endswith(".jpg") or asset.minio_object_name.endswith(".jpeg"):
|
||||
content_type = "image/jpeg"
|
||||
|
||||
# Handle Range requests for video streaming
|
||||
range_header = request.headers.get("range")
|
||||
file_size = await s3_adapter.get_file_size(asset.minio_object_name)
|
||||
|
||||
if range_header and file_size:
|
||||
try:
|
||||
# Parse Range header: bytes=start-end
|
||||
byte_range = range_header.replace("bytes=", "")
|
||||
start_str, end_str = byte_range.split("-")
|
||||
start = int(start_str)
|
||||
end = int(end_str) if end_str else file_size - 1
|
||||
|
||||
# Validate range
|
||||
if start >= file_size:
|
||||
# 416 Range Not Satisfiable
|
||||
return Response(status_code=416, headers={"Content-Range": f"bytes */{file_size}"})
|
||||
|
||||
chunk_size = end - start + 1
|
||||
|
||||
headers = base_headers.copy()
|
||||
headers.update({
|
||||
"Content-Range": f"bytes {start}-{end}/{file_size}",
|
||||
"Content-Length": str(chunk_size),
|
||||
})
|
||||
|
||||
# Pass the exact range string to S3
|
||||
s3_range = f"bytes={start}-{end}"
|
||||
|
||||
return StreamingResponse(
|
||||
s3_adapter.stream_file(asset.minio_object_name, range_header=s3_range),
|
||||
status_code=206,
|
||||
headers=headers,
|
||||
media_type=content_type
|
||||
)
|
||||
except ValueError:
|
||||
pass # Fallback to full content if range parsing fails
|
||||
|
||||
# Full content response
|
||||
headers = base_headers.copy()
|
||||
if file_size:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
return StreamingResponse(
|
||||
s3_adapter.stream_file(asset.minio_object_name),
|
||||
media_type=content_type,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Fallback: data stored in DB (legacy)
|
||||
if asset.data:
|
||||
return Response(content=asset.data, media_type="image/png", headers=base_headers)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Asset data not found")
|
||||
|
||||
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||
async def delete_orphan_assets_from_minio(
|
||||
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
|
||||
minio_client: S3Adapter = Depends(get_s3_adapter),
|
||||
*,
|
||||
assets_collection: str = "assets",
|
||||
generations_collection: str = "generations",
|
||||
asset_type: str | None = "generated",
|
||||
project_id: str | None = None,
|
||||
dry_run: bool = True,
|
||||
mark_assets_deleted: bool = False,
|
||||
batch_size: int = 500,
|
||||
) -> dict[str, Any]:
|
||||
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||
assets = db[assets_collection]
|
||||
|
||||
match_assets: dict[str, Any] = {}
|
||||
if asset_type is not None:
|
||||
match_assets["type"] = asset_type
|
||||
if project_id is not None:
|
||||
match_assets["project_id"] = project_id
|
||||
|
||||
pipeline: list[dict[str, Any]] = [
|
||||
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||
{
|
||||
"$lookup": {
|
||||
"from": generations_collection,
|
||||
"let": {"assetIdStr": {"$toString": "$_id"}},
|
||||
"pipeline": [
|
||||
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
|
||||
{"$match": {"is_deleted": {"$ne": True}}},
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {
|
||||
"$in": [
|
||||
"$$assetIdStr",
|
||||
{"$ifNull": ["$result_list", []]},
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$limit": 1},
|
||||
],
|
||||
"as": "alive_generations",
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"_id": 1,
|
||||
"minio_object_name": 1,
|
||||
"minio_thumbnail_object_name": 1,
|
||||
}
|
||||
},
|
||||
]
|
||||
print(pipeline)
|
||||
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
|
||||
|
||||
deleted_objects = 0
|
||||
deleted_assets = 0
|
||||
errors: list[dict[str, Any]] = []
|
||||
orphan_asset_ids: list[ObjectId] = []
|
||||
|
||||
async for asset in cursor:
|
||||
aid = asset["_id"]
|
||||
obj = asset.get("minio_object_name")
|
||||
thumb = asset.get("minio_thumbnail_object_name")
|
||||
|
||||
orphan_asset_ids.append(aid)
|
||||
|
||||
if dry_run:
|
||||
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
|
||||
continue
|
||||
|
||||
try:
|
||||
if obj:
|
||||
await minio_client.delete_file(obj)
|
||||
deleted_objects += 1
|
||||
|
||||
if thumb:
|
||||
await minio_client.delete_file(thumb)
|
||||
deleted_objects += 1
|
||||
|
||||
deleted_assets += 1
|
||||
|
||||
except Exception as e:
|
||||
errors.append({"asset_id": str(aid), "error": str(e)})
|
||||
|
||||
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
|
||||
res = await assets.update_many(
|
||||
{"_id": {"$in": orphan_asset_ids}},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
marked = res.modified_count
|
||||
else:
|
||||
marked = 0
|
||||
|
||||
return {
|
||||
"dry_run": dry_run,
|
||||
"filter": {
|
||||
"asset_type": asset_type,
|
||||
"project_id": project_id,
|
||||
},
|
||||
"orphans_found": len(orphan_asset_ids),
|
||||
"deleted_assets": deleted_assets,
|
||||
"deleted_objects": deleted_objects,
|
||||
"marked_assets_deleted": marked,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||
async def delete_asset(
|
||||
@@ -68,11 +259,19 @@ async def delete_asset(
|
||||
|
||||
|
||||
@router.get("", dependencies=[Depends(get_current_user)])
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse:
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str | None = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: str | None = Depends(get_project_id)) -> AssetsResponse:
|
||||
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||
assets = await dao.assets.get_assets(type, limit, offset)
|
||||
|
||||
user_id_filter = current_user["id"]
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
|
||||
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
|
||||
total_count = await dao.assets.get_asset_count()
|
||||
total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
|
||||
|
||||
# Manually map to DTO to trigger computed fields validation if necessary,
|
||||
# but primarily to ensure valid Pydantic models for the response list.
|
||||
@@ -84,11 +283,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
|
||||
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)])
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
linked_char_id: Optional[str] = Form(None),
|
||||
linked_char_id: str | None = Form(None),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id)
|
||||
):
|
||||
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||
if not file.content_type:
|
||||
@@ -97,6 +298,11 @@ async def upload_asset(
|
||||
if not file.content_type.startswith("image/"):
|
||||
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
data = await file.read()
|
||||
if not data:
|
||||
raise HTTPException(status_code=400, detail="Empty file")
|
||||
@@ -111,7 +317,9 @@ async def upload_asset(
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=data,
|
||||
thumbnail=thumbnail_bytes
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=str(current_user["_id"]),
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
asset_id = await dao.assets.create_asset(asset)
|
||||
@@ -124,8 +332,7 @@ async def upload_asset(
|
||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
|
||||
linked_char_id=asset.linked_char_id,
|
||||
created_at=asset.created_at,
|
||||
url=asset.url
|
||||
created_at=asset.created_at
|
||||
)
|
||||
|
||||
|
||||
@@ -172,3 +379,4 @@ async def migrate_to_minio(dao: DAO = Depends(get_dao)):
|
||||
result = await dao.assets.migrate_to_minio()
|
||||
logger.info(f"Migration result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class Token(BaseModel):
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
full_name: str | None = None
|
||||
status: str
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from typing import List, Any, Coroutine
|
||||
from typing import Any, Coroutine
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from api.models import GenerationRequest, GenerationResponse
|
||||
from models.Asset import Asset
|
||||
from models.Character import Character
|
||||
from api.models import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
@@ -17,25 +18,61 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_project_id
|
||||
|
||||
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Character])
|
||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
|
||||
logger.info("get_characters called")
|
||||
characters = await dao.chars.get_all_characters()
|
||||
@router.get("/", response_model=list[Character])
|
||||
async def get_characters(
|
||||
request: Request,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> list[Character]:
|
||||
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
characters = await dao.chars.get_all_characters(
|
||||
created_by=user_id_filter,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
return characters
|
||||
|
||||
|
||||
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
||||
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
||||
offset: int = 0, ) -> AssetsResponse:
|
||||
offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
|
||||
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if character is None:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
# Access Check
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
||||
# Filter assets by user ownership as well?
|
||||
# Usually if you own character, you see its assets.
|
||||
# But assets also have specific created_by.
|
||||
# Let's assume if you own character you can see its assets.
|
||||
|
||||
total_count = await dao.assets.get_asset_count(character_id)
|
||||
|
||||
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
||||
@@ -43,14 +80,113 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
|
||||
|
||||
|
||||
@router.get("/{character_id}", response_model=Character)
|
||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
|
||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
|
||||
logger.debug(f"get_character_by_id called. ID: {character_id}")
|
||||
character = await dao.chars.get_character(character_id)
|
||||
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
if character:
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||
request: Request) -> GenerationResponse:
|
||||
logger.info(f"post_character_generation called. CharacterID: {character_id}")
|
||||
generation_service = request.app.state.generation_service
|
||||
@router.post("/", response_model=Character)
|
||||
async def create_character(
|
||||
char_req: CharacterCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Character:
|
||||
logger.info("create_character called")
|
||||
char_req.project_id = project_id
|
||||
char_data = char_req.model_dump()
|
||||
char_data["created_by"] = str(current_user["_id"])
|
||||
if "id" not in char_data:
|
||||
char_data["id"] = None
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
new_char = Character(**char_data)
|
||||
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
|
||||
created_char = await dao.chars.add_character(new_char)
|
||||
return created_char
|
||||
|
||||
|
||||
@router.put("/{character_id}", response_model=Character)
|
||||
async def update_character(
|
||||
character_id: str,
|
||||
char_update: CharacterUpdateRequest,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Character:
|
||||
logger.info(f"update_character called. ID: {character_id}")
|
||||
|
||||
existing_char = await dao.chars.get_character(character_id)
|
||||
if not existing_char:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
update_data = char_update.model_dump(exclude_unset=True)
|
||||
|
||||
if "project_id" in update_data and update_data["project_id"]:
|
||||
new_project_id = update_data["project_id"]
|
||||
project = await dao.projects.get_project(new_project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Target project access denied")
|
||||
|
||||
updated_char_data = existing_char.model_dump()
|
||||
updated_char_data.update(update_data)
|
||||
|
||||
updated_char = Character(**updated_char_data)
|
||||
|
||||
success = await dao.chars.update_char(character_id, updated_char)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update character")
|
||||
|
||||
return updated_char
|
||||
|
||||
|
||||
@router.delete("/{character_id}", status_code=204)
|
||||
async def delete_character(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"delete_character called. ID: {character_id}")
|
||||
|
||||
existing_char = await dao.chars.get_character(character_id)
|
||||
if not existing_char:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
success = await dao.chars.delete_character(character_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||
|
||||
return
|
||||
|
||||
191
api/endpoints/environment_router.py
Normal file
191
api/endpoints/environment_router.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||
from models.Environment import Environment
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/", response_model=Environment)
|
||||
async def create_environment(
|
||||
env_req: EnvironmentCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||
await check_character_access(env_req.character_id, current_user, dao)
|
||||
|
||||
# Verify assets exist if provided
|
||||
if env_req.asset_ids:
|
||||
for aid in env_req.asset_ids:
|
||||
asset = await dao.assets.get_asset(aid)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||
|
||||
new_env = Environment(**env_req.model_dump())
|
||||
created_env = await dao.environments.create_env(new_env)
|
||||
return created_env
|
||||
|
||||
|
||||
@router.get("/character/{character_id}", response_model=list[Environment])
|
||||
async def get_character_environments(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Getting environments for character {character_id}")
|
||||
await check_character_access(character_id, current_user, dao)
|
||||
return await dao.environments.get_character_envs(character_id)
|
||||
|
||||
|
||||
@router.get("/{env_id}", response_model=Environment)
|
||||
async def get_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
return env
|
||||
|
||||
|
||||
@router.put("/{env_id}", response_model=Environment)
|
||||
async def update_environment(
|
||||
env_id: str,
|
||||
env_update: EnvironmentUpdate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
update_data = env_update.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
return env
|
||||
|
||||
# Verify assets exist if provided
|
||||
if "asset_ids" in update_data:
|
||||
if update_data["asset_ids"] is None:
|
||||
del update_data["asset_ids"]
|
||||
elif update_data["asset_ids"]:
|
||||
# Verify all assets exist using batch check
|
||||
assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"])
|
||||
if len(assets) != len(update_data["asset_ids"]):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids]
|
||||
raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.update_env(env_id, update_data)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||
|
||||
return await dao.environments.get_env(env_id)
|
||||
|
||||
|
||||
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.delete_env(env_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||
async def add_asset_to_environment(
|
||||
env_id: str,
|
||||
req: AssetToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify asset exists
|
||||
asset = await dao.assets.get_asset(req.asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||
async def add_assets_batch_to_environment(
|
||||
env_id: str,
|
||||
req: AssetsToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify all assets exist
|
||||
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||
if len(assets) != len(req.asset_ids):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||
async def remove_asset_from_environment(
|
||||
env_id: str,
|
||||
asset_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||
return {"success": success}
|
||||
@@ -1,84 +1,258 @@
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from api import service
|
||||
from api.dependency import get_generation_service
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||
from config import settings
|
||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models import (
|
||||
GenerationResponse,
|
||||
GenerationRequest,
|
||||
GenerationsResponse,
|
||||
PromptResponse,
|
||||
PromptRequest,
|
||||
GenerationGroupResponse,
|
||||
FinancialReport,
|
||||
ExternalGenerationRequest,
|
||||
NsfwRequest
|
||||
)
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Generation import Generation
|
||||
|
||||
from starlette import status
|
||||
|
||||
import logging
|
||||
from repos.dao import DAO
|
||||
from utils.external_auth import verify_signature
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
|
||||
"""Helper to check if user has access to project."""
|
||||
if not project_id:
|
||||
return
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
async def ask_prompt_assistant(
|
||||
prompt_request: PromptRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(
|
||||
prompt_request.prompt,
|
||||
prompt_request.linked_assets,
|
||||
prompt_request.model
|
||||
)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.post("/prompt-from-image", response_model=PromptResponse)
|
||||
async def prompt_from_image(
|
||||
prompt: Optional[str] = Form(None),
|
||||
images: List[UploadFile] = File(...),
|
||||
generation_service: GenerationService = Depends(get_generation_service)
|
||||
prompt: str | None = Form(None),
|
||||
model: str = Form("gemini-3.1-pro-preview"),
|
||||
images: list[UploadFile] = File(...),
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
for image in images:
|
||||
content = await image.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
||||
images_bytes = [await img.read() for img in images]
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.get("", response_model=GenerationsResponse)
|
||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
|
||||
async def get_generations(
|
||||
character_id: str | None = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
only_liked: bool = False,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
# If project_id is set, we don't filter by user to show all project-wide generations
|
||||
created_by_filter = None if project_id else str(current_user["_id"])
|
||||
only_liked_by = str(current_user["_id"]) if only_liked else None
|
||||
|
||||
return await generation_service.get_generations(
|
||||
character_id=character_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
created_by=created_by_filter,
|
||||
project_id=project_id,
|
||||
only_liked_by=only_liked_by,
|
||||
current_user_id=str(current_user["_id"])
|
||||
)
|
||||
|
||||
|
||||
@router.post("/_run", response_model=GenerationResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> GenerationResponse:
|
||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||
return await generation_service.create_generation_task(generation)
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
async def get_usage_report(
|
||||
breakdown: str | None = None, # "user" or "project"
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"]) if not project_id else None
|
||||
breakdown_by = None
|
||||
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id,
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
return await generation_service.get_generation(generation_id)
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(
|
||||
generation: GenerationRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> GenerationGroupResponse:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
if project_id:
|
||||
generation.project_id = project_id
|
||||
|
||||
return await generation_service.create_generation_task(
|
||||
generation,
|
||||
user_id=str(current_user.get("_id"))
|
||||
)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service)):
|
||||
return await generation_service.get_running_generations()
|
||||
async def get_running_generations(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
user_id_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_running_generations(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||
async def get_generation_group(
|
||||
group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> GenerationResponse:
|
||||
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return gen
|
||||
|
||||
|
||||
@router.post("/{generation_id}/like", response_model=dict)
|
||||
async def toggle_like(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
|
||||
if is_liked is None:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return {"is_liked": is_liked}
|
||||
|
||||
|
||||
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_generation_nsfw(
|
||||
generation_id: str,
|
||||
request: NsfwRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw)
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
body = await request.body()
|
||||
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
if not secret:
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import external generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
if not await generation_service.delete_generation(generation_id):
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return None
|
||||
106
api/endpoints/idea_router.py
Normal file
106
api/endpoints/idea_router.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.idea_service import IdeaService
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Idea import Idea
|
||||
from api.models import GenerationResponse, GenerationsResponse
|
||||
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
|
||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
|
||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
||||
|
||||
@router.post("", response_model=Idea)
|
||||
async def create_idea(
|
||||
request: IdeaCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
return await idea_service.create_idea(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
inspiration_id=request.inspiration_id
|
||||
)
|
||||
|
||||
@router.get("", response_model=list[IdeaResponse])
|
||||
async def get_ideas(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
|
||||
|
||||
@router.get("/{idea_id}", response_model=Idea)
|
||||
async def get_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.get_idea(idea_id)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.put("/{idea_id}", response_model=Idea)
|
||||
async def update_idea(
|
||||
idea_id: str,
|
||||
request: IdeaUpdateRequest,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.update_idea(
|
||||
idea_id=idea_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
inspiration_id=request.inspiration_id
|
||||
)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.delete("/{idea_id}")
|
||||
async def delete_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.delete_idea(idea_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
|
||||
async def get_idea_generations(
|
||||
idea_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset, current_user_id=str(current_user["_id"]))
|
||||
|
||||
@router.post("/{idea_id}/generations/{generation_id}")
|
||||
async def add_generation_to_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.delete("/{idea_id}/generations/{generation_id}")
|
||||
async def remove_generation_from_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
94
api/endpoints/inspiration_router.py
Normal file
94
api/endpoints/inspiration_router.py
Normal file
@@ -0,0 +1,94 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from api.dependency import get_inspiration_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.InspirationRequest import InspirationCreateRequest, InspirationResponse, InspirationListResponse
|
||||
from api.service.inspiration_service import InspirationService
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
router = APIRouter(prefix="/api/inspirations", tags=["Inspirations"])
|
||||
|
||||
|
||||
@router.post("", response_model=InspirationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_inspiration(
|
||||
request: InspirationCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
service: InspirationService = Depends(get_inspiration_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
inspiration = await service.create_inspiration(
|
||||
source_url=request.source_url,
|
||||
created_by=str(current_user["_id"]),
|
||||
project_id=pid,
|
||||
caption=request.caption
|
||||
)
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.get("", response_model=InspirationListResponse)
|
||||
async def get_inspirations(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
service: InspirationService = Depends(get_inspiration_service)
|
||||
):
|
||||
# If project_id is provided, filter by it. Otherwise, filter by user.
|
||||
# Or maybe we want to see all user's inspirations if no project is selected?
|
||||
# Let's follow the pattern: if project_id is present, show project's inspirations.
|
||||
# If not, show user's personal inspirations (where project_id is None) OR all user's inspirations?
|
||||
# Usually "My Inspirations" means created by me.
|
||||
|
||||
# Let's assume:
|
||||
# If project_id -> filter by project_id (and maybe created_by if we want strict ownership, but usually project members share)
|
||||
# If no project_id -> filter by created_by (personal)
|
||||
|
||||
pid = project_id
|
||||
uid = str(current_user["_id"])
|
||||
|
||||
inspirations = await service.get_inspirations(project_id=pid, created_by=uid if not pid else None, limit=limit, offset=offset)
|
||||
total_count = await service.dao.inspirations.count_inspirations(project_id=pid, created_by=uid if not pid else None)
|
||||
|
||||
return InspirationListResponse(
|
||||
inspirations=[InspirationResponse(**inspiration.model_dump()) for inspiration in inspirations],
|
||||
total_count=total_count
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{inspiration_id}", response_model=InspirationResponse)
|
||||
async def get_inspiration(
|
||||
inspiration_id: str,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
inspiration = await service.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.patch("/{inspiration_id}/complete", response_model=InspirationResponse)
|
||||
async def mark_inspiration_complete(
|
||||
inspiration_id: str,
|
||||
is_completed: bool = True,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
inspiration = await service.mark_as_completed(inspiration_id, is_completed)
|
||||
if not inspiration:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.delete("/{inspiration_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_inspiration(
|
||||
inspiration_id: str,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
success = await service.delete_inspiration(inspiration_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return None
|
||||
98
api/endpoints/post_router.py
Normal file
98
api/endpoints/post_router.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.dependency import get_post_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.post_service import PostService
|
||||
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
from models.Post import Post
|
||||
|
||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
||||
|
||||
|
||||
@router.post("", response_model=Post)
|
||||
async def create_post(
|
||||
request: PostCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
return await post_service.create_post(
|
||||
date=request.date,
|
||||
topic=request.topic,
|
||||
generation_ids=request.generation_ids,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=list[Post])
|
||||
async def get_posts(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=Post)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=Post)
|
||||
async def update_post(
|
||||
post_id: str,
|
||||
request: PostUpdateRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.delete_post(post_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/generations")
|
||||
async def add_generations(
|
||||
post_id: str,
|
||||
request: AddGenerationsRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/generations/{generation_id}")
|
||||
async def remove_generation(
|
||||
post_id: str,
|
||||
generation_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.remove_generation(post_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
||||
return {"status": "success"}
|
||||
181
api/endpoints/project_router.py
Normal file
181
api/endpoints/project_router.py
Normal file
@@ -0,0 +1,181 @@
|
||||
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from models.Project import Project
|
||||
from repos.dao import DAO
|
||||
|
||||
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
class ProjectMemberResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
|
||||
class ProjectResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
owner_id: str
|
||||
members: list[ProjectMemberResponse]
|
||||
is_owner: bool = False
|
||||
|
||||
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
||||
member_responses = []
|
||||
for member_id in project.members:
|
||||
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
|
||||
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
|
||||
# But for Web users we might need to search by _id.
|
||||
# Let's try to get user info.
|
||||
# Since project.members contains strings (ObjectIds as strings), we search by _id.
|
||||
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
|
||||
if not user_doc and member_id.isdigit():
|
||||
# Fallback for telegram IDs if they are stored as strings of digits
|
||||
user_doc = await dao.users.get_user(int(member_id))
|
||||
|
||||
username = "unknown"
|
||||
if user_doc:
|
||||
username = user_doc.get("username", "unknown")
|
||||
|
||||
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
|
||||
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
description=project.description,
|
||||
owner_id=project.owner_id,
|
||||
members=member_responses,
|
||||
is_owner=(project.owner_id == current_user_id)
|
||||
)
|
||||
|
||||
@router.post("", response_model=ProjectResponse)
|
||||
async def create_project(
|
||||
project_data: ProjectCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
new_project = Project(
|
||||
name=project_data.name,
|
||||
description=project_data.description,
|
||||
owner_id=user_id,
|
||||
members=[user_id]
|
||||
)
|
||||
project_id = await dao.projects.create_project(new_project)
|
||||
new_project.id = project_id
|
||||
|
||||
# Add project to user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return await _get_project_response(new_project, user_id, dao)
|
||||
|
||||
@router.get("", response_model=list[ProjectResponse])
|
||||
async def get_my_projects(
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
projects = await dao.projects.get_projects_by_user(user_id)
|
||||
|
||||
responses = []
|
||||
for p in projects:
|
||||
responses.append(await _get_project_response(p, user_id, dao))
|
||||
return responses
|
||||
|
||||
class MemberAdd(BaseModel):
|
||||
username: str
|
||||
|
||||
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
|
||||
async def add_member(
|
||||
project_id: str,
|
||||
member_data: MemberAdd,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if project.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Only owner can add members")
|
||||
|
||||
target_user = await dao.users.get_user_by_username(member_data.username)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
target_user_id = str(target_user["_id"])
|
||||
|
||||
if target_user_id in project.members:
|
||||
return {"message": "User already in project"}
|
||||
|
||||
await dao.projects.add_member(project_id, target_user_id)
|
||||
|
||||
# Update target user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": target_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Member added"}
|
||||
|
||||
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
|
||||
async def join_project(
|
||||
project_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
# Retrieve project to verify it exists
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = str(current_user["_id"])
|
||||
|
||||
# Check if user is ALREADY in project
|
||||
if user_id in project.members:
|
||||
return {"message": "Already a member"}
|
||||
|
||||
# Add member
|
||||
await dao.projects.add_member(project_id, user_id)
|
||||
|
||||
# Update user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Joined project"}
|
||||
|
||||
|
||||
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = str(current_user["_id"])
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
if project.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Only owner can delete project")
|
||||
|
||||
await dao.projects.delete_project(project_id)
|
||||
|
||||
# Remove project from user's project list
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$pull": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return {"message": "Project deleted"}
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -11,10 +10,10 @@ class AssetResponse(BaseModel):
|
||||
name: str
|
||||
type: str # uploaded / generated
|
||||
content_type: str # image / prompt
|
||||
linked_char_id: Optional[str] = None
|
||||
linked_char_id: str | None = None
|
||||
created_at: datetime
|
||||
url: Optional[str] = None
|
||||
url: str | None = None
|
||||
|
||||
class AssetsResponse(BaseModel):
|
||||
assets: List[AssetResponse]
|
||||
assets: list[AssetResponse]
|
||||
total_count: int
|
||||
17
api/models/CharacterDTO.py
Normal file
17
api/models/CharacterDTO.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CharacterCreateRequest(BaseModel):
|
||||
name: str
|
||||
character_bio: str
|
||||
character_image_doc_tg_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
project_id: str | None = None
|
||||
|
||||
class CharacterUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
character_bio: str | None = None
|
||||
character_image_doc_tg_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
project_id: str | None = None
|
||||
22
api/models/EnvironmentRequest.py
Normal file
22
api/models/EnvironmentRequest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EnvironmentCreate(BaseModel):
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] | None = []
|
||||
|
||||
|
||||
class EnvironmentUpdate(BaseModel):
|
||||
name: str | None = Field(None, min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AssetToEnvironment(BaseModel):
|
||||
asset_id: str
|
||||
|
||||
|
||||
class AssetsToEnvironment(BaseModel):
|
||||
asset_ids: list[str]
|
||||
40
api/models/ExternalGenerationDTO.py
Normal file
40
api/models/ExternalGenerationDTO.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class ExternalGenerationRequest(BaseModel):
|
||||
"""Request model for importing external generations."""
|
||||
|
||||
prompt: str
|
||||
tech_prompt: str | None = None
|
||||
|
||||
# Image can be provided as base64 string OR URL (one must be provided)
|
||||
image_data: str | None = Field(None, description="Base64-encoded image data")
|
||||
image_url: str | None = Field(None, description="URL to download image from")
|
||||
|
||||
nsfw: bool = False
|
||||
|
||||
# Generation metadata
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
quality: Quality = Quality.ONEK
|
||||
model: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
# Optional linking
|
||||
linked_character_id: str | None = None
|
||||
created_by: str = Field(..., description="User ID from external system")
|
||||
project_id: str | None = None
|
||||
|
||||
# Performance metrics
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
|
||||
def validate_image_source(self):
|
||||
"""Ensure at least one image source is provided."""
|
||||
if not self.image_data and not self.image_url:
|
||||
raise ValueError("Either image_data or image_url must be provided")
|
||||
if self.image_data and self.image_url:
|
||||
raise ValueError("Only one of image_data or image_url should be provided")
|
||||
17
api/models/FinancialUsageDTO.py
Normal file
17
api/models/FinancialUsageDTO.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
total_runs: int
|
||||
total_tokens: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_cost: float
|
||||
|
||||
class UsageByEntity(BaseModel):
|
||||
entity_id: str | None = None
|
||||
stats: UsageStats
|
||||
|
||||
class FinancialReport(BaseModel):
|
||||
summary: UsageStats
|
||||
by_user: list[UsageByEntity] | None = None
|
||||
by_project: list[UsageByEntity] | None = None
|
||||
@@ -1,55 +1,78 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.Generation import GenerationStatus
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from models.enums import AspectRatios, Quality, GenType, ImageModel, TextModel
|
||||
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
linked_character_id: str | None = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
quality: Quality = Quality.ONEK
|
||||
prompt: str
|
||||
telegram_id: Optional[int] = None
|
||||
model: ImageModel = Field(default=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW)
|
||||
telegram_id: int | None = None
|
||||
use_profile_image: bool = True
|
||||
assets_list: List[str]
|
||||
assets_list: list[str]
|
||||
environment_id: str | None = None
|
||||
project_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
nsfw: bool = False
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
|
||||
class NsfwRequest(BaseModel):
|
||||
is_nsfw: bool
|
||||
|
||||
|
||||
class GenerationsResponse(BaseModel):
|
||||
generations: List["GenerationResponse"]
|
||||
generations: list["GenerationResponse"]
|
||||
total_count: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
id: str
|
||||
status: GenerationStatus
|
||||
failed_reason: Optional[str] = None
|
||||
|
||||
linked_character_id: Optional[str] = None
|
||||
failed_reason: str | None = None
|
||||
project_id: str | None = None
|
||||
linked_character_id: str | None = None
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
tech_prompt: Optional[str] = None
|
||||
assets_list: List[str]
|
||||
result_list: List[str] = []
|
||||
result: Optional[str] = None
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
model: ImageModel | None = None
|
||||
seed: int | None = None
|
||||
tech_prompt: str | None = None
|
||||
assets_list: list[str]
|
||||
result_list: list[str] = []
|
||||
result: str | None = None
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
progress: int = 0
|
||||
cost: float | None = None
|
||||
created_by: str | None = None
|
||||
generation_group_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
likes_count: int = 0
|
||||
is_liked: bool = False
|
||||
nsfw: bool = False
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
|
||||
|
||||
class GenerationGroupResponse(BaseModel):
|
||||
generation_group_id: str
|
||||
generations: list[GenerationResponse]
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
linked_assets: List[str] = []
|
||||
model: TextModel = Field(default=TextModel.GEMINI_3_1_PRO_PREVIEW)
|
||||
linked_assets: list[str] = []
|
||||
|
||||
|
||||
class PromptResponse(BaseModel):
|
||||
|
||||
17
api/models/IdeaRequest.py
Normal file
17
api/models/IdeaRequest.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
|
||||
class IdeaCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
project_id: str | None = None # Optional in body if passed via header/dependency
|
||||
inspiration_id: str | None = None
|
||||
|
||||
class IdeaUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
inspiration_id: str | None = None
|
||||
|
||||
class IdeaResponse(Idea):
|
||||
last_generation: GenerationResponse | None = None
|
||||
28
api/models/InspirationRequest.py
Normal file
28
api/models/InspirationRequest.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
|
||||
class InspirationCreateRequest(BaseModel):
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class InspirationResponse(BaseModel):
|
||||
id: str
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
asset_id: str
|
||||
is_completed: bool
|
||||
created_by: str
|
||||
project_id: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class InspirationListResponse(BaseModel):
|
||||
inspirations: list[InspirationResponse]
|
||||
total_count: int
|
||||
18
api/models/PostRequest.py
Normal file
18
api/models/PostRequest.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PostCreateRequest(BaseModel):
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: list[str] = []
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class PostUpdateRequest(BaseModel):
|
||||
date: datetime | None = None
|
||||
topic: str | None = None
|
||||
|
||||
|
||||
class AddGenerationsRequest(BaseModel):
|
||||
generation_ids: list[str]
|
||||
@@ -0,0 +1,7 @@
|
||||
from .AssetDTO import AssetResponse, AssetsResponse
|
||||
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from .ExternalGenerationDTO import ExternalGenerationRequest
|
||||
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
||||
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse, NsfwRequest
|
||||
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,66 +1,73 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import (
|
||||
FinancialReport, UsageStats, UsageByEntity,
|
||||
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
)
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from utils.image_utils import create_thumbnail
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limit concurrent generations to 4
|
||||
generation_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
prompt: str,
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
model: str,
|
||||
gemini: GoogleAdapter,
|
||||
|
||||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||
"""
|
||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||
Возвращает список байтов сгенерированных изображений.
|
||||
Wrapper for calling Gemini's synchronous method in a separate thread.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||
result = await asyncio.to_thread(
|
||||
gemini.generate_image,
|
||||
prompt=prompt,
|
||||
images_list=media_group_bytes,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
model=model,
|
||||
)
|
||||
generated_images_io, metrics = result
|
||||
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
except GoogleGenerationException as e:
|
||||
raise e
|
||||
except GoogleGenerationException:
|
||||
raise
|
||||
finally:
|
||||
del media_group_bytes
|
||||
|
||||
images_bytes = []
|
||||
if generated_images_io:
|
||||
for img_io in generated_images_io:
|
||||
# Читаем байты из BytesIO
|
||||
img_io.seek(0)
|
||||
content = img_io.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
# Закрываем поток
|
||||
images_bytes.append(img_io.read())
|
||||
img_io.close()
|
||||
del generated_images_io
|
||||
|
||||
return images_bytes, metrics
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
self.dao = dao
|
||||
@@ -68,167 +75,109 @@ class GenerationService:
|
||||
self.s3_adapter = s3_adapter
|
||||
self.bot = bot
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||
future_prompt += prompt
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
future_prompt = (
|
||||
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
||||
"User may upload reference image too. I will provide sources prompt entered by user. "
|
||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||
f"USER_ENTERED_PROMPT: {prompt}"
|
||||
)
|
||||
assets_data = []
|
||||
if assets is not None:
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db)
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||||
logger.info(future_prompt)
|
||||
logger.info(generated_prompt)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
return generated_prompt
|
||||
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||
if user_prompt:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
||||
Generation]:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id)
|
||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
current_user_id = kwargs.pop('current_user_id', None)
|
||||
generations = await self.dao.generations.get_generations(**kwargs)
|
||||
total_count = await self.dao.generations.count_generations(
|
||||
character_id=kwargs.get('character_id'),
|
||||
created_by=kwargs.get('created_by'),
|
||||
project_id=kwargs.get('project_id'),
|
||||
idea_id=kwargs.get('idea_id'),
|
||||
only_liked_by=kwargs.get('only_liked_by')
|
||||
)
|
||||
return GenerationsResponse(
|
||||
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
|
||||
total_count=total_count
|
||||
)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if gen is None:
|
||||
return None
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
return self._map_to_response(gen, current_user_id) if gen else None
|
||||
|
||||
async def get_running_generations(self) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
||||
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||
return await self.dao.generations.toggle_like(generation_id, user_id)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
generations = await self.dao.generations.get_generations_by_group(group_id)
|
||||
return GenerationGroupResponse(
|
||||
generation_group_id=group_id,
|
||||
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
|
||||
)
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump())
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
|
||||
res = GenerationResponse(**gen.model_dump())
|
||||
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
|
||||
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
|
||||
return res
|
||||
|
||||
async def runner(gen):
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||
|
||||
asyncio.create_task(runner(generation_model))
|
||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
if generation_group_id is None:
|
||||
generation_group_id = str(uuid4())
|
||||
|
||||
return GenerationResponse(**generation_model.model_dump())
|
||||
|
||||
except Exception:
|
||||
# если не успели создать запись — нечего помечать
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
results = []
|
||||
for _ in range(generation_request.count):
|
||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||
results.append(gen_response)
|
||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||
|
||||
async def create_generation(self, generation: Generation):
|
||||
start_time = datetime.now()
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = f"""
|
||||
# 1. Prepare input
|
||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||
|
||||
Create detailed image of character in scene.
|
||||
|
||||
SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
Rules:
|
||||
- Integrate the character's appearance naturally into the scene description.
|
||||
- Focus on lighting, texture, and composition.
|
||||
"""
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
media_group_bytes.append(char_info.character_image_data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
|
||||
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
||||
|
||||
# 3. Запускаем процесс генерации и симуляцию прогресса
|
||||
# 2. Run generation with progress simulation
|
||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||
|
||||
try:
|
||||
|
||||
# Default to Image Generation (Gemini)
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt, # или request.prompt
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
|
||||
# 3. Process results
|
||||
created_assets = await self._process_generated_images(generation, generated_bytes_list)
|
||||
|
||||
# Update metrics from API (Common for both)
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
# 4. Finalize generation record
|
||||
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
|
||||
|
||||
except GoogleGenerationException as e:
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Тут стоит добавить логирование ошибки
|
||||
logging.error(f"Generation failed: {e}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
# 5. Notify
|
||||
if generation.telegram_id and self.bot:
|
||||
await self._notify_telegram(generation, created_assets)
|
||||
finally:
|
||||
if not progress_task.done():
|
||||
progress_task.cancel()
|
||||
@@ -237,103 +186,53 @@ class GenerationService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||
created_assets: List[Asset] = []
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
external_gen.validate_image_source()
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
|
||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||
# Generate thumbnail
|
||||
thumbnail_bytes = None
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
||||
image_bytes = await self._fetch_external_image(external_gen)
|
||||
|
||||
# Save to S3
|
||||
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
data=None, # Not storing bytes in DB anymore
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes
|
||||
# Reuse internal processing logic
|
||||
new_asset = await self._save_asset(
|
||||
image_bytes=image_bytes,
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
folder="external"
|
||||
)
|
||||
|
||||
# Сохраняем в БД
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||||
|
||||
created_assets.append(new_asset)
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = [a.id for a in created_assets]
|
||||
|
||||
generation.result_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = generation_prompt
|
||||
|
||||
end_time = datetime.now()
|
||||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
||||
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
# 6. Send to Telegram if telegram_id is provided
|
||||
if generation.telegram_id and self.bot:
|
||||
try:
|
||||
for asset in created_assets:
|
||||
if asset.data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
||||
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
||||
generation = Generation(
|
||||
status=GenerationStatus.DONE,
|
||||
linked_character_id=external_gen.linked_character_id,
|
||||
aspect_ratio=external_gen.aspect_ratio,
|
||||
quality=external_gen.quality,
|
||||
prompt=external_gen.prompt,
|
||||
model=external_gen.model,
|
||||
tech_prompt=external_gen.tech_prompt,
|
||||
seed=external_gen.seed,
|
||||
result_list=[new_asset.id],
|
||||
result=new_asset.id,
|
||||
progress=100,
|
||||
nsfw=external_gen.nsfw,
|
||||
execution_time_seconds=external_gen.execution_time_seconds,
|
||||
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||
token_usage=external_gen.token_usage,
|
||||
input_token_usage=external_gen.input_token_usage,
|
||||
output_token_usage=external_gen.output_token_usage,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
)
|
||||
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
"""
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
# Random increment between 5 and 15
|
||||
increment = random.randint(5, 15)
|
||||
current_progress = min(current_progress + increment, 90)
|
||||
|
||||
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
||||
# But for simplicity here we just use the object we have and save it.
|
||||
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
||||
# Assuming simple update is fine for now.
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
# Task cancelled, generation finished (or failed)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in progress simulation: {e}")
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
return generation
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
Soft delete generation by marking it as deleted.
|
||||
"""
|
||||
try:
|
||||
generation = await self.dao.generations.get_generation(generation_id)
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
generation.is_deleted = True
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
@@ -341,3 +240,207 @@ class GenerationService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def cleanup_stale_generations(self):
|
||||
try:
|
||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} stale generations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up stale generations: {e}")
|
||||
|
||||
async def cleanup_old_data(self, days: int = 30):
|
||||
try:
|
||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
||||
if gen_count > 0:
|
||||
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
|
||||
if asset_ids:
|
||||
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during old data cleanup: {e}")
|
||||
|
||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||
summary = UsageStats(**summary_data)
|
||||
|
||||
by_user, by_project = None, None
|
||||
if breakdown_by == "created_by":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||
by_user = [UsageByEntity(**item) for item in res]
|
||||
if breakdown_by == "project_id":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||
by_project = [UsageByEntity(**item) for item in res]
|
||||
|
||||
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
|
||||
|
||||
# --- Private Helpers ---
|
||||
|
||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
|
||||
try:
|
||||
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
gen_model.created_by = user_id
|
||||
gen_model.generation_group_id = generation_group_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(gen_model)
|
||||
gen_model.id = gen_id
|
||||
|
||||
asyncio.create_task(self._queued_generation_runner(gen_model))
|
||||
return GenerationResponse(**gen_model.model_dump())
|
||||
except Exception:
|
||||
logger.exception("Failed to initiate single generation")
|
||||
raise
|
||||
|
||||
async def _queued_generation_runner(self, gen: Generation):
|
||||
logger.info(f"Generation {gen.id} waiting for slot...")
|
||||
try:
|
||||
async with generation_semaphore:
|
||||
await self.create_generation(gen)
|
||||
except Exception as e:
|
||||
await self._handle_generation_failure(gen, e)
|
||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||
media_group_bytes: List[bytes] = []
|
||||
prompt = generation.prompt
|
||||
|
||||
# 1. Character Avatar
|
||||
if generation.linked_character_id:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if not char_info:
|
||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
data = await self._get_asset_data_bytes(avatar_asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 2. Reference Assets
|
||||
if generation.assets_list:
|
||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 3. Environment Assets
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
if media_group_bytes:
|
||||
prompt += (
|
||||
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
|
||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||
)
|
||||
|
||||
return media_group_bytes, prompt
|
||||
|
||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
|
||||
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
|
||||
logger.error(f"Generation {generation.id} failed: {error}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
# Don't overwrite if reason is already set, unless a new error is provided
|
||||
if error:
|
||||
generation.failed_reason = str(error)
|
||||
elif not generation.failed_reason:
|
||||
generation.failed_reason = "Unknown error"
|
||||
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
|
||||
created_assets = []
|
||||
for img_bytes in bytes_list:
|
||||
asset = await self._save_asset(
|
||||
image_bytes=img_bytes,
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
folder="generated"
|
||||
)
|
||||
created_assets.append(asset)
|
||||
return created_assets
|
||||
|
||||
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=name,
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=None,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
return new_asset
|
||||
|
||||
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
|
||||
generation.result_list = [a.id for a in assets]
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = tech_prompt
|
||||
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
|
||||
try:
|
||||
for asset in assets:
|
||||
# Need to get data for telegram if it's not in Asset object
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
|
||||
if img_data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
|
||||
caption=f"Generated from: {generation.prompt[:100]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to Telegram: {e}")
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
current_progress = min(current_progress + random.randint(5, 15), 90)
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _fetch_external_image(self, external_gen) -> bytes:
|
||||
if external_gen.image_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif external_gen.image_data:
|
||||
return base64.b64decode(external_gen.image_data)
|
||||
raise ValueError("No image source provided")
|
||||
|
||||
82
api/service/idea_service.py
Normal file
82
api/service/idea_service.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str, inspiration_id: Optional[str] = None) -> Idea:
|
||||
idea = Idea(
|
||||
name=name,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
inspiration_id=inspiration_id
|
||||
)
|
||||
idea_id = await self.dao.ideas.create_idea(idea)
|
||||
idea.id = idea_id
|
||||
return idea
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
return await self.dao.ideas.get_idea(idea_id)
|
||||
|
||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None, inspiration_id: Optional[str] = None) -> Optional[Idea]:
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
idea.name = name
|
||||
if description is not None:
|
||||
idea.description = description
|
||||
if inspiration_id is not None:
|
||||
idea.inspiration_id = inspiration_id
|
||||
|
||||
idea.updated_at = datetime.now()
|
||||
await self.dao.ideas.update_idea(idea)
|
||||
return idea
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
return await self.dao.ideas.delete_idea(idea_id)
|
||||
|
||||
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Link
|
||||
gen.idea_id = idea_id
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists (optional, but good for validation)
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Unlink only if currently linked to this idea
|
||||
if gen.idea_id == idea_id:
|
||||
gen.idea_id = None
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
return False
|
||||
146
api/service/inspiration_service.py
Normal file
146
api/service/inspiration_service.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Inspiration import Inspiration
|
||||
from repos.dao import DAO
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
# Try to import yt_dlp, but don't crash if it's missing (though we added it to requirements)
|
||||
try:
|
||||
import yt_dlp
|
||||
except ImportError:
|
||||
yt_dlp = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InspirationService:
|
||||
def __init__(self, dao: DAO, s3_adapter: S3Adapter):
|
||||
self.dao = dao
|
||||
self.s3_adapter = s3_adapter
|
||||
|
||||
async def create_inspiration(self, source_url: str, created_by: str, project_id: Optional[str] = None, caption: Optional[str] = None) -> Inspiration:
|
||||
# 1. Download content from Instagram
|
||||
try:
|
||||
content_bytes, content_type, ext = await self._download_content(source_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download content from {source_url}: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to download content: {str(e)}")
|
||||
|
||||
# 2. Save as Asset
|
||||
filename = f"inspirations/{datetime.now().strftime('%Y%m%d_%H%M%S')}_insta.{ext}"
|
||||
|
||||
await self.s3_adapter.upload_file(filename, content_bytes, content_type=content_type)
|
||||
|
||||
asset = Asset(
|
||||
name=f"Inspiration from {source_url}",
|
||||
type=AssetType.INSPIRATION,
|
||||
content_type=AssetContentType.VIDEO if content_type.startswith("video") else AssetContentType.IMAGE,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
asset_id = await self.dao.assets.create_asset(asset)
|
||||
|
||||
# 3. Create Inspiration object
|
||||
inspiration = Inspiration(
|
||||
source_url=source_url,
|
||||
caption=caption,
|
||||
asset_id=str(asset_id),
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
insp_id = await self.dao.inspirations.create_inspiration(inspiration)
|
||||
inspiration.id = insp_id
|
||||
|
||||
return inspiration
|
||||
|
||||
async def get_inspirations(self, project_id: Optional[str], created_by: str, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||
return await self.dao.inspirations.get_inspirations(project_id, created_by, limit, offset)
|
||||
|
||||
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||
return await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
|
||||
async def mark_as_completed(self, inspiration_id: str, is_completed: bool = True) -> Optional[Inspiration]:
|
||||
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
return None
|
||||
|
||||
inspiration.is_completed = is_completed
|
||||
inspiration.updated_at = datetime.now()
|
||||
await self.dao.inspirations.update_inspiration(inspiration)
|
||||
return inspiration
|
||||
|
||||
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
return False
|
||||
|
||||
# Delete associated asset
|
||||
if inspiration.asset_id:
|
||||
await self.dao.assets.delete_asset(inspiration.asset_id)
|
||||
|
||||
return await self.dao.inspirations.delete_inspiration(inspiration_id)
|
||||
|
||||
async def _download_content(self, url: str) -> Tuple[bytes, str, str]:
|
||||
"""
|
||||
Downloads content using yt-dlp.
|
||||
Returns (content_bytes, content_type, extension)
|
||||
"""
|
||||
if not yt_dlp:
|
||||
raise RuntimeError("yt-dlp is not installed")
|
||||
|
||||
logger.info(f"Downloading from {url} using yt-dlp...")
|
||||
|
||||
def run_yt_dlp():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ydl_opts = {
|
||||
'outtmpl': f'{tmpdirname}/%(id)s.%(ext)s',
|
||||
'quiet': True,
|
||||
'no_warnings': True,
|
||||
'format': 'best', # Best quality single file
|
||||
'noplaylist': True, # Only single video if it's a playlist/profile
|
||||
'writethumbnail': False,
|
||||
'writesubtitles': False,
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
ydl.download([url])
|
||||
|
||||
# Find the downloaded file
|
||||
files = os.listdir(tmpdirname)
|
||||
if not files:
|
||||
raise Exception("No files downloaded")
|
||||
|
||||
# Pick the largest file if multiple (e.g. if yt-dlp downloaded parts)
|
||||
# But with 'format': 'best', it should be one.
|
||||
# If carousel, it might be multiple. Let's pick the first one.
|
||||
filename = files[0]
|
||||
filepath = os.path.join(tmpdirname, filename)
|
||||
|
||||
with open(filepath, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
ext = filename.split('.')[-1].lower()
|
||||
|
||||
# Determine content type
|
||||
if ext in ['mp4', 'mov', 'avi', 'mkv', 'webm']:
|
||||
content_type = f"video/{ext}"
|
||||
if ext == 'mov': content_type = "video/quicktime"
|
||||
elif ext in ['jpg', 'jpeg', 'png', 'webp']:
|
||||
content_type = f"image/{ext}"
|
||||
if ext == 'jpg': content_type = "image/jpeg"
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
return data, content_type, ext
|
||||
|
||||
return await asyncio.to_thread(run_yt_dlp)
|
||||
79
api/service/post_service.py
Normal file
79
api/service/post_service.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from repos.dao import DAO
|
||||
from models.Post import Post
|
||||
|
||||
|
||||
class PostService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_post(
|
||||
self,
|
||||
date: datetime,
|
||||
topic: str,
|
||||
generation_ids: List[str],
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Post:
|
||||
post = Post(
|
||||
date=date,
|
||||
topic=topic,
|
||||
generation_ids=generation_ids,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
)
|
||||
post_id = await self.dao.posts.create_post(post)
|
||||
post.id = post_id
|
||||
return post
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
||||
|
||||
async def update_post(
|
||||
self,
|
||||
post_id: str,
|
||||
date: Optional[datetime] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> Optional[Post]:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return None
|
||||
|
||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
||||
if date is not None:
|
||||
updates["date"] = date
|
||||
if topic is not None:
|
||||
updates["topic"] = topic
|
||||
|
||||
await self.dao.posts.update_post(post_id, updates)
|
||||
|
||||
# Return refreshed post
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
return await self.dao.posts.delete_post(post_id)
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
||||
39
config.py
Normal file
39
config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Telegram Bot
|
||||
BOT_TOKEN: str
|
||||
ADMIN_ID: int = 0
|
||||
|
||||
# AI Service
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
# Database
|
||||
MONGO_HOST: str = "mongodb://localhost:27017"
|
||||
DB_NAME: str = "my_bot_db"
|
||||
|
||||
# S3 Storage (Minio)
|
||||
MINIO_ENDPOINT: str = "http://localhost:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "ai-char"
|
||||
|
||||
# External API
|
||||
EXTERNAL_API_SECRET: Optional[str] = None
|
||||
|
||||
# JWT Security
|
||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.getenv("ENV_FILE", ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
|
||||
# Ждем сбора остальных частей
|
||||
await asyncio.sleep(self.latency)
|
||||
|
||||
# Проверяем, что ключ все еще существует (на всякий случай)
|
||||
# Проверяем, что ключ все еще существует
|
||||
if group_id in self.album_data:
|
||||
# Передаем собранный альбом в хендлер
|
||||
# Сортируем по message_id, чтобы порядок был верным
|
||||
self.album_data[group_id].sort(key=lambda x: x.message_id)
|
||||
data["album"] = self.album_data[group_id]
|
||||
current_album = self.album_data[group_id]
|
||||
current_album.sort(key=lambda x: x.message_id)
|
||||
data["album"] = current_album
|
||||
return await handler(event, data)
|
||||
|
||||
finally:
|
||||
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
|
||||
# Проверяем, что мы удаляем именно то, что создали, и ключ существует
|
||||
if group_id in self.album_data and self.album_data[group_id][0] == event:
|
||||
del self.album_data[group_id]
|
||||
# ЧИСТКА: Удаляем запись после обработки или таймаута
|
||||
# Используем pop() с дефолтом, чтобы избежать KeyError
|
||||
self.album_data.pop(group_id, None)
|
||||
|
||||
else:
|
||||
# Если группа уже собирается - просто добавляем и выходим
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Album(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
cover_asset_id: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
description: str | None = None
|
||||
cover_asset_id: str | None = None
|
||||
generation_ids: list[str] = []
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import Optional, Any, List
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, computed_field, Field, model_validator
|
||||
|
||||
@@ -8,26 +8,31 @@ from pydantic import BaseModel, computed_field, Field, model_validator
|
||||
class AssetContentType(str, Enum):
|
||||
IMAGE = 'image'
|
||||
PROMPT = 'prompt'
|
||||
VIDEO = 'video'
|
||||
|
||||
class AssetType(str, Enum):
|
||||
UPLOADED = 'uploaded'
|
||||
GENERATED = 'generated'
|
||||
INSPIRATION = 'inspiration'
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
name: str
|
||||
type: AssetType = AssetType.GENERATED
|
||||
content_type: AssetContentType = AssetContentType.IMAGE
|
||||
linked_char_id: Optional[str] = None
|
||||
data: Optional[bytes] = None
|
||||
tg_doc_file_id: Optional[str] = None
|
||||
tg_photo_file_id: Optional[str] = None
|
||||
minio_object_name: Optional[str] = None
|
||||
minio_bucket: Optional[str] = None
|
||||
minio_thumbnail_object_name: Optional[str] = None
|
||||
thumbnail: Optional[bytes] = None
|
||||
tags: List[str] = []
|
||||
linked_char_id: str | None = None
|
||||
data: bytes | None = None
|
||||
tg_doc_file_id: str | None = None
|
||||
tg_photo_file_id: str | None = None
|
||||
minio_object_name: str | None = None
|
||||
minio_bucket: str | None = None
|
||||
minio_thumbnail_object_name: str | None = None
|
||||
thumbnail: bytes | None = None
|
||||
tags: list[str] = []
|
||||
created_by: str | None = None
|
||||
project_id: str | None = None
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@@ -60,6 +65,7 @@ class Asset(BaseModel):
|
||||
|
||||
# --- CALCULATED FIELD ---
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""
|
||||
Это поле автоматически вычислится и попадет в model_dump() / .json()
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core.core_schema import computed_field
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
id: str | None
|
||||
id: str | None = None
|
||||
name: str
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_data: Optional[bytes] = None
|
||||
character_image_doc_tg_id: str
|
||||
character_image_tg_id: str | None
|
||||
character_bio: str
|
||||
|
||||
avatar_asset_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_doc_tg_id: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
character_bio: str | None = None
|
||||
created_by: str | None = None
|
||||
project_id: str | None = None
|
||||
|
||||
19
models/Environment.py
Normal file
19
models/Environment.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
id: str | None = Field(None, alias="_id")
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_encoders={ObjectId: str},
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -1,11 +1,9 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class GenerationStatus(str, Enum):
|
||||
@@ -14,25 +12,43 @@ class GenerationStatus(str, Enum):
|
||||
FAILED = "failed"
|
||||
|
||||
class Generation(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
status: GenerationStatus = GenerationStatus.RUNNING
|
||||
failed_reason: Optional[str] = None
|
||||
linked_character_id: Optional[str] = None
|
||||
telegram_id: Optional[int] = None
|
||||
failed_reason: str | None = None
|
||||
linked_character_id: str | None = None
|
||||
telegram_id: int | None = None
|
||||
use_profile_image: bool = True
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
tech_prompt: Optional[str] = None
|
||||
assets_list: List[str] = Field(default_factory=list)
|
||||
result_list: List[str] = Field(default_factory=list)
|
||||
result: Optional[str] = None
|
||||
model: str | None = None
|
||||
seed: int | None = None
|
||||
tech_prompt: str | None = None
|
||||
assets_list: list[str] = Field(default_factory=list)
|
||||
result_list: list[str] = Field(default_factory=list)
|
||||
result: str | None = None
|
||||
progress: int = 0
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
is_deleted: bool = False
|
||||
album_id: str | None = None
|
||||
environment_id: str | None = None
|
||||
generation_group_id: str | None = None
|
||||
created_by: str | None = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
liked_by: list[str] = Field(default_factory=list)
|
||||
nsfw: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@computed_field
|
||||
def cost(self) -> float:
|
||||
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
|
||||
cost_input = self.input_token_usage * 0.000002
|
||||
cost_output = self.output_token_usage * 0.00012
|
||||
return round(cost_input + cost_output, 3)
|
||||
return 0.0
|
||||
|
||||
13
models/Idea.py
Normal file
13
models/Idea.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Idea(BaseModel):
|
||||
id: str | None = None
|
||||
name: str = "New Idea"
|
||||
description: str | None = None
|
||||
project_id: str | None = None
|
||||
inspiration_id: str | None = None # Link to Inspiration
|
||||
created_by: str # User ID
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
15
models/Inspiration.py
Normal file
15
models/Inspiration.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Inspiration(BaseModel):
|
||||
id: str | None = None
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
asset_id: str
|
||||
is_completed: bool = False
|
||||
created_by: str
|
||||
project_id: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
22
models/Post.py
Normal file
22
models/Post.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime, timezone, UTC
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
id: str | None = None
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: list[str] = Field(default_factory=list)
|
||||
project_id: str | None = None
|
||||
created_by: str
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_tz_aware(self):
|
||||
for field in ("date", "created_at", "updated_at"):
|
||||
val = getattr(self, field)
|
||||
if val is not None and val.tzinfo is None:
|
||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
||||
return self
|
||||
11
models/Project.py
Normal file
11
models/Project.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Project(BaseModel):
|
||||
id: str | None = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
owner_id: str
|
||||
members: list[str] = [] # List of User IDs
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,19 +2,30 @@ from enum import Enum
|
||||
|
||||
|
||||
class AspectRatios(str, Enum):
|
||||
NINESIXTEEN = "NINESIXTEEN"
|
||||
SIXTEENNINE = "SIXTEENNINE"
|
||||
THREEFOUR = "THREEFOUR"
|
||||
FOURTHREE = "FOURTHREE"
|
||||
ONEONE = "1:1"
|
||||
TWOTHREE = "2:3"
|
||||
THREETWO = "3:2"
|
||||
THREEFOUR = "3:4"
|
||||
FOURTHREE = "4:3"
|
||||
FOURFIVE = "4:5"
|
||||
FIVEFOUR = "5:4"
|
||||
NINESIXTEEN = "9:16"
|
||||
SIXTEENNINE = "16:9"
|
||||
TWENTYONENINE = "21:9"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
mapping = {
|
||||
"NINESIXTEEN": cls.NINESIXTEEN,
|
||||
"SIXTEENNINE": cls.SIXTEENNINE,
|
||||
"THREEFOUR": cls.THREEFOUR,
|
||||
"FOURTHREE": cls.FOURTHREE,
|
||||
}
|
||||
return mapping.get(value)
|
||||
|
||||
@property
|
||||
def value_ratio(self) -> str:
|
||||
return {
|
||||
AspectRatios.NINESIXTEEN: "9:16",
|
||||
AspectRatios.SIXTEENNINE: "16:9",
|
||||
AspectRatios.THREEFOUR: "3:4",
|
||||
AspectRatios.FOURTHREE: "4:3",
|
||||
}[self]
|
||||
return self.value
|
||||
|
||||
|
||||
class Quality(str, Enum):
|
||||
@@ -41,3 +52,20 @@ class GenType(str, Enum):
|
||||
GenType.TEXT: 'Text',
|
||||
GenType.IMAGE: 'Image',
|
||||
}[self]
|
||||
|
||||
|
||||
class TextModel(str, Enum):
|
||||
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"
|
||||
|
||||
@property
|
||||
def value_model(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class ImageModel(str, Enum):
|
||||
GEMINI_3_PRO_IMAGE_PREVIEW = "gemini-3-pro-image-preview"
|
||||
GEMINI_3_1_FLASH_IMAGE_PREVIEW = "gemini-3.1-flash-image-preview"
|
||||
|
||||
@property
|
||||
def value_model(self) -> str:
|
||||
return self.value
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,6 +1,8 @@
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
import logging
|
||||
from datetime import datetime, UTC
|
||||
from bson import ObjectId
|
||||
from uuid import uuid4
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Asset import Asset
|
||||
@@ -19,7 +21,8 @@ class AssetsRepo:
|
||||
# Main data
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
|
||||
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||
if uploaded:
|
||||
@@ -32,7 +35,8 @@ class AssetsRepo:
|
||||
# Thumbnail
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
|
||||
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||
if uploaded_thumb:
|
||||
@@ -46,8 +50,8 @@ class AssetsRepo:
|
||||
res = await self.collection.insert_one(asset.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]:
|
||||
filter = {}
|
||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
|
||||
if asset_type:
|
||||
filter["type"] = asset_type
|
||||
args = {}
|
||||
@@ -70,6 +74,12 @@ class AssetsRepo:
|
||||
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
|
||||
# So list DOES NOT return thumbnails by default.
|
||||
args["thumbnail"] = 0
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
filter['project_id'] = None
|
||||
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||
assets = []
|
||||
@@ -92,7 +102,7 @@ class AssetsRepo:
|
||||
|
||||
return assets
|
||||
|
||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]:
|
||||
projection = None
|
||||
if not with_data:
|
||||
projection = {"data": 0, "thumbnail": 0}
|
||||
@@ -128,7 +138,8 @@ class AssetsRepo:
|
||||
if self.s3:
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
if await self.s3.upload_file(object_name, asset.data):
|
||||
asset.minio_object_name = object_name
|
||||
asset.minio_bucket = self.s3.bucket_name
|
||||
@@ -136,7 +147,8 @@ class AssetsRepo:
|
||||
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||
asset.minio_thumbnail_object_name = thumb_name
|
||||
asset.thumbnail = None
|
||||
@@ -157,11 +169,22 @@ class AssetsRepo:
|
||||
assets.append(Asset(**doc))
|
||||
return assets
|
||||
|
||||
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
|
||||
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
|
||||
async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||
filter = {}
|
||||
if character_id:
|
||||
filter["linked_char_id"] = character_id
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
return await self.collection.count_documents(filter)
|
||||
|
||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)]
|
||||
if not object_ids:
|
||||
return []
|
||||
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
|
||||
# Original excluded thumbnail too.
|
||||
assets = []
|
||||
@@ -184,6 +207,61 @@ class AssetsRepo:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
|
||||
"""
|
||||
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
|
||||
Возвращает количество обработанных ассетов.
|
||||
"""
|
||||
if not asset_ids:
|
||||
return 0
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
|
||||
if not object_ids:
|
||||
return 0
|
||||
|
||||
# Находим ассеты, которые ещё не удалены
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
|
||||
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
|
||||
)
|
||||
|
||||
purged_count = 0
|
||||
ids_to_update = []
|
||||
|
||||
async for doc in cursor:
|
||||
ids_to_update.append(doc["_id"])
|
||||
|
||||
# Жёсткое удаление файлов из S3
|
||||
if self.s3:
|
||||
if doc.get("minio_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
|
||||
if doc.get("minio_thumbnail_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
|
||||
|
||||
purged_count += 1
|
||||
|
||||
# Мягкое удаление + очистка ссылок на S3
|
||||
if ids_to_update:
|
||||
await self.collection.update_many(
|
||||
{"_id": {"$in": ids_to_update}},
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"minio_object_name": None,
|
||||
"minio_thumbnail_object_name": None,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return purged_count
|
||||
|
||||
async def migrate_to_minio(self) -> dict:
|
||||
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
||||
if not self.s3:
|
||||
@@ -203,7 +281,8 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
|
||||
if await self.s3.upload_file(object_name, data):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
@@ -230,7 +309,8 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, thumb):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
@@ -12,32 +12,37 @@ class CharacterRepo:
|
||||
|
||||
async def add_character(self, character: Character) -> Character:
|
||||
op = await self.collection.insert_one(character.model_dump())
|
||||
character.id = op.inserted_id
|
||||
character.id = str(op.inserted_id)
|
||||
return character
|
||||
|
||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||
args = {}
|
||||
if not with_image_data:
|
||||
args["character_image_data"] = 0
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
|
||||
async def get_character(self, character_id: str) -> Character | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)})
|
||||
if res is None:
|
||||
return None
|
||||
else:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Character(**res)
|
||||
|
||||
async def get_all_characters(self) -> List[Character]:
|
||||
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None)
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
|
||||
filter = {}
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
characters = []
|
||||
for doc in docs:
|
||||
# Конвертируем ObjectId в строку и кладем в поле id
|
||||
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
|
||||
chars = []
|
||||
for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
chars.append(Character(**doc))
|
||||
return chars
|
||||
|
||||
# Создаем объект
|
||||
characters.append(Character(**doc))
|
||||
async def update_char(self, char_id: str, character: Character) -> bool:
|
||||
result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||
return result.modified_count > 0
|
||||
|
||||
return characters
|
||||
|
||||
async def update_char(self, char_id: str, character: Character) -> None:
|
||||
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||
async def delete_character(self, char_id: str) -> bool:
|
||||
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
|
||||
return result.deleted_count > 0
|
||||
|
||||
11
repos/dao.py
11
repos/dao.py
@@ -5,6 +5,11 @@ from repos.char_repo import CharacterRepo
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
from repos.environment_repo import EnvironmentRepo
|
||||
from repos.inspiration_repo import InspirationRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -16,3 +21,9 @@ class DAO:
|
||||
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
||||
self.generations = GenerationRepo(client, db_name)
|
||||
self.albums = AlbumsRepo(client, db_name)
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
self.posts = PostRepo(client, db_name)
|
||||
self.environments = EnvironmentRepo(client, db_name)
|
||||
self.inspirations = InspirationRepo(client, db_name)
|
||||
|
||||
73
repos/environment_repo.py
Normal file
73
repos/environment_repo.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Environment import Environment
|
||||
|
||||
|
||||
class EnvironmentRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["environments"]
|
||||
|
||||
async def create_env(self, env: Environment) -> Environment:
|
||||
env_dict = env.model_dump(exclude={"id"})
|
||||
res = await self.collection.insert_one(env_dict)
|
||||
env.id = str(res.inserted_id)
|
||||
return env
|
||||
|
||||
async def get_env(self, env_id: str) -> Optional[Environment]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(env_id)})
|
||||
if not res:
|
||||
return None
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Environment(**res)
|
||||
|
||||
async def get_character_envs(self, character_id: str) -> List[Environment]:
|
||||
cursor = self.collection.find({"character_id": character_id})
|
||||
envs = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
envs.append(Environment(**doc))
|
||||
return envs
|
||||
|
||||
async def update_env(self, env_id: str, update_data: dict) -> bool:
|
||||
update_data["updated_at"] = datetime.utcnow()
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{"$set": update_data}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_env(self, env_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def add_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": {"$each": asset_ids}},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$pull": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, List
|
||||
from typing import Any, Optional, List
|
||||
from datetime import datetime, timedelta, UTC
|
||||
|
||||
from PIL.ImageChops import offset
|
||||
from bson import ObjectId
|
||||
@@ -16,7 +17,7 @@ class GenerationRepo:
|
||||
res = await self.collection.insert_one(generation.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
async def get_generation(self, generation_id: str) -> Generation | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
||||
if res is None:
|
||||
return None
|
||||
@@ -25,14 +26,32 @@ class GenerationRepo:
|
||||
return Generation(**res)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
limit: int = 10, offset: int = 10) -> List[Generation]:
|
||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> List[Generation]:
|
||||
|
||||
filter = {"is_deleted": False}
|
||||
filter: dict[str, Any] = {"is_deleted": False}
|
||||
if character_id is not None:
|
||||
filter["linked_character_id"] = character_id
|
||||
if status is not None:
|
||||
filter["status"] = status
|
||||
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||
if created_by is not None:
|
||||
filter["created_by"] = created_by
|
||||
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
|
||||
# But if project_id is passed, we filter by that.
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id is not None:
|
||||
filter["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
filter["idea_id"] = idea_id
|
||||
if only_liked_by is not None:
|
||||
filter["liked_by"] = only_liked_by
|
||||
|
||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
||||
# Otherwise typically descending (newest first)
|
||||
sort_order = 1 if idea_id else -1
|
||||
|
||||
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
|
||||
offset).limit(limit).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
@@ -40,12 +59,26 @@ class GenerationRepo:
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int:
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> int:
|
||||
args = {}
|
||||
if character_id is not None:
|
||||
args["linked_character_id"] = character_id
|
||||
if status is not None:
|
||||
args["status"] = status
|
||||
if created_by is not None:
|
||||
args["created_by"] = created_by
|
||||
if project_id is None:
|
||||
args["project_id"] = None
|
||||
if project_id is not None:
|
||||
args["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
args["idea_id"] = idea_id
|
||||
if album_id is not None:
|
||||
args["album_id"] = album_id
|
||||
if only_liked_by is not None:
|
||||
args["liked_by"] = only_liked_by
|
||||
return await self.collection.count_documents(args)
|
||||
|
||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||
@@ -66,3 +99,219 @@ class GenerationRepo:
|
||||
|
||||
async def update_generation(self, generation: Generation, ):
|
||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||
|
||||
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||
"""
|
||||
Toggles like for a user on a generation.
|
||||
Returns True if liked, False if unliked, None if generation not found.
|
||||
"""
|
||||
if not ObjectId.is_valid(generation_id):
|
||||
return None
|
||||
|
||||
oid = ObjectId(generation_id)
|
||||
|
||||
# Check if generation exists
|
||||
gen = await self.collection.find_one({"_id": oid}, {"liked_by": 1})
|
||||
|
||||
if not gen:
|
||||
return None
|
||||
|
||||
if user_id in gen.get("liked_by", []):
|
||||
# Unlike
|
||||
await self.collection.update_one(
|
||||
{"_id": oid},
|
||||
{"$pull": {"liked_by": user_id}}
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Like
|
||||
await self.collection.update_one(
|
||||
{"_id": oid},
|
||||
{"$addToSet": {"liked_by": user_id}}
|
||||
)
|
||||
return True
|
||||
|
||||
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
|
||||
if not ObjectId.is_valid(generation_id):
|
||||
return False
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(generation_id)},
|
||||
{"$set": {"nsfw": is_nsfw}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
# 1. Match all done generations (including soft-deleted)
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
# 2. Group by null (total)
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(1)
|
||||
|
||||
if not res:
|
||||
return {
|
||||
"total_runs": 0,
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_cost": 0.0
|
||||
}
|
||||
|
||||
result = res[0]
|
||||
result.pop("_id")
|
||||
result["total_cost"] = round(result["total_cost"], 4)
|
||||
return result
|
||||
|
||||
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Returns usage statistics grouped by user or project.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": f"${group_by}",
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
pipeline.append({"$sort": {"total_cost": -1}})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(None)
|
||||
|
||||
results = []
|
||||
for item in res:
|
||||
entity_id = item.pop("_id")
|
||||
item["total_cost"] = round(item["total_cost"], 4)
|
||||
results.append({
|
||||
"entity_id": str(entity_id) if entity_id else "unknown",
|
||||
"stats": item
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
||||
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
generation["id"] = str(generation.pop("_id"))
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||
res = await self.collection.update_many(
|
||||
{
|
||||
"status": GenerationStatus.RUNNING,
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"status": GenerationStatus.FAILED,
|
||||
"failed_reason": "Timeout: Execution time limit exceeded",
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
return res.modified_count
|
||||
|
||||
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
|
||||
"""
|
||||
Мягко удаляет генерации старше N дней.
|
||||
Возвращает (количество удалённых, список asset IDs для очистки).
|
||||
"""
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days)
|
||||
filter_query = {
|
||||
"is_deleted": False,
|
||||
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
}
|
||||
|
||||
# Сначала собираем asset IDs из удаляемых генераций
|
||||
asset_ids: List[str] = []
|
||||
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
||||
async for doc in cursor:
|
||||
asset_ids.extend(doc.get("result_list", []))
|
||||
|
||||
# Мягкое удаление
|
||||
res = await self.collection.update_many(
|
||||
filter_query,
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Убираем дубликаты
|
||||
unique_asset_ids = list(set(asset_ids))
|
||||
return res.modified_count, unique_asset_ids
|
||||
|
||||
91
repos/idea_repo.py
Normal file
91
repos/idea_repo.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Optional, List
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["ideas"]
|
||||
|
||||
async def create_idea(self, idea: Idea) -> str:
|
||||
res = await self.collection.insert_one(idea.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Idea(**res)
|
||||
return None
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
if project_id:
|
||||
match_stage = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_stage},
|
||||
{"$sort": {"updated_at": -1}},
|
||||
{"$skip": offset},
|
||||
{"$limit": limit},
|
||||
# Add string id field for lookup
|
||||
{"$addFields": {"str_id": {"$toString": "$_id"}}},
|
||||
# Lookup generations
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "generations",
|
||||
"let": {"idea_id": "$str_id"},
|
||||
"pipeline": [
|
||||
{
|
||||
"$match": {
|
||||
"$and": [
|
||||
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
|
||||
{"status": "done"},
|
||||
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
|
||||
{"is_deleted": False}
|
||||
]
|
||||
}
|
||||
},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
|
||||
{"$limit": 1}
|
||||
],
|
||||
"as": "generations"
|
||||
}
|
||||
},
|
||||
# Unwind generations array (preserve ideas without generations)
|
||||
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
|
||||
# Rename for clarity
|
||||
{"$addFields": {
|
||||
"last_generation": "$generations",
|
||||
"id": "$str_id"
|
||||
}},
|
||||
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
|
||||
]
|
||||
|
||||
return await self.collection.aggregate(pipeline).to_list(None)
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea_id)},
|
||||
{"$set": {"is_deleted": True}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_idea(self, idea: Idea) -> bool:
|
||||
if not idea.id or not ObjectId.is_valid(idea.id):
|
||||
return False
|
||||
|
||||
idea_dict = idea.model_dump()
|
||||
if "id" in idea_dict:
|
||||
del idea_dict["id"]
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea.id)},
|
||||
{"$set": idea_dict}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
54
repos/inspiration_repo.py
Normal file
54
repos/inspiration_repo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
|
||||
class InspirationRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["inspirations"]
|
||||
|
||||
async def create_inspiration(self, inspiration: Inspiration) -> str:
|
||||
res = await self.collection.insert_one(inspiration.model_dump(exclude={"id"}))
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(inspiration_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Inspiration(**res)
|
||||
return None
|
||||
|
||||
async def get_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||
query = {}
|
||||
if project_id:
|
||||
query["project_id"] = project_id
|
||||
if created_by:
|
||||
query["created_by"] = created_by
|
||||
|
||||
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
|
||||
inspirations = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
inspirations.append(Inspiration(**doc))
|
||||
return inspirations
|
||||
|
||||
async def count_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None) -> int:
|
||||
query = {}
|
||||
if project_id:
|
||||
query["project_id"] = project_id
|
||||
if created_by:
|
||||
query["created_by"] = created_by
|
||||
return await self.collection.count_documents(query)
|
||||
|
||||
async def update_inspiration(self, inspiration: Inspiration):
|
||||
await self.collection.update_one(
|
||||
{"_id": ObjectId(inspiration.id)},
|
||||
{"$set": inspiration.model_dump(exclude={"id"})}
|
||||
)
|
||||
|
||||
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(inspiration_id)})
|
||||
return res.deleted_count > 0
|
||||
97
repos/post_repo.py
Normal file
97
repos/post_repo.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Post import Post
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["posts"]
|
||||
|
||||
async def create_post(self, post: Post) -> str:
|
||||
res = await self.collection.insert_one(post.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Post(**res)
|
||||
return None
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
if project_id:
|
||||
match = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
if date_from or date_to:
|
||||
date_filter = {}
|
||||
if date_from:
|
||||
date_filter["$gte"] = date_from
|
||||
if date_to:
|
||||
date_filter["$lte"] = date_to
|
||||
match["date"] = date_filter
|
||||
|
||||
cursor = (
|
||||
self.collection.find(match)
|
||||
.sort("date", -1)
|
||||
.skip(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
posts = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
posts.append(Post(**doc))
|
||||
return posts
|
||||
|
||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": data},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$pull": {"generation_ids": generation_id}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
62
repos/project_repo.py
Normal file
62
repos/project_repo.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Project import Project
|
||||
|
||||
class ProjectRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["projects"]
|
||||
|
||||
async def create_project(self, project: Project) -> str:
|
||||
res = await self.collection.insert_one(project.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||
if not ObjectId.is_valid(project_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(project_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Project(**res)
|
||||
return None
|
||||
|
||||
async def get_projects_by_user(self, user_id: str) -> List[Project]:
|
||||
# Find projects where user is owner OR in members
|
||||
filter = {
|
||||
"$or": [
|
||||
{"owner_id": user_id},
|
||||
{"members": user_id}
|
||||
],
|
||||
"is_deleted": False
|
||||
}
|
||||
cursor = self.collection.find(filter).sort("created_at", -1)
|
||||
projects = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
projects.append(Project(**doc))
|
||||
return projects
|
||||
|
||||
async def add_member(self, project_id: str, user_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$addToSet": {"members": user_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_member(self, project_id: str, user_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$pull": {"members": user_id}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_project(self, project_id: str, updates: dict) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(project_id)},
|
||||
{"$set": updates}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_project(self, project_id: str) -> bool:
|
||||
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
|
||||
return res.modified_count > 0
|
||||
@@ -19,10 +19,16 @@ class UsersRepo:
|
||||
self.collection = client[db_name]["users"]
|
||||
|
||||
async def get_user(self, user_id: int):
|
||||
return await self.collection.find_one({"user_id": user_id})
|
||||
user = await self.collection.find_one({"user_id": user_id})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def get_user_by_username(self, username: str):
|
||||
return await self.collection.find_one({"username": username})
|
||||
user = await self.collection.find_one({"username": username})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
|
||||
"""Создает нового пользователя с username/паролем"""
|
||||
@@ -38,15 +44,23 @@ class UsersRepo:
|
||||
"created_at": datetime.now(),
|
||||
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
|
||||
"is_web_user": True,
|
||||
"is_admin": False
|
||||
"is_admin": False,
|
||||
"project_ids": [],
|
||||
"current_project_id": None
|
||||
}
|
||||
result = await self.collection.insert_one(user_doc)
|
||||
return await self.collection.find_one({"_id": result.inserted_id})
|
||||
user = await self.collection.find_one({"_id": result.inserted_id})
|
||||
if user:
|
||||
user["id"] = str(user["_id"])
|
||||
return user
|
||||
|
||||
async def get_pending_users(self):
|
||||
"""Возвращает список пользователей со статусом PENDING"""
|
||||
cursor = self.collection.find({"status": UserStatus.PENDING})
|
||||
return await cursor.to_list(length=100)
|
||||
users = await cursor.to_list(length=100)
|
||||
for user in users:
|
||||
user["id"] = str(user["_id"])
|
||||
return users
|
||||
|
||||
async def approve_user(self, username: str):
|
||||
await self.collection.update_one(
|
||||
|
||||
@@ -50,3 +50,6 @@ passlib[argon2]==1.7.4
|
||||
python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.22
|
||||
email-validator
|
||||
prometheus-fastapi-instrumentator
|
||||
pydantic-settings==2.13.0
|
||||
yt-dlp
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -51,56 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||
wait_msg = await message.answer("💾 Сохраняю персонажа...")
|
||||
|
||||
try:
|
||||
# ВОТ ТУТ скачиваем файл (прямо перед сохранением)
|
||||
# 1. Скачиваем файл (один раз)
|
||||
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
|
||||
file_io = await bot.download(file_id)
|
||||
# photo_bytes = file_io.getvalue() # Получаем байты
|
||||
file_bytes = file_io.read()
|
||||
|
||||
|
||||
# Создаем модель
|
||||
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
|
||||
char = Character(
|
||||
id=None,
|
||||
name=name,
|
||||
character_image_data=file_io.read(),
|
||||
character_image_tg_id=None,
|
||||
character_image_doc_tg_id=file_id,
|
||||
character_bio=bio
|
||||
character_bio=bio,
|
||||
created_by=str(message.from_user.id)
|
||||
)
|
||||
file_io.close()
|
||||
|
||||
# Сохраняем через DAO
|
||||
|
||||
# Сохраняем, чтобы получить ID
|
||||
await dao.chars.add_character(char)
|
||||
file_info = await bot.get_file(char.character_image_doc_tg_id)
|
||||
file_bytes = await bot.download_file(file_info.file_path)
|
||||
file_io = file_bytes.read()
|
||||
avatar_asset = await dao.assets.create_asset(
|
||||
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||
tg_doc_file_id=file_id))
|
||||
char.avatar_image = avatar_asset.link
|
||||
|
||||
# 3. Создаем Asset (связанный с персонажем)
|
||||
avatar_asset_id = await dao.assets.create_asset(
|
||||
Asset(
|
||||
name="avatar.png",
|
||||
type=AssetType.UPLOADED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=str(char.id),
|
||||
data=file_bytes,
|
||||
tg_doc_file_id=file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Обновляем персонажа ссылками на ассет
|
||||
char.avatar_asset_id = avatar_asset_id
|
||||
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
|
||||
|
||||
# Отправляем подтверждение
|
||||
# Используем байты для отправки обратно
|
||||
photo_msg = await message.answer_photo(
|
||||
photo=BufferedInputFile(file_io,
|
||||
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
|
||||
photo=BufferedInputFile(file_bytes, filename="char.jpg"),
|
||||
caption=(
|
||||
"🎉 <b>Персонаж создан!</b>\n\n"
|
||||
f"👤 <b>Имя:</b> {char.name}\n"
|
||||
f"📝 <b>Био:</b> {char.character_bio}"
|
||||
)
|
||||
)
|
||||
file_bytes.close()
|
||||
char.character_image_tg_id = photo_msg.photo[0].file_id
|
||||
|
||||
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
|
||||
char.character_image_tg_id = photo_msg.photo[-1].file_id
|
||||
|
||||
# Финальное обновление персонажа
|
||||
await dao.chars.update_char(char.id, char)
|
||||
|
||||
await wait_msg.delete()
|
||||
file_io.close()
|
||||
|
||||
# Сбрасываем состояние
|
||||
await state.clear()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logger.error(f"Error creating character: {e}")
|
||||
traceback.print_exc()
|
||||
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
|
||||
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
||||
|
||||
|
||||
@router.message(Command("chars"))
|
||||
|
||||
@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
|
||||
await wait_msg.delete()
|
||||
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
|
||||
|
||||
|
||||
@router.message(Command("gen_mode"))
|
||||
@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||||
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
await call.answer()
|
||||
keyboards = []
|
||||
for ratio in AspectRatios:
|
||||
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
||||
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
|
||||
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
|
||||
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
|
||||
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||||
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
|
||||
|
||||
|
||||
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||||
@@ -259,7 +258,8 @@ async def handle_album(
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||
linked_char_id = data["char_id"]))
|
||||
linked_char_id = data["char_id"],
|
||||
created_by=str(message.from_user.id)))
|
||||
else:
|
||||
await message.answer("❌ Генерация не вернула изображений.")
|
||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||
@@ -314,7 +314,8 @@ async def gen_mode_start(
|
||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||
linked_char_id=data["char_id"]))
|
||||
linked_char_id=data["char_id"],
|
||||
created_by=str(message.from_user.id)))
|
||||
|
||||
else:
|
||||
await message.answer("❌ Ничего не сгенерировалось.")
|
||||
|
||||
101
tests/test_character_crud.py
Normal file
101
tests/test_character_crud.py
Normal file
@@ -0,0 +1,101 @@
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import asyncio
|
||||
from config import settings
|
||||
|
||||
from aiws import app
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from models.Character import Character
|
||||
|
||||
# Config for test DB
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_chars"
|
||||
|
||||
# Mock User
|
||||
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||
MOCK_USER = {
|
||||
"_id": MOCK_USER_ID,
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"status": "allowed"
|
||||
}
|
||||
|
||||
# Override get_current_user to bypass auth
|
||||
def mock_get_current_user():
|
||||
return MOCK_USER
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||
|
||||
# Setup Real DAO with Test DB
|
||||
client_mongo = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client_mongo, db_name=DB_NAME)
|
||||
|
||||
def mock_get_dao():
|
||||
return dao
|
||||
|
||||
app.dependency_overrides[get_dao] = mock_get_dao
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def setup_teardown():
|
||||
# Setup: Ensure clean state
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||
|
||||
yield
|
||||
|
||||
# Teardown
|
||||
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||
loop.close()
|
||||
|
||||
def test_character_crud_flow():
|
||||
# 1. Create Character
|
||||
create_payload = {
|
||||
"name": "Test Character",
|
||||
"character_bio": "A bio for test character",
|
||||
"character_image_doc_tg_id": "file_123",
|
||||
"avatar_image": "http://example.com/avatar.jpg"
|
||||
}
|
||||
|
||||
response = client.post("/api/characters/", json=create_payload)
|
||||
assert response.status_code == 200, response.text
|
||||
char_data = response.json()
|
||||
assert char_data["name"] == create_payload["name"]
|
||||
assert char_data["created_by"] == MOCK_USER_ID
|
||||
char_id = char_data["id"]
|
||||
assert char_id is not None
|
||||
|
||||
# 2. Get Character
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == char_id
|
||||
|
||||
# 3. Update Character
|
||||
update_payload = {
|
||||
"name": "Updated Name",
|
||||
"character_bio": "Updated bio"
|
||||
}
|
||||
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
updated_data = response.json()
|
||||
assert updated_data["name"] == "Updated Name"
|
||||
assert updated_data["character_bio"] == "Updated bio"
|
||||
|
||||
# Verify update persistent
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.json()["name"] == "Updated Name"
|
||||
|
||||
# 4. Delete Character
|
||||
response = client.delete(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify deletion
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 404, "Deleted character should return 404"
|
||||
64
tests/test_character_integration.py
Normal file
64
tests/test_character_integration.py
Normal file
@@ -0,0 +1,64 @@
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# 1. Set Auth Bypass and Test Config
|
||||
os.environ["DB_NAME"] = "bot_db_test_integration"
|
||||
# We keep MONGO_HOST as is (it works in verified script)
|
||||
|
||||
# 2. Import app AFTER setting env
|
||||
from main import app
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
# 3. Override Auth
|
||||
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||
MOCK_USER = {
|
||||
"_id": MOCK_USER_ID,
|
||||
"username": "testuser",
|
||||
"is_admin": False,
|
||||
"status": "allowed",
|
||||
"project_ids": []
|
||||
}
|
||||
|
||||
def mock_get_current_user():
|
||||
return MOCK_USER
|
||||
|
||||
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def test_character_crud_lifecycle():
|
||||
# 1. Create
|
||||
create_payload = {
|
||||
"name": "Integration Test Char",
|
||||
"character_bio": "Testing with real app structure",
|
||||
"character_image_doc_tg_id": "doc_123",
|
||||
"avatar_image": "http://example.com/img.jpg"
|
||||
}
|
||||
|
||||
response = client.post("/api/characters/", json=create_payload)
|
||||
assert response.status_code == 200, response.text
|
||||
char_data = response.json()
|
||||
assert char_data["name"] == create_payload["name"]
|
||||
char_id = char_data["id"]
|
||||
|
||||
# 2. Get
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == char_id
|
||||
|
||||
# 3. Update
|
||||
update_payload = {"name": "Updated Int Name"}
|
||||
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "Updated Int Name"
|
||||
|
||||
# 4. Delete
|
||||
response = client.delete(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 204
|
||||
|
||||
# 5. Verify Delete
|
||||
response = client.get(f"/api/characters/{char_id}")
|
||||
assert response.status_code == 404
|
||||
63
tests/test_external_import.py
Executable file
63
tests/test_external_import.py
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for external generation import API.
|
||||
This script demonstrates how to call the import endpoint with proper HMAC signature.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import hashlib
|
||||
import json
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
from config import settings
|
||||
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
# Configuration
|
||||
API_URL = "http://localhost:8090/api/generations/import"
|
||||
SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production"
|
||||
|
||||
# Sample generation data
|
||||
generation_data = {
|
||||
"prompt": "A beautiful sunset over mountains",
|
||||
"tech_prompt": "High quality landscape photography",
|
||||
"image_url": "https://picsum.photos/512/512", # Sample image URL
|
||||
# OR use base64:
|
||||
# "image_data": "base64_encoded_image_string_here",
|
||||
"aspect_ratio": "9:16",
|
||||
"quality": "1k",
|
||||
"created_by": "external_user_123",
|
||||
"execution_time_seconds": 5.2,
|
||||
"token_usage": 1000,
|
||||
"input_token_usage": 200,
|
||||
"output_token_usage": 800
|
||||
}
|
||||
|
||||
# Convert to JSON
|
||||
body = json.dumps(generation_data).encode('utf-8')
|
||||
|
||||
# Compute HMAC signature
|
||||
signature = hmac.new(
|
||||
SECRET.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Make request
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Signature": signature
|
||||
}
|
||||
|
||||
print(f"Sending request to {API_URL}")
|
||||
print(f"Signature: {signature}")
|
||||
|
||||
try:
|
||||
response = requests.post(API_URL, data=body, headers=headers)
|
||||
print(f"\nStatus Code: {response.status_code}")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(f"Response text: {e.response.text}")
|
||||
96
tests/test_idea.py
Normal file
96
tests/test_idea.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from bson import ObjectId
|
||||
|
||||
# Import from project root (requires PYTHONPATH=.)
|
||||
from api.service.idea_service import IdeaService
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
async def test_idea_flow():
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
dao = DAO(client, db_name=DB_NAME)
|
||||
service = IdeaService(dao)
|
||||
|
||||
# 1. Create an Idea
|
||||
print("Creating idea...")
|
||||
user_id = "test_user_123"
|
||||
project_id = "test_project_abc"
|
||||
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
|
||||
print(f"Idea created: {idea.id} - {idea.name}")
|
||||
|
||||
# 2. Update Idea
|
||||
print("Updating idea...")
|
||||
updated_idea = await service.update_idea(idea.id, description="Updated description")
|
||||
print(f"Idea updated: {updated_idea.description}")
|
||||
if updated_idea.description == "Updated description":
|
||||
print("✅ Idea update successful")
|
||||
else:
|
||||
print("❌ Idea update FAILED")
|
||||
|
||||
# 3. Add Generation linked to Idea
|
||||
print("Creating generation linked to idea...")
|
||||
gen = Generation(
|
||||
prompt="idea generation 1",
|
||||
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
aspect_ratio=AspectRatios.NINESIXTEEN,
|
||||
quality=Quality.ONEK,
|
||||
assets_list=[]
|
||||
)
|
||||
gen_id = await dao.generations.create_generation(gen)
|
||||
print(f"Created generation: {gen_id}")
|
||||
|
||||
# Link generation to idea
|
||||
print("Linking generation to idea...")
|
||||
success = await service.add_generation_to_idea(idea.id, gen_id)
|
||||
if success:
|
||||
print("✅ Linking successful")
|
||||
else:
|
||||
print("❌ Linking FAILED")
|
||||
|
||||
# Debug: Check if generation was saved with idea_id
|
||||
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
|
||||
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
|
||||
|
||||
# 4. Fetch Generations for Idea (Verify filtering and ordering)
|
||||
print("Fetching generations for idea...")
|
||||
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
|
||||
print(f"Found {len(gens)} generations in idea")
|
||||
|
||||
if len(gens) == 1 and gens[0].id == gen_id:
|
||||
print("✅ Generation retrieval successful")
|
||||
else:
|
||||
print("❌ Generation retrieval FAILED")
|
||||
|
||||
# 5. Fetch Ideas for Project
|
||||
ideas = await service.get_ideas(project_id)
|
||||
print(f"Found {len(ideas)} ideas for project")
|
||||
|
||||
# Cleaning up
|
||||
print("Cleaning up...")
|
||||
await service.delete_idea(idea.id)
|
||||
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
|
||||
|
||||
# Verify deletion
|
||||
deleted_idea = await service.get_idea(idea.id)
|
||||
# IdeaRepo.delete_idea logic sets is_deleted=True
|
||||
if deleted_idea and deleted_idea.is_deleted:
|
||||
print(f"✅ Idea deleted successfully")
|
||||
|
||||
# Hard delete for cleanup
|
||||
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_idea_flow())
|
||||
@@ -1,15 +1,14 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import settings
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
async def test_s3():
|
||||
load_dotenv()
|
||||
|
||||
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
|
||||
access_key = os.getenv("MINIO_ACCESS_KEY")
|
||||
secret_key = os.getenv("MINIO_SECRET_KEY")
|
||||
bucket = os.getenv("MINIO_BUCKET")
|
||||
endpoint = settings.MINIO_ENDPOINT
|
||||
access_key = settings.MINIO_ACCESS_KEY
|
||||
secret_key = settings.MINIO_SECRET_KEY
|
||||
bucket = settings.MINIO_BUCKET
|
||||
|
||||
print(f"Connecting to {endpoint}, bucket: {bucket}")
|
||||
|
||||
|
||||
50
tests/test_scheduler.py
Normal file
50
tests/test_scheduler.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from config import settings
|
||||
|
||||
# Mock configs if not present in env
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
async def test_scheduler():
|
||||
client = AsyncIOMotorClient(MONGO_HOST)
|
||||
repo = GenerationRepo(client, db_name=DB_NAME)
|
||||
|
||||
# 1. Create a "stale" generation (2 hours ago)
|
||||
stale_gen = Generation(
|
||||
prompt="stale test",
|
||||
status=GenerationStatus.RUNNING,
|
||||
created_at=datetime.now(UTC) - timedelta(minutes=120),
|
||||
assets_list=[],
|
||||
aspect_ratio="NINESIXTEEN",
|
||||
quality="ONEK"
|
||||
)
|
||||
gen_id = await repo.create_generation(stale_gen)
|
||||
print(f"Created stale generation: {gen_id}")
|
||||
|
||||
# 2. Run cleanup
|
||||
print("Running cleanup...")
|
||||
count = await repo.cancel_stale_generations(timeout_minutes=60)
|
||||
print(f"Cleaned up {count} generations")
|
||||
|
||||
# 3. Verify status
|
||||
updated_gen = await repo.get_generation(gen_id)
|
||||
print(f"Generation status: {updated_gen.status}")
|
||||
print(f"Failed reason: {updated_gen.failed_reason}")
|
||||
|
||||
if updated_gen.status == GenerationStatus.FAILED:
|
||||
print("✅ SUCCESS: Generation marked as FAILED")
|
||||
else:
|
||||
print("❌ FAILURE: Generation status not updated")
|
||||
|
||||
# Cleanup
|
||||
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_scheduler())
|
||||
@@ -10,10 +10,11 @@ from repos.dao import DAO
|
||||
from models.Album import Album
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
# Mock config
|
||||
# Use the same host as main.py but different DB
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||
# Use the same host as aiws.py but different DB
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_albums"
|
||||
|
||||
async def test_albums():
|
||||
@@ -83,8 +84,6 @@ async def test_albums():
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
try:
|
||||
asyncio.run(test_albums())
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from config import settings
|
||||
from models.Asset import Asset, AssetType
|
||||
from repos.assets_repo import AssetsRepo
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
# Load env to get credentials
|
||||
load_dotenv()
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
async def test_integration():
|
||||
print("🚀 Starting integration test...")
|
||||
|
||||
# 1. Setup Dependencies
|
||||
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
mongo_uri = settings.MONGO_HOST
|
||||
client = AsyncIOMotorClient(mongo_uri)
|
||||
db_name = os.getenv("DB_NAME", "bot_db_test")
|
||||
db_name = settings.DB_NAME + "_test"
|
||||
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
repo = AssetsRepo(client, s3_adapter, db_name=db_name)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
46
utils/external_auth.py
Normal file
46
utils/external_auth.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import hmac
|
||||
import hashlib
|
||||
import os
|
||||
from fastapi import Header, HTTPException
|
||||
from typing import Optional
|
||||
|
||||
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
|
||||
"""
|
||||
Verify HMAC-SHA256 signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from X-Signature header
|
||||
secret: Shared secret key
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise
|
||||
"""
|
||||
expected_signature = hmac.new(
|
||||
secret.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(signature, expected_signature)
|
||||
|
||||
|
||||
async def verify_external_signature(
|
||||
x_signature: Optional[str] = Header(None, alias="X-Signature")
|
||||
):
|
||||
"""
|
||||
FastAPI dependency to verify external API signature.
|
||||
|
||||
Raises:
|
||||
HTTPException: If signature is missing or invalid
|
||||
"""
|
||||
if not x_signature:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing X-Signature header"
|
||||
)
|
||||
|
||||
# Note: We'll need to access the raw request body in the endpoint
|
||||
# This dependency just validates the header exists
|
||||
# Actual signature verification happens in the endpoint
|
||||
return x_signature
|
||||
@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from config import settings
|
||||
|
||||
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
|
||||
# SECRET_KEY должен быть сложным и секретным в продакшене!
|
||||
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
|
||||
# Настройки безопасности берутся из config.py
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = settings.ALGORITHM
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user