Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 32ff77e04b | |||
| d1f67c773f | |||
| c63b51ef75 |
33
.context.md
33
.context.md
@@ -1,33 +0,0 @@
|
||||
# Project Context: AI Char Bot
|
||||
|
||||
## Overview
|
||||
Python backend project using FastAPI and MongoDB (Motor).
|
||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||
|
||||
## Architecture
|
||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||
- **Service Layer**: `api/service` (Business logic).
|
||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||
|
||||
## Coding Standards & Preferences
|
||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||
- **Error Handling**:
|
||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||
- Services/Routers handle `HTTPException`.
|
||||
|
||||
## Key Features & Implementation Details
|
||||
- **Generations**:
|
||||
- Managed by `GenerationService` and `GenerationRepo`.
|
||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||
- **Ideas**:
|
||||
- Managed by `IdeaService` and `IdeaRepo`.
|
||||
- Can have linked generations.
|
||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||
|
||||
## Recent Changes
|
||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||
2
.env
2
.env
@@ -9,3 +9,5 @@ MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||
MINIO_BUCKET=ai-char
|
||||
MODE=production
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
|
||||
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4
|
||||
@@ -1,33 +0,0 @@
|
||||
# Project Context: AI Char Bot
|
||||
|
||||
## Overview
|
||||
Python backend project using FastAPI and MongoDB (Motor).
|
||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||
|
||||
## Architecture
|
||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||
- **Service Layer**: `api/service` (Business logic).
|
||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||
|
||||
## Coding Standards & Preferences
|
||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||
- **Error Handling**:
|
||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||
- Services/Routers handle `HTTPException`.
|
||||
|
||||
## Key Features & Implementation Details
|
||||
- **Generations**:
|
||||
- Managed by `GenerationService` and `GenerationRepo`.
|
||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||
- **Ideas**:
|
||||
- Managed by `IdeaService` and `IdeaRepo`.
|
||||
- Can have linked generations.
|
||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||
|
||||
## Recent Changes
|
||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -9,18 +9,3 @@ minio_backup.tar.gz
|
||||
.idea
|
||||
.venv
|
||||
.vscode
|
||||
.vscode/launch.json
|
||||
middlewares/__pycache__/
|
||||
middlewares/*.pyc
|
||||
api/__pycache__/
|
||||
api/*.pyc
|
||||
repos/__pycache__/
|
||||
repos/*.pyc
|
||||
adapters/__pycache__/
|
||||
adapters/*.pyc
|
||||
services/__pycache__/
|
||||
services/*.pyc
|
||||
utils/__pycache__/
|
||||
utils/*.pyc
|
||||
.vscode/launch.json
|
||||
repos/__pycache__/assets_repo.cpython-313.pyc
|
||||
|
||||
25
.vscode/launch.json
vendored
25
.vscode/launch.json
vendored
@@ -16,6 +16,31 @@
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Debug Tests: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"${file}"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -8,7 +8,7 @@ from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from models.enums import AspectRatios, Quality, TextModel, ImageModel
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,37 +19,36 @@ class GoogleAdapter:
|
||||
raise ValueError("API Key for Gemini is missing")
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
||||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
||||
contents : list [Any]= [prompt]
|
||||
opened_images = []
|
||||
# Константы моделей
|
||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||
|
||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
|
||||
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||
contents = [prompt]
|
||||
if images_list:
|
||||
logger.info(f"Preparing content with {len(images_list)} images")
|
||||
for img_bytes in images_list:
|
||||
try:
|
||||
# Gemini API требует PIL Image на входе
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
contents.append(image)
|
||||
opened_images.append(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {e}")
|
||||
else:
|
||||
logger.info("Preparing content with no images")
|
||||
return contents, opened_images
|
||||
return contents
|
||||
|
||||
def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
|
||||
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||
"""
|
||||
Генерация текста (Чат или Vision).
|
||||
Возвращает строку с ответом.
|
||||
"""
|
||||
if model not in [m.value for m in TextModel]:
|
||||
raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
|
||||
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating text: {prompt} with model: {model}")
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating text: {prompt}")
|
||||
try:
|
||||
response = self.client.models.generate_content(
|
||||
model=model,
|
||||
model=self.TEXT_MODEL,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['TEXT'],
|
||||
@@ -69,27 +68,22 @@ class GoogleAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Text API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
"""
|
||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||
Возвращает список байтовых потоков (готовых к отправке).
|
||||
"""
|
||||
if model not in [m.value for m in ImageModel]:
|
||||
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
|
||||
|
||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
|
||||
contents = self._prepare_contents(prompt, images_list)
|
||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
||||
|
||||
start_time = datetime.now()
|
||||
token_usage = 0
|
||||
|
||||
try:
|
||||
response = self.client.models.generate_content(
|
||||
model=model,
|
||||
model=self.IMAGE_MODEL,
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['IMAGE'],
|
||||
@@ -107,20 +101,8 @@ class GoogleAdapter:
|
||||
if response.usage_metadata:
|
||||
token_usage = response.usage_metadata.total_token_count
|
||||
|
||||
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
|
||||
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
||||
raise GoogleGenerationException(
|
||||
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
|
||||
)
|
||||
|
||||
# Check candidate-level block
|
||||
if response.parts is None:
|
||||
response_reason = (
|
||||
response.candidates[0].finish_reason
|
||||
if response.candidates and len(response.candidates) > 0
|
||||
else "Unknown"
|
||||
)
|
||||
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
|
||||
if response.parts is None and response.candidates[0].finish_reason is not None:
|
||||
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
||||
|
||||
generated_images = []
|
||||
|
||||
@@ -131,9 +113,7 @@ class GoogleAdapter:
|
||||
try:
|
||||
# 1. Берем сырые байты
|
||||
raw_data = part.inline_data.data
|
||||
if raw_data is None:
|
||||
raise GoogleGenerationException("Generation returned no data")
|
||||
byte_arr : io.BytesIO = io.BytesIO(raw_data)
|
||||
byte_arr = io.BytesIO(raw_data)
|
||||
|
||||
# 2. Нейминг (формально, для TG)
|
||||
timestamp = datetime.now().timestamp()
|
||||
@@ -168,7 +148,3 @@ class GoogleAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"Gemini Image API Error: {e}")
|
||||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||
finally:
|
||||
for img in opened_images:
|
||||
img.close()
|
||||
del contents
|
||||
165
adapters/kling_adapter.py
Normal file
165
adapters/kling_adapter.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KLING_API_BASE = "https://api.klingai.com"
|
||||
|
||||
|
||||
class KlingApiException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KlingAdapter:
|
||||
def __init__(self, access_key: str, secret_key: str):
|
||||
if not access_key or not secret_key:
|
||||
raise ValueError("Kling API credentials are missing")
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
def _generate_token(self) -> str:
|
||||
"""Generate a JWT token for Kling API authentication."""
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"iss": self.access_key,
|
||||
"exp": now + 1800, # 30 minutes
|
||||
"iat": now - 5, # небольшой запас назад
|
||||
"nbf": now - 5,
|
||||
}
|
||||
return jwt.encode(payload, self.secret_key, algorithm="HS256",
|
||||
headers={"typ": "JWT", "alg": "HS256"})
|
||||
|
||||
def _headers(self) -> dict:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self._generate_token()}"
|
||||
}
|
||||
|
||||
async def create_video_task(
|
||||
self,
|
||||
image_url: str,
|
||||
prompt: str = "",
|
||||
negative_prompt: str = "",
|
||||
model_name: str = "kling-v2-6",
|
||||
duration: int = 5,
|
||||
mode: str = "std",
|
||||
cfg_scale: float = 0.5,
|
||||
aspect_ratio: str = "16:9",
|
||||
callback_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an image-to-video generation task.
|
||||
Returns the full task data dict including task_id.
|
||||
"""
|
||||
body: Dict[str, Any] = {
|
||||
"model_name": model_name,
|
||||
"image": image_url,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"duration": str(duration),
|
||||
"mode": mode,
|
||||
"cfg_scale": cfg_scale,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
}
|
||||
if callback_url:
|
||||
body["callback_url"] = callback_url
|
||||
|
||||
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{KLING_API_BASE}/v1/videos/image2video",
|
||||
headers=self._headers(),
|
||||
json=body,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
|
||||
|
||||
if response.status_code != 200 or data.get("code") != 0:
|
||||
error_msg = data.get("message", "Unknown Kling API error")
|
||||
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
|
||||
|
||||
task_data = data.get("data", {})
|
||||
task_id = task_data.get("task_id")
|
||||
if not task_id:
|
||||
raise KlingApiException("No task_id returned from Kling API")
|
||||
|
||||
logger.info(f"Kling video task created: task_id={task_id}")
|
||||
return task_data
|
||||
|
||||
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Query the status of a video generation task.
|
||||
Returns the full task data dict.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
|
||||
headers=self._headers(),
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
|
||||
if response.status_code != 200 or data.get("code") != 0:
|
||||
error_msg = data.get("message", "Unknown error")
|
||||
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
|
||||
|
||||
return data.get("data", {})
|
||||
|
||||
async def wait_for_completion(
|
||||
self,
|
||||
task_id: str,
|
||||
poll_interval: int = 10,
|
||||
timeout: int = 600,
|
||||
progress_callback=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll the task status until completion.
|
||||
|
||||
Args:
|
||||
task_id: Kling task ID
|
||||
poll_interval: seconds between polls
|
||||
timeout: max seconds to wait
|
||||
progress_callback: async callable(progress_pct: int) to report progress
|
||||
|
||||
Returns:
|
||||
Final task data dict with video URL on success.
|
||||
|
||||
Raises:
|
||||
KlingApiException on failure or timeout.
|
||||
"""
|
||||
start = time.time()
|
||||
attempt = 0
|
||||
|
||||
while True:
|
||||
elapsed = time.time() - start
|
||||
if elapsed > timeout:
|
||||
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
|
||||
|
||||
task_data = await self.get_task_status(task_id)
|
||||
status = task_data.get("task_status")
|
||||
|
||||
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||
|
||||
if status == "succeed":
|
||||
logger.info(f"Kling task {task_id} completed successfully")
|
||||
return task_data
|
||||
|
||||
if status == "failed":
|
||||
fail_reason = task_data.get("task_status_msg", "Unknown failure")
|
||||
raise KlingApiException(f"Video generation failed: {fail_reason}")
|
||||
|
||||
# Report progress estimate (linear approximation based on typical time)
|
||||
if progress_callback:
|
||||
# Estimate: typical gen is ~120s, cap at 90%
|
||||
estimated_progress = min(int((elapsed / 120) * 90), 90)
|
||||
attempt += 1
|
||||
await progress_callback(estimated_progress)
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, BinaryIO, AsyncGenerator
|
||||
from typing import Optional, BinaryIO
|
||||
import aioboto3
|
||||
from botocore.exceptions import ClientError
|
||||
import os
|
||||
@@ -18,7 +18,7 @@ class S3Adapter:
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_client(self):
|
||||
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
|
||||
async with self.session.client(
|
||||
"s3",
|
||||
endpoint_url=self.endpoint_url,
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
@@ -56,37 +56,6 @@ class S3Adapter:
|
||||
print(f"Error downloading from S3: {e}")
|
||||
return None
|
||||
|
||||
async def get_file_size(self, object_name: str) -> Optional[int]:
|
||||
"""Returns the size of the file in bytes."""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
response = await client.head_object(Bucket=self.bucket_name, Key=object_name)
|
||||
return response['ContentLength']
|
||||
except ClientError as e:
|
||||
print(f"Error getting file size from S3: {e}")
|
||||
return None
|
||||
|
||||
async def stream_file(self, object_name: str, range_header: Optional[str] = None, chunk_size: int = 65536) -> AsyncGenerator[bytes, None]:
|
||||
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
||||
try:
|
||||
async with self._get_client() as client:
|
||||
args = {'Bucket': self.bucket_name, 'Key': object_name}
|
||||
if range_header:
|
||||
args['Range'] = range_header
|
||||
|
||||
response = await client.get_object(**args)
|
||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
||||
body = response['Body']
|
||||
|
||||
while True:
|
||||
chunk = await body.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
except ClientError as e:
|
||||
print(f"Error streaming from S3: {e}")
|
||||
return
|
||||
|
||||
async def delete_file(self, object_name: str):
|
||||
"""Deletes a file from S3."""
|
||||
try:
|
||||
|
||||
87
aiws.py
87
aiws.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aiogram import Bot, Dispatcher, Router, F
|
||||
@@ -8,6 +9,7 @@ from aiogram.enums import ParseMode
|
||||
from aiogram.filters import CommandStart, Command
|
||||
from aiogram.types import Message
|
||||
from aiogram.fsm.storage.mongo import MongoStorage
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from prometheus_client import Info
|
||||
@@ -15,8 +17,8 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||
from config import settings
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.kling_adapter import KlingAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from api.service.album_service import AlbumService
|
||||
@@ -42,21 +44,17 @@ from api.endpoints.auth import router as api_auth_router
|
||||
from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
from api.endpoints.post_router import router as post_api_router
|
||||
from api.endpoints.environment_router import router as environment_api_router
|
||||
from api.endpoints.inspiration_router import router as inspiration_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- КОНФИГУРАЦИЯ ---
|
||||
# Настройки теперь берутся из config.py
|
||||
BOT_TOKEN = settings.BOT_TOKEN
|
||||
GEMINI_API_KEY = settings.GEMINI_API_KEY
|
||||
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
ADMIN_ID = settings.ADMIN_ID
|
||||
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
||||
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
||||
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@@ -66,8 +64,6 @@ def setup_logging():
|
||||
|
||||
|
||||
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
||||
if BOT_TOKEN is None:
|
||||
raise ValueError("BOT_TOKEN is not set")
|
||||
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||
|
||||
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
||||
@@ -80,19 +76,26 @@ char_repo = CharacterRepo(mongo_client)
|
||||
|
||||
# S3 Adapter
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
)
|
||||
|
||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||
if GEMINI_API_KEY is None:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||
if bot is None:
|
||||
raise ValueError("bot is not set")
|
||||
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
|
||||
|
||||
# Kling Adapter (optional, for video generation)
|
||||
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
|
||||
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
|
||||
kling_adapter = None
|
||||
if kling_access_key and kling_secret_key:
|
||||
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
|
||||
logger.info("Kling adapter initialized")
|
||||
else:
|
||||
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
|
||||
|
||||
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
|
||||
album_service = AlbumService(dao)
|
||||
|
||||
# Dispatcher
|
||||
@@ -129,18 +132,6 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
|
||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||
|
||||
|
||||
async def start_scheduler(service: GenerationService):
|
||||
while True:
|
||||
try:
|
||||
logger.info("Running scheduler for stacked generation killing")
|
||||
await service.cleanup_stale_generations()
|
||||
await service.cleanup_old_data(days=14)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler error: {e}")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -159,6 +150,7 @@ async def lifespan(app: FastAPI):
|
||||
app.state.gemini_client = gemini
|
||||
app.state.bot = bot
|
||||
app.state.s3_adapter = s3_adapter
|
||||
app.state.kling_adapter = kling_adapter
|
||||
app.state.album_service = album_service
|
||||
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||
|
||||
@@ -172,28 +164,17 @@ async def lifespan(app: FastAPI):
|
||||
# )
|
||||
# print("🤖 Bot polling started")
|
||||
|
||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||
print("⏰ Scheduler started")
|
||||
|
||||
yield
|
||||
|
||||
# --- SHUTDOWN ---
|
||||
print("🛑 Shutting down...")
|
||||
|
||||
# 4. Остановка шедулера
|
||||
scheduler_task.cancel()
|
||||
try:
|
||||
await scheduler_task
|
||||
except asyncio.CancelledError:
|
||||
print("⏰ Scheduler stopped")
|
||||
|
||||
# 3. Остановка бота
|
||||
# polling_task.cancel()
|
||||
# try:
|
||||
# await polling_task
|
||||
# except asyncio.CancelledError:
|
||||
# print("🤖 Bot polling stopped")
|
||||
polling_task.cancel()
|
||||
try:
|
||||
await polling_task
|
||||
except asyncio.CancelledError:
|
||||
print("🤖 Bot polling stopped")
|
||||
|
||||
# 4. Отключение БД
|
||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||
@@ -220,10 +201,6 @@ app.include_router(api_char_router)
|
||||
app.include_router(api_gen_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
app.include_router(post_api_router)
|
||||
app.include_router(environment_api_router)
|
||||
app.include_router(inspiration_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
@@ -262,7 +239,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# Создаем конфигурацию uvicorn вручную
|
||||
# loop="asyncio" заставляет использовать стандартный цикл
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Запускаем сервер (lifespan запустится внутри)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -3,9 +3,9 @@ from fastapi import Request, Depends
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.kling_adapter import KlingAdapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
# ... ваши импорты ...
|
||||
@@ -37,34 +37,20 @@ def get_dao(
|
||||
# так что DAO создастся один раз за запрос.
|
||||
return DAO(mongo_client, s3_adapter)
|
||||
|
||||
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
|
||||
return request.app.state.kling_adapter
|
||||
|
||||
# Провайдер сервиса (собирается из DAO и Gemini)
|
||||
def get_generation_service(
|
||||
dao: DAO = Depends(get_dao),
|
||||
gemini: GoogleAdapter = Depends(get_gemini_client),
|
||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||
bot: Bot = Depends(get_bot_client),
|
||||
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
|
||||
) -> GenerationService:
|
||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
||||
|
||||
from api.service.idea_service import IdeaService
|
||||
|
||||
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
|
||||
return IdeaService(dao)
|
||||
return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
|
||||
|
||||
from fastapi import Header
|
||||
|
||||
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||
return x_project_id
|
||||
|
||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
||||
return AlbumService(dao)
|
||||
|
||||
from api.service.post_service import PostService
|
||||
|
||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||
return PostService(dao)
|
||||
|
||||
from api.service.inspiration_service import InspirationService
|
||||
|
||||
def get_inspiration_service(dao: DAO = Depends(get_dao), s3_adapter: S3Adapter = Depends(get_s3_adapter)) -> InspirationService:
|
||||
return InspirationService(dao, s3_adapter)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,12 +1,10 @@
|
||||
from typing import Annotated
|
||||
from typing import Annotated, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from repos.user_repo import UsersRepo, UserStatus
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||
from jose import JWTError, jwt
|
||||
from starlette.requests import Request
|
||||
@@ -25,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
username: str | None = payload.get("sub")
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
@@ -54,7 +52,7 @@ class UserResponse(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@router.get("/approvals", response_model=list[UserResponse])
|
||||
@router.get("/approvals", response_model=List[UserResponse])
|
||||
async def list_pending_users(
|
||||
admin: Annotated[dict, Depends(get_current_admin)],
|
||||
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
from fastapi import APIRouter, HTTPException, status, Request
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
from models.Album import Album
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_album_service
|
||||
from api.service.album_service import AlbumService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||
|
||||
class AlbumCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class AlbumUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
class AlbumResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
generation_ids: list[str] = []
|
||||
cover_asset_id: str | None = None # Not implemented yet
|
||||
description: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||
|
||||
@router.post("", response_model=AlbumResponse)
|
||||
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||
@@ -31,7 +29,7 @@ async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||
return AlbumResponse(**album.model_dump())
|
||||
|
||||
@router.get("", response_model=list[AlbumResponse])
|
||||
@router.get("", response_model=List[AlbumResponse])
|
||||
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
albums = await service.get_albums(limit=limit, offset=offset)
|
||||
@@ -76,7 +74,7 @@ async def remove_generation_from_album(request: Request, album_id: str, generati
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{album_id}/generations", response_model=list[GenerationResponse])
|
||||
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
||||
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||
service: AlbumService = request.app.state.album_service
|
||||
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from aiogram.types import BufferedInputFile
|
||||
from bson import ObjectId
|
||||
@@ -9,10 +9,10 @@ from pymongo import MongoClient
|
||||
from starlette import status
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse, StreamingResponse
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||
@@ -33,100 +33,27 @@ async def get_asset(
|
||||
asset_id: str,
|
||||
request: Request,
|
||||
thumbnail: bool = False,
|
||||
dao: DAO = Depends(get_dao),
|
||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> Response:
|
||||
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
||||
# Загружаем только метаданные (без data/thumbnail bytes)
|
||||
asset = await dao.assets.get_asset(asset_id, with_data=False)
|
||||
asset = await dao.assets.get_asset(asset_id)
|
||||
# 2. Проверка на существование
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
base_headers = {
|
||||
"Cache-Control": "public, max-age=31536000, immutable",
|
||||
"Accept-Ranges": "bytes"
|
||||
headers = {
|
||||
# Кэшировать на 1 год (31536000 сек)
|
||||
"Cache-Control": "public, max-age=31536000, immutable"
|
||||
}
|
||||
|
||||
# Thumbnail: маленький, можно грузить в RAM
|
||||
if thumbnail:
|
||||
if asset.minio_thumbnail_object_name and s3_adapter:
|
||||
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
||||
if thumb_bytes:
|
||||
return Response(content=thumb_bytes, media_type="image/jpeg", headers=base_headers)
|
||||
# Fallback: thumbnail in DB
|
||||
if asset.thumbnail:
|
||||
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=base_headers)
|
||||
# No thumbnail available — fall through to main content
|
||||
content = asset.data
|
||||
media_type = "image/png" # Default, or detect
|
||||
|
||||
# Main content: стримим из S3 без загрузки в RAM
|
||||
if asset.minio_object_name and s3_adapter:
|
||||
content_type = "image/png"
|
||||
if asset.content_type == AssetContentType.VIDEO:
|
||||
content_type = "video/mp4" # Or detect from extension if stored
|
||||
elif asset.content_type == AssetContentType.IMAGE:
|
||||
content_type = "image/png" # Default for images
|
||||
if thumbnail and asset.thumbnail:
|
||||
content = asset.thumbnail
|
||||
media_type = "image/jpeg"
|
||||
|
||||
# Better content type detection based on extension if possible, but for now this is okay
|
||||
if asset.minio_object_name.endswith(".mp4"):
|
||||
content_type = "video/mp4"
|
||||
elif asset.minio_object_name.endswith(".mov"):
|
||||
content_type = "video/quicktime"
|
||||
elif asset.minio_object_name.endswith(".jpg") or asset.minio_object_name.endswith(".jpeg"):
|
||||
content_type = "image/jpeg"
|
||||
|
||||
# Handle Range requests for video streaming
|
||||
range_header = request.headers.get("range")
|
||||
file_size = await s3_adapter.get_file_size(asset.minio_object_name)
|
||||
|
||||
if range_header and file_size:
|
||||
try:
|
||||
# Parse Range header: bytes=start-end
|
||||
byte_range = range_header.replace("bytes=", "")
|
||||
start_str, end_str = byte_range.split("-")
|
||||
start = int(start_str)
|
||||
end = int(end_str) if end_str else file_size - 1
|
||||
|
||||
# Validate range
|
||||
if start >= file_size:
|
||||
# 416 Range Not Satisfiable
|
||||
return Response(status_code=416, headers={"Content-Range": f"bytes */{file_size}"})
|
||||
|
||||
chunk_size = end - start + 1
|
||||
|
||||
headers = base_headers.copy()
|
||||
headers.update({
|
||||
"Content-Range": f"bytes {start}-{end}/{file_size}",
|
||||
"Content-Length": str(chunk_size),
|
||||
})
|
||||
|
||||
# Pass the exact range string to S3
|
||||
s3_range = f"bytes={start}-{end}"
|
||||
|
||||
return StreamingResponse(
|
||||
s3_adapter.stream_file(asset.minio_object_name, range_header=s3_range),
|
||||
status_code=206,
|
||||
headers=headers,
|
||||
media_type=content_type
|
||||
)
|
||||
except ValueError:
|
||||
pass # Fallback to full content if range parsing fails
|
||||
|
||||
# Full content response
|
||||
headers = base_headers.copy()
|
||||
if file_size:
|
||||
headers["Content-Length"] = str(file_size)
|
||||
|
||||
return StreamingResponse(
|
||||
s3_adapter.stream_file(asset.minio_object_name),
|
||||
media_type=content_type,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Fallback: data stored in DB (legacy)
|
||||
if asset.data:
|
||||
return Response(content=asset.data, media_type="image/png", headers=base_headers)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Asset data not found")
|
||||
return Response(content=content, media_type=media_type, headers=headers)
|
||||
|
||||
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||
async def delete_orphan_assets_from_minio(
|
||||
@@ -135,22 +62,22 @@ async def delete_orphan_assets_from_minio(
|
||||
*,
|
||||
assets_collection: str = "assets",
|
||||
generations_collection: str = "generations",
|
||||
asset_type: str | None = "generated",
|
||||
project_id: str | None = None,
|
||||
asset_type: Optional[str] = "generated",
|
||||
project_id: Optional[str] = None,
|
||||
dry_run: bool = True,
|
||||
mark_assets_deleted: bool = False,
|
||||
batch_size: int = 500,
|
||||
) -> dict[str, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||
assets = db[assets_collection]
|
||||
|
||||
match_assets: dict[str, Any] = {}
|
||||
match_assets: Dict[str, Any] = {}
|
||||
if asset_type is not None:
|
||||
match_assets["type"] = asset_type
|
||||
if project_id is not None:
|
||||
match_assets["project_id"] = project_id
|
||||
|
||||
pipeline: list[dict[str, Any]] = [
|
||||
pipeline: List[Dict[str, Any]] = [
|
||||
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||
{
|
||||
"$lookup": {
|
||||
@@ -192,8 +119,8 @@ async def delete_orphan_assets_from_minio(
|
||||
|
||||
deleted_objects = 0
|
||||
deleted_assets = 0
|
||||
errors: list[dict[str, Any]] = []
|
||||
orphan_asset_ids: list[ObjectId] = []
|
||||
errors: List[Dict[str, Any]] = []
|
||||
orphan_asset_ids: List[ObjectId] = []
|
||||
|
||||
async for asset in cursor:
|
||||
aid = asset["_id"]
|
||||
@@ -259,7 +186,7 @@ async def delete_asset(
|
||||
|
||||
|
||||
@router.get("", dependencies=[Depends(get_current_user)])
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str | None = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: str | None = Depends(get_project_id)) -> AssetsResponse:
|
||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
|
||||
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||
|
||||
user_id_filter = current_user["id"]
|
||||
@@ -286,10 +213,10 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str |
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
linked_char_id: str | None = Form(None),
|
||||
linked_char_id: Optional[str] = Form(None),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id)
|
||||
project_id: Optional[str] = Depends(get_project_id)
|
||||
):
|
||||
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||
if not file.content_type:
|
||||
@@ -332,7 +259,8 @@ async def upload_asset(
|
||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
|
||||
linked_char_id=asset.linked_char_id,
|
||||
created_at=asset.created_at
|
||||
created_at=asset.created_at,
|
||||
url=asset.url
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
from typing import Any, Coroutine
|
||||
from typing import List, Any, Coroutine, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from api.models import GenerationRequest, GenerationResponse
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
from models.Asset import Asset
|
||||
from models.Character import Character
|
||||
from api.models import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
@@ -23,16 +23,9 @@ from api.dependency import get_project_id
|
||||
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
@router.get("/", response_model=list[Character])
|
||||
async def get_characters(
|
||||
request: Request,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> list[Character]:
|
||||
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
||||
@router.get("/", response_model=List[Character])
|
||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
|
||||
logger.info("get_characters called")
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
@@ -41,12 +34,7 @@ async def get_characters(
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
characters = await dao.chars.get_all_characters(
|
||||
created_by=user_id_filter,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
|
||||
return characters
|
||||
|
||||
|
||||
@@ -102,7 +90,7 @@ async def get_character_by_id(character_id: str, request: Request, dao: DAO = De
|
||||
@router.post("/", response_model=Character)
|
||||
async def create_character(
|
||||
char_req: CharacterCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Character:
|
||||
@@ -190,3 +178,10 @@ async def delete_character(
|
||||
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||
|
||||
return
|
||||
|
||||
|
||||
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||
request: Request) -> GenerationResponse:
|
||||
logger.info(f"post_character_generation called. CharacterID: {character_id}")
|
||||
generation_service = request.app.state.generation_service
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||
from models.Environment import Environment
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/", response_model=Environment)
|
||||
async def create_environment(
|
||||
env_req: EnvironmentCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||
await check_character_access(env_req.character_id, current_user, dao)
|
||||
|
||||
# Verify assets exist if provided
|
||||
if env_req.asset_ids:
|
||||
for aid in env_req.asset_ids:
|
||||
asset = await dao.assets.get_asset(aid)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||
|
||||
new_env = Environment(**env_req.model_dump())
|
||||
created_env = await dao.environments.create_env(new_env)
|
||||
return created_env
|
||||
|
||||
|
||||
@router.get("/character/{character_id}", response_model=list[Environment])
|
||||
async def get_character_environments(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Getting environments for character {character_id}")
|
||||
await check_character_access(character_id, current_user, dao)
|
||||
return await dao.environments.get_character_envs(character_id)
|
||||
|
||||
|
||||
@router.get("/{env_id}", response_model=Environment)
|
||||
async def get_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
return env
|
||||
|
||||
|
||||
@router.put("/{env_id}", response_model=Environment)
|
||||
async def update_environment(
|
||||
env_id: str,
|
||||
env_update: EnvironmentUpdate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
update_data = env_update.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
return env
|
||||
|
||||
# Verify assets exist if provided
|
||||
if "asset_ids" in update_data:
|
||||
if update_data["asset_ids"] is None:
|
||||
del update_data["asset_ids"]
|
||||
elif update_data["asset_ids"]:
|
||||
# Verify all assets exist using batch check
|
||||
assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"])
|
||||
if len(assets) != len(update_data["asset_ids"]):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids]
|
||||
raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.update_env(env_id, update_data)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||
|
||||
return await dao.environments.get_env(env_id)
|
||||
|
||||
|
||||
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.delete_env(env_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||
async def add_asset_to_environment(
|
||||
env_id: str,
|
||||
req: AssetToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify asset exists
|
||||
asset = await dao.assets.get_asset(req.asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||
async def add_assets_batch_to_environment(
|
||||
env_id: str,
|
||||
req: AssetsToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify all assets exist
|
||||
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||
if len(assets) != len(req.asset_ids):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||
async def remove_asset_from_environment(
|
||||
env_id: str,
|
||||
asset_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||
return {"success": success}
|
||||
@@ -1,225 +1,138 @@
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from config import settings
|
||||
from api import service
|
||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models import (
|
||||
GenerationResponse,
|
||||
GenerationRequest,
|
||||
GenerationsResponse,
|
||||
PromptResponse,
|
||||
PromptRequest,
|
||||
GenerationGroupResponse,
|
||||
FinancialReport,
|
||||
ExternalGenerationRequest,
|
||||
NsfwRequest
|
||||
)
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
from utils.external_auth import verify_signature
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Generation import Generation
|
||||
|
||||
from starlette import status
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
|
||||
"""Helper to check if user has access to project."""
|
||||
if not project_id:
|
||||
return
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(
|
||||
prompt_request: PromptRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(
|
||||
prompt_request.prompt,
|
||||
prompt_request.linked_assets,
|
||||
prompt_request.model
|
||||
)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.post("/prompt-from-image", response_model=PromptResponse)
|
||||
async def prompt_from_image(
|
||||
prompt: str | None = Form(None),
|
||||
model: str = Form("gemini-3.1-pro-preview"),
|
||||
images: list[UploadFile] = File(...),
|
||||
prompt: Optional[str] = Form(None),
|
||||
images: List[UploadFile] = File(...),
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
images_bytes = [await img.read() for img in images]
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
for image in images:
|
||||
content = await image.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
|
||||
|
||||
@router.get("", response_model=GenerationsResponse)
|
||||
async def get_generations(
|
||||
character_id: str | None = None,
|
||||
limit: int = 10,
|
||||
offset: int = 0,
|
||||
only_liked: bool = False,
|
||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||
|
||||
# If project_id is set, we don't filter by user to show all project-wide generations
|
||||
created_by_filter = None if project_id else str(current_user["_id"])
|
||||
only_liked_by = str(current_user["_id"]) if only_liked else None
|
||||
|
||||
return await generation_service.get_generations(
|
||||
character_id=character_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
created_by=created_by_filter,
|
||||
project_id=project_id,
|
||||
only_liked_by=only_liked_by,
|
||||
current_user_id=str(current_user["_id"])
|
||||
)
|
||||
|
||||
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
async def get_usage_report(
|
||||
breakdown: str | None = None, # "user" or "project"
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
|
||||
user_id_filter = str(current_user["_id"]) if not project_id else None
|
||||
breakdown_by = None
|
||||
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id,
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(
|
||||
generation: GenerationRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> GenerationGroupResponse:
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # Show all project generations
|
||||
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.post("/_run", response_model=GenerationResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)) -> GenerationResponse:
|
||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
generation.project_id = project_id
|
||||
|
||||
return await generation_service.create_generation_task(
|
||||
generation,
|
||||
user_id=str(current_user.get("_id"))
|
||||
)
|
||||
|
||||
|
||||
@router.get("/running")
|
||||
async def get_running_generations(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
):
|
||||
await check_project_access(project_id, current_user, dao)
|
||||
user_id_filter = None if project_id else str(current_user["_id"])
|
||||
|
||||
return await generation_service.get_running_generations(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||
async def get_generation_group(
|
||||
group_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
|
||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||
|
||||
|
||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||
async def get_generation(
|
||||
generation_id: str,
|
||||
async def get_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> GenerationResponse:
|
||||
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
gen = await generation_service.get_generation(generation_id)
|
||||
if gen and gen.created_by != str(current_user["_id"]):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return gen
|
||||
|
||||
|
||||
@router.post("/{generation_id}/like", response_model=dict)
|
||||
async def toggle_like(
|
||||
generation_id: str,
|
||||
@router.get("/running")
|
||||
async def get_running_generations(request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
|
||||
if is_liked is None:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return {"is_liked": is_liked}
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)):
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_generation_nsfw(
|
||||
generation_id: str,
|
||||
request: NsfwRequest,
|
||||
@router.post("/video/_run", response_model=GenerationResponse)
|
||||
async def post_video_generation(
|
||||
video_request: VideoGenerationRequest,
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao),
|
||||
) -> GenerationResponse:
|
||||
"""Start image-to-video generation using Kling AI."""
|
||||
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
if project_id:
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
video_request.project_id = project_id
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw)
|
||||
return None
|
||||
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
@@ -228,18 +141,40 @@ async def import_external_generation(
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
import os
|
||||
from utils.external_auth import verify_signature
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
# Verify signature
|
||||
secret = os.getenv("EXTERNAL_API_SECRET")
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
logger.warning("Invalid signature for external generation import")
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||
|
||||
# Import generation
|
||||
try:
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
@@ -248,11 +183,11 @@ async def import_external_generation(
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(
|
||||
generation_id: str,
|
||||
async def delete_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
if not await generation_service.delete_generation(generation_id):
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
return None
|
||||
@@ -1,106 +0,0 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.idea_service import IdeaService
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Idea import Idea
|
||||
from api.models import GenerationResponse, GenerationsResponse
|
||||
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
|
||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
|
||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
||||
|
||||
@router.post("", response_model=Idea)
|
||||
async def create_idea(
|
||||
request: IdeaCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
return await idea_service.create_idea(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
inspiration_id=request.inspiration_id
|
||||
)
|
||||
|
||||
@router.get("", response_model=list[IdeaResponse])
|
||||
async def get_ideas(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
|
||||
|
||||
@router.get("/{idea_id}", response_model=Idea)
|
||||
async def get_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.get_idea(idea_id)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.put("/{idea_id}", response_model=Idea)
|
||||
async def update_idea(
|
||||
idea_id: str,
|
||||
request: IdeaUpdateRequest,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
idea = await idea_service.update_idea(
|
||||
idea_id=idea_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
inspiration_id=request.inspiration_id
|
||||
)
|
||||
if not idea:
|
||||
raise HTTPException(status_code=404, detail="Idea not found")
|
||||
return idea
|
||||
|
||||
@router.delete("/{idea_id}")
|
||||
async def delete_idea(
|
||||
idea_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.delete_idea(idea_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
|
||||
async def get_idea_generations(
|
||||
idea_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset, current_user_id=str(current_user["_id"]))
|
||||
|
||||
@router.post("/{idea_id}/generations/{generation_id}")
|
||||
async def add_generation_to_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
|
||||
@router.delete("/{idea_id}/generations/{generation_id}")
|
||||
async def remove_generation_from_idea(
|
||||
idea_id: str,
|
||||
generation_id: str,
|
||||
idea_service: IdeaService = Depends(get_idea_service)
|
||||
):
|
||||
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
||||
return {"status": "success"}
|
||||
@@ -1,94 +0,0 @@
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from api.dependency import get_inspiration_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.InspirationRequest import InspirationCreateRequest, InspirationResponse, InspirationListResponse
|
||||
from api.service.inspiration_service import InspirationService
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
router = APIRouter(prefix="/api/inspirations", tags=["Inspirations"])
|
||||
|
||||
|
||||
@router.post("", response_model=InspirationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_inspiration(
|
||||
request: InspirationCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
service: InspirationService = Depends(get_inspiration_service)
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
|
||||
inspiration = await service.create_inspiration(
|
||||
source_url=request.source_url,
|
||||
created_by=str(current_user["_id"]),
|
||||
project_id=pid,
|
||||
caption=request.caption
|
||||
)
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.get("", response_model=InspirationListResponse)
|
||||
async def get_inspirations(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
service: InspirationService = Depends(get_inspiration_service)
|
||||
):
|
||||
# If project_id is provided, filter by it. Otherwise, filter by user.
|
||||
# Or maybe we want to see all user's inspirations if no project is selected?
|
||||
# Let's follow the pattern: if project_id is present, show project's inspirations.
|
||||
# If not, show user's personal inspirations (where project_id is None) OR all user's inspirations?
|
||||
# Usually "My Inspirations" means created by me.
|
||||
|
||||
# Let's assume:
|
||||
# If project_id -> filter by project_id (and maybe created_by if we want strict ownership, but usually project members share)
|
||||
# If no project_id -> filter by created_by (personal)
|
||||
|
||||
pid = project_id
|
||||
uid = str(current_user["_id"])
|
||||
|
||||
inspirations = await service.get_inspirations(project_id=pid, created_by=uid if not pid else None, limit=limit, offset=offset)
|
||||
total_count = await service.dao.inspirations.count_inspirations(project_id=pid, created_by=uid if not pid else None)
|
||||
|
||||
return InspirationListResponse(
|
||||
inspirations=[InspirationResponse(**inspiration.model_dump()) for inspiration in inspirations],
|
||||
total_count=total_count
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{inspiration_id}", response_model=InspirationResponse)
|
||||
async def get_inspiration(
|
||||
inspiration_id: str,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
inspiration = await service.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.patch("/{inspiration_id}/complete", response_model=InspirationResponse)
|
||||
async def mark_inspiration_complete(
|
||||
inspiration_id: str,
|
||||
is_completed: bool = True,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
inspiration = await service.mark_as_completed(inspiration_id, is_completed)
|
||||
if not inspiration:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return inspiration
|
||||
|
||||
|
||||
@router.delete("/{inspiration_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_inspiration(
|
||||
inspiration_id: str,
|
||||
service: InspirationService = Depends(get_inspiration_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
success = await service.delete_inspiration(inspiration_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||
return None
|
||||
@@ -1,98 +0,0 @@
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.dependency import get_post_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.post_service import PostService
|
||||
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
from models.Post import Post
|
||||
|
||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
||||
|
||||
|
||||
@router.post("", response_model=Post)
|
||||
async def create_post(
|
||||
request: PostCreateRequest,
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
return await post_service.create_post(
|
||||
date=request.date,
|
||||
topic=request.topic,
|
||||
generation_ids=request.generation_ids,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=list[Post])
|
||||
async def get_posts(
|
||||
project_id: str | None = Depends(get_project_id),
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
date_from: datetime | None = None,
|
||||
date_to: datetime | None = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=Post)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=Post)
|
||||
async def update_post(
|
||||
post_id: str,
|
||||
request: PostUpdateRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.delete_post(post_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/generations")
|
||||
async def add_generations(
|
||||
post_id: str,
|
||||
request: AddGenerationsRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/generations/{generation_id}")
|
||||
async def remove_generation(
|
||||
post_id: str,
|
||||
generation_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.remove_generation(post_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
||||
return {"status": "success"}
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from bson import ObjectId
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from api.dependency import get_dao
|
||||
@@ -11,48 +10,16 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
|
||||
class ProjectMemberResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class ProjectResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
description: Optional[str] = None
|
||||
owner_id: str
|
||||
members: list[ProjectMemberResponse]
|
||||
members: List[str]
|
||||
is_owner: bool = False
|
||||
|
||||
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
||||
member_responses = []
|
||||
for member_id in project.members:
|
||||
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
|
||||
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
|
||||
# But for Web users we might need to search by _id.
|
||||
# Let's try to get user info.
|
||||
# Since project.members contains strings (ObjectIds as strings), we search by _id.
|
||||
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
|
||||
if not user_doc and member_id.isdigit():
|
||||
# Fallback for telegram IDs if they are stored as strings of digits
|
||||
user_doc = await dao.users.get_user(int(member_id))
|
||||
|
||||
username = "unknown"
|
||||
if user_doc:
|
||||
username = user_doc.get("username", "unknown")
|
||||
|
||||
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
|
||||
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
description=project.description,
|
||||
owner_id=project.owner_id,
|
||||
members=member_responses,
|
||||
is_owner=(project.owner_id == current_user_id)
|
||||
)
|
||||
|
||||
@router.post("", response_model=ProjectResponse)
|
||||
async def create_project(
|
||||
project_data: ProjectCreate,
|
||||
@@ -67,17 +34,29 @@ async def create_project(
|
||||
members=[user_id]
|
||||
)
|
||||
project_id = await dao.projects.create_project(new_project)
|
||||
new_project.id = project_id
|
||||
|
||||
# Add project to user's project list
|
||||
# Assuming user_repo has a method to add project or we do it directly?
|
||||
# UserRepo doesn't have add_project method yet.
|
||||
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
|
||||
# Better to update UserRepo. For now, let's just return success.
|
||||
# But user needs to see it in list.
|
||||
# Update user in DB
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return await _get_project_response(new_project, user_id, dao)
|
||||
return ProjectResponse(
|
||||
id=project_id,
|
||||
name=new_project.name,
|
||||
description=new_project.description,
|
||||
owner_id=new_project.owner_id,
|
||||
members=new_project.members,
|
||||
is_owner=True
|
||||
)
|
||||
|
||||
@router.get("", response_model=list[ProjectResponse])
|
||||
@router.get("", response_model=List[ProjectResponse])
|
||||
async def get_my_projects(
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
@@ -87,7 +66,14 @@ async def get_my_projects(
|
||||
|
||||
responses = []
|
||||
for p in projects:
|
||||
responses.append(await _get_project_response(p, user_id, dao))
|
||||
responses.append(ProjectResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
owner_id=p.owner_id,
|
||||
members=p.members,
|
||||
is_owner=(p.owner_id == user_id)
|
||||
))
|
||||
return responses
|
||||
|
||||
class MemberAdd(BaseModel):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -10,10 +11,10 @@ class AssetResponse(BaseModel):
|
||||
name: str
|
||||
type: str # uploaded / generated
|
||||
content_type: str # image / prompt
|
||||
linked_char_id: str | None = None
|
||||
linked_char_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
url: str | None = None
|
||||
url: Optional[str] = None
|
||||
|
||||
class AssetsResponse(BaseModel):
|
||||
assets: list[AssetResponse]
|
||||
assets: List[AssetResponse]
|
||||
total_count: int
|
||||
@@ -1,17 +1,18 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class CharacterCreateRequest(BaseModel):
|
||||
name: str
|
||||
character_bio: str
|
||||
character_image_doc_tg_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
project_id: str | None = None
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
class CharacterUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
character_bio: str | None = None
|
||||
character_image_doc_tg_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
project_id: str | None = None
|
||||
name: Optional[str] = None
|
||||
character_bio: Optional[str] = None
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EnvironmentCreate(BaseModel):
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] | None = []
|
||||
|
||||
|
||||
class EnvironmentUpdate(BaseModel):
|
||||
name: str | None = Field(None, min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] | None = None
|
||||
|
||||
|
||||
class AssetToEnvironment(BaseModel):
|
||||
asset_id: str
|
||||
|
||||
|
||||
class AssetsToEnvironment(BaseModel):
|
||||
asset_ids: list[str]
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
@@ -6,31 +7,27 @@ class ExternalGenerationRequest(BaseModel):
|
||||
"""Request model for importing external generations."""
|
||||
|
||||
prompt: str
|
||||
tech_prompt: str | None = None
|
||||
tech_prompt: Optional[str] = None
|
||||
|
||||
# Image can be provided as base64 string OR URL (one must be provided)
|
||||
image_data: str | None = Field(None, description="Base64-encoded image data")
|
||||
image_url: str | None = Field(None, description="URL to download image from")
|
||||
|
||||
nsfw: bool = False
|
||||
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
||||
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||
|
||||
# Generation metadata
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
quality: Quality = Quality.ONEK
|
||||
model: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
# Optional linking
|
||||
linked_character_id: str | None = None
|
||||
linked_character_id: Optional[str] = None
|
||||
created_by: str = Field(..., description="User ID from external system")
|
||||
project_id: str | None = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
# Performance metrics
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
|
||||
def validate_image_source(self):
|
||||
"""Ensure at least one image source is provided."""
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
total_runs: int
|
||||
total_tokens: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_cost: float
|
||||
|
||||
class UsageByEntity(BaseModel):
|
||||
entity_id: str | None = None
|
||||
stats: UsageStats
|
||||
|
||||
class FinancialReport(BaseModel):
|
||||
summary: UsageStats
|
||||
by_user: list[UsageByEntity] | None = None
|
||||
by_project: list[UsageByEntity] | None = None
|
||||
@@ -1,78 +1,63 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Asset import Asset
|
||||
from models.Generation import GenerationStatus
|
||||
from models.enums import AspectRatios, Quality, GenType, ImageModel, TextModel
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
linked_character_id: str | None = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
quality: Quality = Quality.ONEK
|
||||
prompt: str
|
||||
model: ImageModel = Field(default=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW)
|
||||
telegram_id: int | None = None
|
||||
telegram_id: Optional[int] = None
|
||||
use_profile_image: bool = True
|
||||
assets_list: list[str]
|
||||
environment_id: str | None = None
|
||||
project_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
nsfw: bool = False
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
|
||||
class NsfwRequest(BaseModel):
|
||||
is_nsfw: bool
|
||||
assets_list: List[str]
|
||||
project_id: Optional[str] = None
|
||||
|
||||
|
||||
class GenerationsResponse(BaseModel):
|
||||
generations: list["GenerationResponse"]
|
||||
generations: List["GenerationResponse"]
|
||||
total_count: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
id: str
|
||||
status: GenerationStatus
|
||||
failed_reason: str | None = None
|
||||
project_id: str | None = None
|
||||
linked_character_id: str | None = None
|
||||
gen_type: GenType = GenType.IMAGE
|
||||
failed_reason: Optional[str] = None
|
||||
|
||||
linked_character_id: Optional[str] = None
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
model: ImageModel | None = None
|
||||
seed: int | None = None
|
||||
tech_prompt: str | None = None
|
||||
assets_list: list[str]
|
||||
result_list: list[str] = []
|
||||
result: str | None = None
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
tech_prompt: Optional[str] = None
|
||||
assets_list: List[str]
|
||||
result_list: List[str] = []
|
||||
result: Optional[str] = None
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
progress: int = 0
|
||||
cost: float | None = None
|
||||
created_by: str | None = None
|
||||
generation_group_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
likes_count: int = 0
|
||||
is_liked: bool = False
|
||||
nsfw: bool = False
|
||||
cost: Optional[float] = None
|
||||
created_by: Optional[str] = None
|
||||
# Video-specific
|
||||
kling_task_id: Optional[str] = None
|
||||
video_duration: Optional[int] = None
|
||||
video_mode: Optional[str] = None
|
||||
created_at: datetime = datetime.now(UTC)
|
||||
updated_at: datetime = datetime.now(UTC)
|
||||
|
||||
|
||||
class GenerationGroupResponse(BaseModel):
|
||||
generation_group_id: str
|
||||
generations: list[GenerationResponse]
|
||||
|
||||
|
||||
class PromptRequest(BaseModel):
|
||||
prompt: str
|
||||
model: TextModel = Field(default=TextModel.GEMINI_3_1_PRO_PREVIEW)
|
||||
linked_assets: list[str] = []
|
||||
linked_assets: List[str] = []
|
||||
|
||||
|
||||
class PromptResponse(BaseModel):
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse
|
||||
|
||||
class IdeaCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
project_id: str | None = None # Optional in body if passed via header/dependency
|
||||
inspiration_id: str | None = None
|
||||
|
||||
class IdeaUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
inspiration_id: str | None = None
|
||||
|
||||
class IdeaResponse(Idea):
|
||||
last_generation: GenerationResponse | None = None
|
||||
@@ -1,28 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
|
||||
class InspirationCreateRequest(BaseModel):
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class InspirationResponse(BaseModel):
|
||||
id: str
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
asset_id: str
|
||||
is_completed: bool
|
||||
created_by: str
|
||||
project_id: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class InspirationListResponse(BaseModel):
|
||||
inspirations: list[InspirationResponse]
|
||||
total_count: int
|
||||
@@ -1,18 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PostCreateRequest(BaseModel):
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: list[str] = []
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class PostUpdateRequest(BaseModel):
|
||||
date: datetime | None = None
|
||||
topic: str | None = None
|
||||
|
||||
|
||||
class AddGenerationsRequest(BaseModel):
|
||||
generation_ids: list[str]
|
||||
16
api/models/VideoGenerationRequest.py
Normal file
16
api/models/VideoGenerationRequest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class VideoGenerationRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: Optional[str] = ""
|
||||
image_asset_id: str # ID ассета-картинки для source image
|
||||
duration: int = 5 # 5 or 10 seconds
|
||||
mode: str = "std" # "std" or "pro"
|
||||
model_name: str = "kling-v2-1"
|
||||
cfg_scale: float = 0.5
|
||||
aspect_ratio: str = "16:9"
|
||||
linked_character_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
@@ -1,7 +0,0 @@
|
||||
from .AssetDTO import AssetResponse, AssetsResponse
|
||||
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from .ExternalGenerationDTO import ExternalGenerationRequest
|
||||
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
||||
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse, NsfwRequest
|
||||
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,183 +1,245 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import base64
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from io import BytesIO
|
||||
import httpx
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import (
|
||||
FinancialReport, UsageStats, UsageByEntity,
|
||||
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
)
|
||||
from adapters.kling_adapter import KlingAdapter, KlingApiException
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from repos.dao import DAO
|
||||
from utils.image_utils import create_thumbnail
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limit concurrent generations to 4
|
||||
generation_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
|
||||
# --- Вспомогательная функция генерации ---
|
||||
async def generate_image_task(
|
||||
prompt: str,
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
model: str,
|
||||
gemini: GoogleAdapter,
|
||||
|
||||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||
"""
|
||||
Wrapper for calling Gemini's synchronous method in a separate thread.
|
||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||
Возвращает список байтов сгенерированных изображений.
|
||||
"""
|
||||
try:
|
||||
try :
|
||||
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||
result = await asyncio.to_thread(
|
||||
gemini.generate_image,
|
||||
prompt=prompt,
|
||||
images_list=media_group_bytes,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
model=model,
|
||||
)
|
||||
generated_images_io, metrics = result
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
except GoogleGenerationException:
|
||||
raise
|
||||
finally:
|
||||
del media_group_bytes
|
||||
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
except GoogleGenerationException as e:
|
||||
raise e
|
||||
images_bytes = []
|
||||
if generated_images_io:
|
||||
for img_io in generated_images_io:
|
||||
# Читаем байты из BytesIO
|
||||
img_io.seek(0)
|
||||
images_bytes.append(img_io.read())
|
||||
content = img_io.read()
|
||||
images_bytes.append(content)
|
||||
|
||||
# Закрываем поток
|
||||
img_io.close()
|
||||
del generated_images_io
|
||||
|
||||
return images_bytes, metrics
|
||||
|
||||
|
||||
class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
self.s3_adapter = s3_adapter
|
||||
self.bot = bot
|
||||
self.kling_adapter = kling_adapter
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
future_prompt = (
|
||||
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
||||
"User may upload reference image too. I will provide sources prompt entered by user. "
|
||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||
f"USER_ENTERED_PROMPT: {prompt}"
|
||||
)
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||
future_prompt += prompt
|
||||
assets_data = []
|
||||
if assets:
|
||||
if assets is not None:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
assets_data.extend(asset.data for asset in assets_db)
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||||
logger.info(future_prompt)
|
||||
logger.info(generated_prompt)
|
||||
return generated_prompt
|
||||
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||
if user_prompt:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
current_user_id = kwargs.pop('current_user_id', None)
|
||||
generations = await self.dao.generations.get_generations(**kwargs)
|
||||
total_count = await self.dao.generations.count_generations(
|
||||
character_id=kwargs.get('character_id'),
|
||||
created_by=kwargs.get('created_by'),
|
||||
project_id=kwargs.get('project_id'),
|
||||
idea_id=kwargs.get('idea_id'),
|
||||
only_liked_by=kwargs.get('only_liked_by')
|
||||
)
|
||||
return GenerationsResponse(
|
||||
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
|
||||
total_count=total_count
|
||||
)
|
||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
|
||||
Generation]:
|
||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
|
||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
|
||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||
|
||||
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
|
||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
return self._map_to_response(gen, current_user_id) if gen else None
|
||||
|
||||
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||
return await self.dao.generations.toggle_like(generation_id, user_id)
|
||||
|
||||
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
generations = await self.dao.generations.get_generations_by_group(group_id)
|
||||
return GenerationGroupResponse(
|
||||
generation_group_id=group_id,
|
||||
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
|
||||
)
|
||||
|
||||
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
|
||||
res = GenerationResponse(**gen.model_dump())
|
||||
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
|
||||
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
|
||||
return res
|
||||
if gen is None:
|
||||
return None
|
||||
else:
|
||||
return GenerationResponse(**gen.model_dump())
|
||||
|
||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||
|
||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||
if generation_group_id is None:
|
||||
generation_group_id = str(uuid4())
|
||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
results = []
|
||||
for _ in range(generation_request.count):
|
||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||
results.append(gen_response)
|
||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump())
|
||||
if user_id:
|
||||
generation_model.created_by = user_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||
generation_model.id = gen_id
|
||||
|
||||
async def runner(gen):
|
||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||
try:
|
||||
await self.create_generation(gen)
|
||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
# если генерация уже пошла и упала — пометим FAILED
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED")
|
||||
logger.exception("create_generation task failed")
|
||||
|
||||
asyncio.create_task(runner(generation_model))
|
||||
|
||||
return GenerationResponse(**generation_model.model_dump())
|
||||
|
||||
except Exception:
|
||||
# если не успели создать запись — нечего помечать
|
||||
if gen_id is not None:
|
||||
try:
|
||||
gen = await self.dao.generations.get_generation(gen_id)
|
||||
gen.status = GenerationStatus.FAILED
|
||||
await self.dao.generations.update_generation(gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||
raise
|
||||
|
||||
async def create_generation(self, generation: Generation):
|
||||
start_time = datetime.now()
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 1. Prepare input
|
||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = generation.prompt
|
||||
# generation_prompt = f"""
|
||||
|
||||
# 2. Run generation with progress simulation
|
||||
# Create detailed image of character in scene.
|
||||
|
||||
# SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
# Rules:
|
||||
# - Integrate the character's appearance naturally into the scene description.
|
||||
# - Focus on lighting, texture, and composition.
|
||||
# """
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
|
||||
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
||||
|
||||
# 3. Запускаем процесс генерации и симуляцию прогресса
|
||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||
|
||||
try:
|
||||
|
||||
# Default to Image Generation (Gemini)
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt,
|
||||
prompt=generation_prompt, # или request.prompt
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
|
||||
# 3. Process results
|
||||
created_assets = await self._process_generated_images(generation, generated_bytes_list)
|
||||
|
||||
# 4. Finalize generation record
|
||||
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
|
||||
# Update metrics from API (Common for both)
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
|
||||
# 5. Notify
|
||||
if generation.telegram_id and self.bot:
|
||||
await self._notify_telegram(generation, created_assets)
|
||||
except GoogleGenerationException as e:
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# Тут стоит добавить логирование ошибки
|
||||
logging.error(f"Generation failed: {e}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise e
|
||||
finally:
|
||||
if not progress_task.done():
|
||||
progress_task.cancel()
|
||||
@@ -186,53 +248,360 @@ class GenerationService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
external_gen.validate_image_source()
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||
created_assets: List[Asset] = []
|
||||
|
||||
image_bytes = await self._fetch_external_image(external_gen)
|
||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||
# Generate thumbnail
|
||||
thumbnail_bytes = None
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
||||
|
||||
# Reuse internal processing logic
|
||||
new_asset = await self._save_asset(
|
||||
image_bytes=image_bytes,
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
folder="external"
|
||||
# Save to S3
|
||||
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
data=None, # Not storing bytes in DB anymore
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id
|
||||
)
|
||||
|
||||
# Сохраняем в БД
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
||||
|
||||
created_assets.append(new_asset)
|
||||
|
||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||
result_ids = [a.id for a in created_assets]
|
||||
|
||||
generation.result_list = result_ids
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = generation_prompt
|
||||
|
||||
end_time = datetime.now()
|
||||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
||||
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
# 6. Send to Telegram if telegram_id is provided
|
||||
if generation.telegram_id and self.bot:
|
||||
try:
|
||||
for asset in created_assets:
|
||||
if asset.data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
||||
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
||||
)
|
||||
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
"""
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
# Random increment between 5 and 15
|
||||
increment = random.randint(5, 15)
|
||||
current_progress = min(current_progress + increment, 90)
|
||||
|
||||
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
||||
# But for simplicity here we just use the object we have and save it.
|
||||
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
||||
# Assuming simple update is fine for now.
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
# Task cancelled, generation finished (or failed)
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Error in progress simulation: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
|
||||
Args:
|
||||
external_gen: ExternalGenerationRequest with generation data and image
|
||||
|
||||
Returns:
|
||||
Created Generation object
|
||||
"""
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
# Validate image source
|
||||
external_gen.validate_image_source()
|
||||
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
|
||||
# 1. Process image (download or decode)
|
||||
image_bytes = None
|
||||
|
||||
if external_gen.image_url:
|
||||
# Download image from URL
|
||||
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif external_gen.image_data:
|
||||
# Decode base64 image
|
||||
logger.info("Decoding base64 image data")
|
||||
image_bytes = base64.b64decode(external_gen.image_data)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Failed to process image data")
|
||||
|
||||
# 2. Generate thumbnail
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
|
||||
# 3. Save to S3
|
||||
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
# 4. Create Asset
|
||||
new_asset = Asset(
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
data=None, # Not storing bytes in DB
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
)
|
||||
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
|
||||
logger.info(f"Created asset {asset_id} for external generation")
|
||||
|
||||
# 5. Create Generation record
|
||||
generation = Generation(
|
||||
status=GenerationStatus.DONE,
|
||||
linked_character_id=external_gen.linked_character_id,
|
||||
aspect_ratio=external_gen.aspect_ratio,
|
||||
quality=external_gen.quality,
|
||||
prompt=external_gen.prompt,
|
||||
model=external_gen.model,
|
||||
tech_prompt=external_gen.tech_prompt,
|
||||
seed=external_gen.seed,
|
||||
result_list=[new_asset.id],
|
||||
result=new_asset.id,
|
||||
progress=100,
|
||||
nsfw=external_gen.nsfw,
|
||||
execution_time_seconds=external_gen.execution_time_seconds,
|
||||
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||
token_usage=external_gen.token_usage,
|
||||
input_token_usage=external_gen.input_token_usage,
|
||||
output_token_usage=external_gen.output_token_usage,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
project_id=external_gen.project_id,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC)
|
||||
)
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
logger.info(f"Created generation {gen_id} from external source")
|
||||
|
||||
return generation
|
||||
|
||||
# === VIDEO GENERATION (Kling) ===
|
||||
|
||||
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||
"""Create a video generation task (async, returns immediately)."""
|
||||
if not self.kling_adapter:
|
||||
raise Exception("Kling adapter is not configured")
|
||||
|
||||
generation = Generation(
|
||||
status=GenerationStatus.RUNNING,
|
||||
gen_type=GenType.VIDEO,
|
||||
linked_character_id=request.linked_character_id,
|
||||
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
|
||||
quality=Quality.ONEK,
|
||||
prompt=request.prompt,
|
||||
assets_list=[request.image_asset_id],
|
||||
video_duration=request.duration,
|
||||
video_mode=request.mode,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
if user_id:
|
||||
generation.created_by = user_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
async def runner(gen, req):
|
||||
logger.info(f"Starting background video generation task for ID: {gen.id}")
|
||||
try:
|
||||
await self.create_video_generation(gen, req)
|
||||
logger.info(f"Background video generation task finished for ID: {gen.id}")
|
||||
except Exception:
|
||||
try:
|
||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||
if db_gen and db_gen.status != GenerationStatus.FAILED:
|
||||
db_gen.status = GenerationStatus.FAILED
|
||||
db_gen.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(db_gen)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark video generation as FAILED")
|
||||
logger.exception("create_video_generation task failed")
|
||||
|
||||
asyncio.create_task(runner(generation, request))
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
|
||||
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
|
||||
"""Background video generation: call Kling API, poll, download result, save asset."""
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
# 1. Get source image presigned URL
|
||||
asset = await self.dao.assets.get_asset(request.image_asset_id)
|
||||
if not asset:
|
||||
raise Exception(f"Asset {request.image_asset_id} not found")
|
||||
|
||||
if not asset.minio_object_name:
|
||||
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
|
||||
|
||||
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
|
||||
if not presigned_url:
|
||||
raise Exception("Failed to generate presigned URL for source image")
|
||||
|
||||
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
|
||||
|
||||
# 2. Create Kling task
|
||||
task_data = await self.kling_adapter.create_video_task(
|
||||
image_url=presigned_url,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt or "",
|
||||
model_name=request.model_name,
|
||||
duration=request.duration,
|
||||
mode=request.mode,
|
||||
cfg_scale=request.cfg_scale,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
)
|
||||
|
||||
task_id = task_data.get("task_id")
|
||||
generation.kling_task_id = task_id
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
|
||||
|
||||
# 3. Poll for completion with progress updates
|
||||
async def progress_callback(progress_pct: int):
|
||||
generation.progress = progress_pct
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
result = await self.kling_adapter.wait_for_completion(
|
||||
task_id=task_id,
|
||||
poll_interval=10,
|
||||
timeout=600,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 4. Extract video URL and download
|
||||
works = result.get("task_result", {}).get("videos", [])
|
||||
if not works:
|
||||
raise Exception("No video in Kling result")
|
||||
|
||||
video_url = works[0].get("url")
|
||||
video_duration = works[0].get("duration", request.duration)
|
||||
if not video_url:
|
||||
raise Exception("No video URL in Kling result")
|
||||
|
||||
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
video_response = await client.get(video_url)
|
||||
video_response.raise_for_status()
|
||||
video_bytes = video_response.content
|
||||
|
||||
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
|
||||
|
||||
# 5. Upload to S3
|
||||
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
|
||||
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
|
||||
|
||||
# 6. Create Asset
|
||||
new_asset = Asset(
|
||||
name=f"Video_{generation.linked_character_id or 'gen'}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.VIDEO,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
data=None,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=None, # видео thumbnails можно добавить позже
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id,
|
||||
)
|
||||
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
|
||||
# 7. Finalize generation
|
||||
end_time = datetime.now()
|
||||
generation.result_list = [new_asset.id]
|
||||
generation.result = new_asset.id
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.video_duration = video_duration
|
||||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
|
||||
|
||||
except KlingApiException as e:
|
||||
logger.error(f"Kling API error for generation {generation.id}: {e}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Video generation {generation.id} failed: {e}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
generation.failed_reason = str(e)
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
raise
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
Soft delete generation by marking it as deleted.
|
||||
"""
|
||||
try:
|
||||
generation = await self.dao.generations.get_generation(generation_id)
|
||||
if not generation:
|
||||
return False
|
||||
|
||||
generation.is_deleted = True
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
@@ -240,207 +609,3 @@ class GenerationService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||
return False
|
||||
|
||||
async def cleanup_stale_generations(self):
|
||||
try:
|
||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
|
||||
if count > 0:
|
||||
logger.info(f"Cleaned up {count} stale generations")
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up stale generations: {e}")
|
||||
|
||||
async def cleanup_old_data(self, days: int = 30):
|
||||
try:
|
||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
||||
if gen_count > 0:
|
||||
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
|
||||
if asset_ids:
|
||||
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during old data cleanup: {e}")
|
||||
|
||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||
summary = UsageStats(**summary_data)
|
||||
|
||||
by_user, by_project = None, None
|
||||
if breakdown_by == "created_by":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||
by_user = [UsageByEntity(**item) for item in res]
|
||||
if breakdown_by == "project_id":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||
by_project = [UsageByEntity(**item) for item in res]
|
||||
|
||||
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
|
||||
|
||||
# --- Private Helpers ---
|
||||
|
||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
|
||||
try:
|
||||
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
gen_model.created_by = user_id
|
||||
gen_model.generation_group_id = generation_group_id
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(gen_model)
|
||||
gen_model.id = gen_id
|
||||
|
||||
asyncio.create_task(self._queued_generation_runner(gen_model))
|
||||
return GenerationResponse(**gen_model.model_dump())
|
||||
except Exception:
|
||||
logger.exception("Failed to initiate single generation")
|
||||
raise
|
||||
|
||||
async def _queued_generation_runner(self, gen: Generation):
|
||||
logger.info(f"Generation {gen.id} waiting for slot...")
|
||||
try:
|
||||
async with generation_semaphore:
|
||||
await self.create_generation(gen)
|
||||
except Exception as e:
|
||||
await self._handle_generation_failure(gen, e)
|
||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||
media_group_bytes: List[bytes] = []
|
||||
prompt = generation.prompt
|
||||
|
||||
# 1. Character Avatar
|
||||
if generation.linked_character_id:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if not char_info:
|
||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
data = await self._get_asset_data_bytes(avatar_asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 2. Reference Assets
|
||||
if generation.assets_list:
|
||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 3. Environment Assets
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
if media_group_bytes:
|
||||
prompt += (
|
||||
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
|
||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||
)
|
||||
|
||||
return media_group_bytes, prompt
|
||||
|
||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
generation.token_usage = metrics.get("token_usage")
|
||||
generation.input_token_usage = metrics.get("input_token_usage")
|
||||
generation.output_token_usage = metrics.get("output_token_usage")
|
||||
|
||||
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
|
||||
logger.error(f"Generation {generation.id} failed: {error}")
|
||||
generation.status = GenerationStatus.FAILED
|
||||
# Don't overwrite if reason is already set, unless a new error is provided
|
||||
if error:
|
||||
generation.failed_reason = str(error)
|
||||
elif not generation.failed_reason:
|
||||
generation.failed_reason = "Unknown error"
|
||||
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
await self.dao.generations.update_generation(generation)
|
||||
|
||||
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
|
||||
created_assets = []
|
||||
for img_bytes in bytes_list:
|
||||
asset = await self._save_asset(
|
||||
image_bytes=img_bytes,
|
||||
name=f"Generated_{generation.linked_character_id}",
|
||||
created_by=generation.created_by,
|
||||
project_id=generation.project_id,
|
||||
linked_char_id=generation.linked_character_id,
|
||||
folder="generated"
|
||||
)
|
||||
created_assets.append(asset)
|
||||
return created_assets
|
||||
|
||||
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
new_asset = Asset(
|
||||
name=name,
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=linked_char_id,
|
||||
data=None,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
return new_asset
|
||||
|
||||
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
|
||||
generation.result_list = [a.id for a in assets]
|
||||
generation.status = GenerationStatus.DONE
|
||||
generation.progress = 100
|
||||
generation.updated_at = datetime.now(UTC)
|
||||
generation.tech_prompt = tech_prompt
|
||||
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
|
||||
await self.dao.generations.update_generation(generation)
|
||||
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
|
||||
|
||||
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
|
||||
try:
|
||||
for asset in assets:
|
||||
# Need to get data for telegram if it's not in Asset object
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
|
||||
if img_data:
|
||||
await self.bot.send_photo(
|
||||
chat_id=generation.telegram_id,
|
||||
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
|
||||
caption=f"Generated from: {generation.prompt[:100]}..."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send to Telegram: {e}")
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
current_progress = 0
|
||||
try:
|
||||
while current_progress < 90:
|
||||
await asyncio.sleep(4)
|
||||
current_progress = min(current_progress + random.randint(5, 15), 90)
|
||||
generation.progress = current_progress
|
||||
await self.dao.generations.update_generation(generation)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _fetch_external_image(self, external_gen) -> bytes:
|
||||
if external_gen.image_url:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif external_gen.image_data:
|
||||
return base64.b64decode(external_gen.image_data)
|
||||
raise ValueError("No image source provided")
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str, inspiration_id: Optional[str] = None) -> Idea:
|
||||
idea = Idea(
|
||||
name=name,
|
||||
description=description,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
inspiration_id=inspiration_id
|
||||
)
|
||||
idea_id = await self.dao.ideas.create_idea(idea)
|
||||
idea.id = idea_id
|
||||
return idea
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
return await self.dao.ideas.get_idea(idea_id)
|
||||
|
||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None, inspiration_id: Optional[str] = None) -> Optional[Idea]:
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
idea.name = name
|
||||
if description is not None:
|
||||
idea.description = description
|
||||
if inspiration_id is not None:
|
||||
idea.inspiration_id = inspiration_id
|
||||
|
||||
idea.updated_at = datetime.now()
|
||||
await self.dao.ideas.update_idea(idea)
|
||||
return idea
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
return await self.dao.ideas.delete_idea(idea_id)
|
||||
|
||||
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Link
|
||||
gen.idea_id = idea_id
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
|
||||
# Verify idea exists (optional, but good for validation)
|
||||
idea = await self.dao.ideas.get_idea(idea_id)
|
||||
if not idea:
|
||||
return False
|
||||
|
||||
# Get generation
|
||||
gen = await self.dao.generations.get_generation(generation_id)
|
||||
if not gen:
|
||||
return False
|
||||
|
||||
# Unlink only if currently linked to this idea
|
||||
if gen.idea_id == idea_id:
|
||||
gen.idea_id = None
|
||||
gen.updated_at = datetime.now()
|
||||
await self.dao.generations.update_generation(gen)
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -1,146 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Inspiration import Inspiration
|
||||
from repos.dao import DAO
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
# Try to import yt_dlp, but don't crash if it's missing (though we added it to requirements)
|
||||
try:
|
||||
import yt_dlp
|
||||
except ImportError:
|
||||
yt_dlp = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InspirationService:
|
||||
def __init__(self, dao: DAO, s3_adapter: S3Adapter):
|
||||
self.dao = dao
|
||||
self.s3_adapter = s3_adapter
|
||||
|
||||
async def create_inspiration(self, source_url: str, created_by: str, project_id: Optional[str] = None, caption: Optional[str] = None) -> Inspiration:
|
||||
# 1. Download content from Instagram
|
||||
try:
|
||||
content_bytes, content_type, ext = await self._download_content(source_url)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download content from {source_url}: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Failed to download content: {str(e)}")
|
||||
|
||||
# 2. Save as Asset
|
||||
filename = f"inspirations/{datetime.now().strftime('%Y%m%d_%H%M%S')}_insta.{ext}"
|
||||
|
||||
await self.s3_adapter.upload_file(filename, content_bytes, content_type=content_type)
|
||||
|
||||
asset = Asset(
|
||||
name=f"Inspiration from {source_url}",
|
||||
type=AssetType.INSPIRATION,
|
||||
content_type=AssetContentType.VIDEO if content_type.startswith("video") else AssetContentType.IMAGE,
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
asset_id = await self.dao.assets.create_asset(asset)
|
||||
|
||||
# 3. Create Inspiration object
|
||||
inspiration = Inspiration(
|
||||
source_url=source_url,
|
||||
caption=caption,
|
||||
asset_id=str(asset_id),
|
||||
created_by=created_by,
|
||||
project_id=project_id
|
||||
)
|
||||
insp_id = await self.dao.inspirations.create_inspiration(inspiration)
|
||||
inspiration.id = insp_id
|
||||
|
||||
return inspiration
|
||||
|
||||
async def get_inspirations(self, project_id: Optional[str], created_by: str, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||
return await self.dao.inspirations.get_inspirations(project_id, created_by, limit, offset)
|
||||
|
||||
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||
return await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
|
||||
async def mark_as_completed(self, inspiration_id: str, is_completed: bool = True) -> Optional[Inspiration]:
|
||||
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
return None
|
||||
|
||||
inspiration.is_completed = is_completed
|
||||
inspiration.updated_at = datetime.now()
|
||||
await self.dao.inspirations.update_inspiration(inspiration)
|
||||
return inspiration
|
||||
|
||||
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||
if not inspiration:
|
||||
return False
|
||||
|
||||
# Delete associated asset
|
||||
if inspiration.asset_id:
|
||||
await self.dao.assets.delete_asset(inspiration.asset_id)
|
||||
|
||||
return await self.dao.inspirations.delete_inspiration(inspiration_id)
|
||||
|
||||
async def _download_content(self, url: str) -> Tuple[bytes, str, str]:
|
||||
"""
|
||||
Downloads content using yt-dlp.
|
||||
Returns (content_bytes, content_type, extension)
|
||||
"""
|
||||
if not yt_dlp:
|
||||
raise RuntimeError("yt-dlp is not installed")
|
||||
|
||||
logger.info(f"Downloading from {url} using yt-dlp...")
|
||||
|
||||
def run_yt_dlp():
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ydl_opts = {
|
||||
'outtmpl': f'{tmpdirname}/%(id)s.%(ext)s',
|
||||
'quiet': True,
|
||||
'no_warnings': True,
|
||||
'format': 'best', # Best quality single file
|
||||
'noplaylist': True, # Only single video if it's a playlist/profile
|
||||
'writethumbnail': False,
|
||||
'writesubtitles': False,
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
ydl.download([url])
|
||||
|
||||
# Find the downloaded file
|
||||
files = os.listdir(tmpdirname)
|
||||
if not files:
|
||||
raise Exception("No files downloaded")
|
||||
|
||||
# Pick the largest file if multiple (e.g. if yt-dlp downloaded parts)
|
||||
# But with 'format': 'best', it should be one.
|
||||
# If carousel, it might be multiple. Let's pick the first one.
|
||||
filename = files[0]
|
||||
filepath = os.path.join(tmpdirname, filename)
|
||||
|
||||
with open(filepath, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
ext = filename.split('.')[-1].lower()
|
||||
|
||||
# Determine content type
|
||||
if ext in ['mp4', 'mov', 'avi', 'mkv', 'webm']:
|
||||
content_type = f"video/{ext}"
|
||||
if ext == 'mov': content_type = "video/quicktime"
|
||||
elif ext in ['jpg', 'jpeg', 'png', 'webp']:
|
||||
content_type = f"image/{ext}"
|
||||
if ext == 'jpg': content_type = "image/jpeg"
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
return data, content_type, ext
|
||||
|
||||
return await asyncio.to_thread(run_yt_dlp)
|
||||
@@ -1,79 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from repos.dao import DAO
|
||||
from models.Post import Post
|
||||
|
||||
|
||||
class PostService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_post(
|
||||
self,
|
||||
date: datetime,
|
||||
topic: str,
|
||||
generation_ids: List[str],
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Post:
|
||||
post = Post(
|
||||
date=date,
|
||||
topic=topic,
|
||||
generation_ids=generation_ids,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
)
|
||||
post_id = await self.dao.posts.create_post(post)
|
||||
post.id = post_id
|
||||
return post
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
||||
|
||||
async def update_post(
|
||||
self,
|
||||
post_id: str,
|
||||
date: Optional[datetime] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> Optional[Post]:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return None
|
||||
|
||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
||||
if date is not None:
|
||||
updates["date"] = date
|
||||
if topic is not None:
|
||||
updates["topic"] = topic
|
||||
|
||||
await self.dao.posts.update_post(post_id, updates)
|
||||
|
||||
# Return refreshed post
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
return await self.dao.posts.delete_post(post_id)
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
||||
39
config.py
39
config.py
@@ -1,39 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Telegram Bot
|
||||
BOT_TOKEN: str
|
||||
ADMIN_ID: int = 0
|
||||
|
||||
# AI Service
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
# Database
|
||||
MONGO_HOST: str = "mongodb://localhost:27017"
|
||||
DB_NAME: str = "my_bot_db"
|
||||
|
||||
# S3 Storage (Minio)
|
||||
MINIO_ENDPOINT: str = "http://localhost:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "ai-char"
|
||||
|
||||
# External API
|
||||
EXTERNAL_API_SECRET: Optional[str] = None
|
||||
|
||||
# JWT Security
|
||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.getenv("ENV_FILE", ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
|
||||
# Ждем сбора остальных частей
|
||||
await asyncio.sleep(self.latency)
|
||||
|
||||
# Проверяем, что ключ все еще существует
|
||||
# Проверяем, что ключ все еще существует (на всякий случай)
|
||||
if group_id in self.album_data:
|
||||
# Передаем собранный альбом в хендлер
|
||||
# Сортируем по message_id, чтобы порядок был верным
|
||||
current_album = self.album_data[group_id]
|
||||
current_album.sort(key=lambda x: x.message_id)
|
||||
data["album"] = current_album
|
||||
self.album_data[group_id].sort(key=lambda x: x.message_id)
|
||||
data["album"] = self.album_data[group_id]
|
||||
return await handler(event, data)
|
||||
|
||||
finally:
|
||||
# ЧИСТКА: Удаляем запись после обработки или таймаута
|
||||
# Используем pop() с дефолтом, чтобы избежать KeyError
|
||||
self.album_data.pop(group_id, None)
|
||||
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
|
||||
# Проверяем, что мы удаляем именно то, что создали, и ключ существует
|
||||
if group_id in self.album_data and self.album_data[group_id][0] == event:
|
||||
del self.album_data[group_id]
|
||||
|
||||
else:
|
||||
# Если группа уже собирается - просто добавляем и выходим
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from datetime import datetime, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Album(BaseModel):
|
||||
id: str | None = None
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
cover_asset_id: str | None = None
|
||||
generation_ids: list[str] = []
|
||||
description: Optional[str] = None
|
||||
cover_asset_id: Optional[str] = None
|
||||
generation_ids: List[str] = []
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@@ -1,38 +1,36 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Optional, Any, List
|
||||
|
||||
from pydantic import BaseModel, computed_field, Field, model_validator
|
||||
|
||||
|
||||
class AssetContentType(str, Enum):
|
||||
IMAGE = 'image'
|
||||
PROMPT = 'prompt'
|
||||
VIDEO = 'video'
|
||||
PROMPT = 'prompt'
|
||||
|
||||
class AssetType(str, Enum):
|
||||
UPLOADED = 'uploaded'
|
||||
GENERATED = 'generated'
|
||||
INSPIRATION = 'inspiration'
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
id: str | None = None
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
type: AssetType = AssetType.GENERATED
|
||||
content_type: AssetContentType = AssetContentType.IMAGE
|
||||
linked_char_id: str | None = None
|
||||
data: bytes | None = None
|
||||
tg_doc_file_id: str | None = None
|
||||
tg_photo_file_id: str | None = None
|
||||
minio_object_name: str | None = None
|
||||
minio_bucket: str | None = None
|
||||
minio_thumbnail_object_name: str | None = None
|
||||
thumbnail: bytes | None = None
|
||||
tags: list[str] = []
|
||||
created_by: str | None = None
|
||||
project_id: str | None = None
|
||||
is_deleted: bool = False
|
||||
linked_char_id: Optional[str] = None
|
||||
data: Optional[bytes] = None
|
||||
tg_doc_file_id: Optional[str] = None
|
||||
tg_photo_file_id: Optional[str] = None
|
||||
minio_object_name: Optional[str] = None
|
||||
minio_bucket: Optional[str] = None
|
||||
minio_thumbnail_object_name: Optional[str] = None
|
||||
thumbnail: Optional[bytes] = None
|
||||
tags: List[str] = []
|
||||
created_by: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@@ -65,7 +63,6 @@ class Asset(BaseModel):
|
||||
|
||||
# --- CALCULATED FIELD ---
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""
|
||||
Это поле автоматически вычислится и попадет в model_dump() / .json()
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core.core_schema import computed_field
|
||||
|
||||
|
||||
class Character(BaseModel):
|
||||
id: str | None = None
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
avatar_asset_id: str | None = None
|
||||
avatar_image: str | None = None
|
||||
character_image_doc_tg_id: str | None = None
|
||||
character_image_tg_id: str | None = None
|
||||
character_bio: str | None = None
|
||||
created_by: str | None = None
|
||||
project_id: str | None = None
|
||||
avatar_asset_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_data: Optional[bytes] = None
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
character_bio: Optional[str] = None
|
||||
created_by: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
id: str | None = Field(None, alias="_id")
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: str | None = None
|
||||
asset_ids: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_encoders={ObjectId: str},
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime, UTC
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
from models.enums import AspectRatios, Quality
|
||||
from models.Asset import Asset
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
|
||||
|
||||
class GenerationStatus(str, Enum):
|
||||
@@ -12,36 +14,34 @@ class GenerationStatus(str, Enum):
|
||||
FAILED = "failed"
|
||||
|
||||
class Generation(BaseModel):
|
||||
id: str | None = None
|
||||
id: Optional[str] = None
|
||||
status: GenerationStatus = GenerationStatus.RUNNING
|
||||
failed_reason: str | None = None
|
||||
linked_character_id: str | None = None
|
||||
telegram_id: int | None = None
|
||||
gen_type: GenType = GenType.IMAGE
|
||||
failed_reason: Optional[str] = None
|
||||
linked_character_id: Optional[str] = None
|
||||
telegram_id: Optional[int] = None
|
||||
use_profile_image: bool = True
|
||||
aspect_ratio: AspectRatios
|
||||
quality: Quality
|
||||
prompt: str
|
||||
model: str | None = None
|
||||
seed: int | None = None
|
||||
tech_prompt: str | None = None
|
||||
assets_list: list[str] = Field(default_factory=list)
|
||||
result_list: list[str] = Field(default_factory=list)
|
||||
result: str | None = None
|
||||
tech_prompt: Optional[str] = None
|
||||
assets_list: List[str] = Field(default_factory=list)
|
||||
result_list: List[str] = Field(default_factory=list)
|
||||
result: Optional[str] = None
|
||||
progress: int = 0
|
||||
execution_time_seconds: float | None = None
|
||||
api_execution_time_seconds: float | None = None
|
||||
token_usage: int | None = None
|
||||
input_token_usage: int | None = None
|
||||
output_token_usage: int | None = None
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
is_deleted: bool = False
|
||||
album_id: str | None = None
|
||||
environment_id: str | None = None
|
||||
generation_group_id: str | None = None
|
||||
created_by: str | None = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: str | None = None
|
||||
idea_id: str | None = None
|
||||
liked_by: list[str] = Field(default_factory=list)
|
||||
nsfw: bool = False
|
||||
album_id: Optional[str] = None
|
||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: Optional[str] = None
|
||||
# Video-specific fields
|
||||
kling_task_id: Optional[str] = None
|
||||
video_duration: Optional[int] = None # 5 or 10 seconds
|
||||
video_mode: Optional[str] = None # "std" or "pro"
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Idea(BaseModel):
|
||||
id: str | None = None
|
||||
name: str = "New Idea"
|
||||
description: str | None = None
|
||||
project_id: str | None = None
|
||||
inspiration_id: str | None = None # Link to Inspiration
|
||||
created_by: str # User ID
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
updated_at: datetime = Field(default_factory=datetime.now)
|
||||
@@ -1,15 +0,0 @@
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Inspiration(BaseModel):
|
||||
id: str | None = None
|
||||
source_url: str
|
||||
caption: str | None = None
|
||||
asset_id: str
|
||||
is_completed: bool = False
|
||||
created_by: str
|
||||
project_id: str | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
@@ -1,22 +0,0 @@
|
||||
from datetime import datetime, timezone, UTC
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
id: str | None = None
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: list[str] = Field(default_factory=list)
|
||||
project_id: str | None = None
|
||||
created_by: str
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_tz_aware(self):
|
||||
for field in ("date", "created_at", "updated_at"):
|
||||
val = getattr(self, field)
|
||||
if val is not None and val.tzinfo is None:
|
||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
||||
return self
|
||||
@@ -1,11 +1,12 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class Project(BaseModel):
|
||||
id: str | None = None
|
||||
id: Optional[str] = None
|
||||
name: str
|
||||
description: str | None = None
|
||||
description: Optional[str] = None
|
||||
owner_id: str
|
||||
members: list[str] = [] # List of User IDs
|
||||
members: List[str] = [] # List of User IDs
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -2,30 +2,19 @@ from enum import Enum
|
||||
|
||||
|
||||
class AspectRatios(str, Enum):
|
||||
ONEONE = "1:1"
|
||||
TWOTHREE = "2:3"
|
||||
THREETWO = "3:2"
|
||||
THREEFOUR = "3:4"
|
||||
FOURTHREE = "4:3"
|
||||
FOURFIVE = "4:5"
|
||||
FIVEFOUR = "5:4"
|
||||
NINESIXTEEN = "9:16"
|
||||
SIXTEENNINE = "16:9"
|
||||
TWENTYONENINE = "21:9"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
mapping = {
|
||||
"NINESIXTEEN": cls.NINESIXTEEN,
|
||||
"SIXTEENNINE": cls.SIXTEENNINE,
|
||||
"THREEFOUR": cls.THREEFOUR,
|
||||
"FOURTHREE": cls.FOURTHREE,
|
||||
}
|
||||
return mapping.get(value)
|
||||
NINESIXTEEN = "NINESIXTEEN"
|
||||
SIXTEENNINE = "SIXTEENNINE"
|
||||
THREEFOUR = "THREEFOUR"
|
||||
FOURTHREE = "FOURTHREE"
|
||||
|
||||
@property
|
||||
def value_ratio(self) -> str:
|
||||
return self.value
|
||||
return {
|
||||
AspectRatios.NINESIXTEEN: "9:16",
|
||||
AspectRatios.SIXTEENNINE: "16:9",
|
||||
AspectRatios.THREEFOUR: "3:4",
|
||||
AspectRatios.FOURTHREE: "4:3",
|
||||
}[self]
|
||||
|
||||
|
||||
class Quality(str, Enum):
|
||||
@@ -45,27 +34,12 @@ class Quality(str, Enum):
|
||||
class GenType(str, Enum):
|
||||
TEXT = 'Text'
|
||||
IMAGE = 'Image'
|
||||
VIDEO = 'Video'
|
||||
|
||||
@property
|
||||
def value_type(self) -> str:
|
||||
return {
|
||||
GenType.TEXT: 'Text',
|
||||
GenType.IMAGE: 'Image',
|
||||
GenType.VIDEO: 'Video',
|
||||
}[self]
|
||||
|
||||
|
||||
class TextModel(str, Enum):
|
||||
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"
|
||||
|
||||
@property
|
||||
def value_model(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class ImageModel(str, Enum):
|
||||
GEMINI_3_PRO_IMAGE_PREVIEW = "gemini-3-pro-image-preview"
|
||||
GEMINI_3_1_FLASH_IMAGE_PREVIEW = "gemini-3.1-flash-image-preview"
|
||||
|
||||
@property
|
||||
def value_model(self) -> str:
|
||||
return self.value
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,8 +1,6 @@
|
||||
from typing import Any, List, Optional
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from datetime import datetime, UTC
|
||||
from bson import ObjectId
|
||||
from uuid import uuid4
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Asset import Asset
|
||||
@@ -21,8 +19,7 @@ class AssetsRepo:
|
||||
# Main data
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
|
||||
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||
if uploaded:
|
||||
@@ -35,8 +32,7 @@ class AssetsRepo:
|
||||
# Thumbnail
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
|
||||
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||
if uploaded_thumb:
|
||||
@@ -51,7 +47,7 @@ class AssetsRepo:
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
|
||||
filter = {}
|
||||
if asset_type:
|
||||
filter["type"] = asset_type
|
||||
args = {}
|
||||
@@ -102,7 +98,7 @@ class AssetsRepo:
|
||||
|
||||
return assets
|
||||
|
||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]:
|
||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
||||
projection = None
|
||||
if not with_data:
|
||||
projection = {"data": 0, "thumbnail": 0}
|
||||
@@ -138,8 +134,7 @@ class AssetsRepo:
|
||||
if self.s3:
|
||||
if asset.data:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
||||
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||
if await self.s3.upload_file(object_name, asset.data):
|
||||
asset.minio_object_name = object_name
|
||||
asset.minio_bucket = self.s3.bucket_name
|
||||
@@ -147,8 +142,7 @@ class AssetsRepo:
|
||||
|
||||
if asset.thumbnail:
|
||||
ts = int(asset.created_at.timestamp())
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||
asset.minio_thumbnail_object_name = thumb_name
|
||||
asset.thumbnail = None
|
||||
@@ -175,16 +169,12 @@ class AssetsRepo:
|
||||
filter["linked_char_id"] = character_id
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
return await self.collection.count_documents(filter)
|
||||
|
||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)]
|
||||
if not object_ids:
|
||||
return []
|
||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
|
||||
# Original excluded thumbnail too.
|
||||
assets = []
|
||||
@@ -207,61 +197,6 @@ class AssetsRepo:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
|
||||
"""
|
||||
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
|
||||
Возвращает количество обработанных ассетов.
|
||||
"""
|
||||
if not asset_ids:
|
||||
return 0
|
||||
|
||||
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
|
||||
if not object_ids:
|
||||
return 0
|
||||
|
||||
# Находим ассеты, которые ещё не удалены
|
||||
cursor = self.collection.find(
|
||||
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
|
||||
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
|
||||
)
|
||||
|
||||
purged_count = 0
|
||||
ids_to_update = []
|
||||
|
||||
async for doc in cursor:
|
||||
ids_to_update.append(doc["_id"])
|
||||
|
||||
# Жёсткое удаление файлов из S3
|
||||
if self.s3:
|
||||
if doc.get("minio_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
|
||||
if doc.get("minio_thumbnail_object_name"):
|
||||
try:
|
||||
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
|
||||
|
||||
purged_count += 1
|
||||
|
||||
# Мягкое удаление + очистка ссылок на S3
|
||||
if ids_to_update:
|
||||
await self.collection.update_many(
|
||||
{"_id": {"$in": ids_to_update}},
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"minio_object_name": None,
|
||||
"minio_thumbnail_object_name": None,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return purged_count
|
||||
|
||||
async def migrate_to_minio(self) -> dict:
|
||||
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
||||
if not self.s3:
|
||||
@@ -281,8 +216,7 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
uid = uuid4().hex[:8]
|
||||
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
|
||||
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||
if await self.s3.upload_file(object_name, data):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
@@ -309,8 +243,7 @@ class AssetsRepo:
|
||||
created_at = doc.get("created_at")
|
||||
ts = int(created_at.timestamp()) if created_at else 0
|
||||
|
||||
uid = uuid4().hex[:8]
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
|
||||
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||
if await self.s3.upload_file(thumb_name, thumb):
|
||||
await self.collection.update_one(
|
||||
{"_id": asset_id},
|
||||
|
||||
@@ -15,24 +15,26 @@ class CharacterRepo:
|
||||
character.id = str(op.inserted_id)
|
||||
return character
|
||||
|
||||
async def get_character(self, character_id: str) -> Character | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)})
|
||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||
args = {}
|
||||
if not with_image_data:
|
||||
args["character_image_data"] = 0
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
|
||||
if res is None:
|
||||
return None
|
||||
else:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Character(**res)
|
||||
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
|
||||
filter = {}
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
|
||||
args = {"character_image_data": 0} # don't return image data for list
|
||||
res = await self.collection.find(filter, args).to_list(None)
|
||||
chars = []
|
||||
for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
|
||||
@@ -6,10 +6,6 @@ from repos.generation_repo import GenerationRepo
|
||||
from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
from repos.environment_repo import EnvironmentRepo
|
||||
from repos.inspiration_repo import InspirationRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -23,7 +19,3 @@ class DAO:
|
||||
self.albums = AlbumsRepo(client, db_name)
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
self.posts = PostRepo(client, db_name)
|
||||
self.environments = EnvironmentRepo(client, db_name)
|
||||
self.inspirations = InspirationRepo(client, db_name)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Environment import Environment
|
||||
|
||||
|
||||
class EnvironmentRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["environments"]
|
||||
|
||||
async def create_env(self, env: Environment) -> Environment:
|
||||
env_dict = env.model_dump(exclude={"id"})
|
||||
res = await self.collection.insert_one(env_dict)
|
||||
env.id = str(res.inserted_id)
|
||||
return env
|
||||
|
||||
async def get_env(self, env_id: str) -> Optional[Environment]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(env_id)})
|
||||
if not res:
|
||||
return None
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Environment(**res)
|
||||
|
||||
async def get_character_envs(self, character_id: str) -> List[Environment]:
|
||||
cursor = self.collection.find({"character_id": character_id})
|
||||
envs = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
envs.append(Environment(**doc))
|
||||
return envs
|
||||
|
||||
async def update_env(self, env_id: str, update_data: dict) -> bool:
|
||||
update_data["updated_at"] = datetime.utcnow()
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{"$set": update_data}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_env(self, env_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def add_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": {"$each": asset_ids}},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$pull": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,5 +1,4 @@
|
||||
from typing import Any, Optional, List
|
||||
from datetime import datetime, timedelta, UTC
|
||||
from typing import Optional, List
|
||||
|
||||
from PIL.ImageChops import offset
|
||||
from bson import ObjectId
|
||||
@@ -17,7 +16,7 @@ class GenerationRepo:
|
||||
res = await self.collection.insert_one(generation.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Generation | None:
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
||||
if res is None:
|
||||
return None
|
||||
@@ -26,32 +25,20 @@ class GenerationRepo:
|
||||
return Generation(**res)
|
||||
|
||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> List[Generation]:
|
||||
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||
|
||||
filter: dict[str, Any] = {"is_deleted": False}
|
||||
filter = {"is_deleted": False}
|
||||
if character_id is not None:
|
||||
filter["linked_character_id"] = character_id
|
||||
if status is not None:
|
||||
filter["status"] = status
|
||||
if created_by is not None:
|
||||
filter["created_by"] = created_by
|
||||
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
|
||||
# But if project_id is passed, we filter by that.
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id is not None:
|
||||
filter["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
filter["idea_id"] = idea_id
|
||||
if only_liked_by is not None:
|
||||
filter["liked_by"] = only_liked_by
|
||||
|
||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
||||
# Otherwise typically descending (newest first)
|
||||
sort_order = 1 if idea_id else -1
|
||||
|
||||
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
|
||||
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||
offset).limit(limit).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
@@ -60,8 +47,7 @@ class GenerationRepo:
|
||||
return generations
|
||||
|
||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> int:
|
||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||
args = {}
|
||||
if character_id is not None:
|
||||
args["linked_character_id"] = character_id
|
||||
@@ -69,16 +55,8 @@ class GenerationRepo:
|
||||
args["status"] = status
|
||||
if created_by is not None:
|
||||
args["created_by"] = created_by
|
||||
if project_id is None:
|
||||
args["project_id"] = None
|
||||
if project_id is not None:
|
||||
args["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
args["idea_id"] = idea_id
|
||||
if album_id is not None:
|
||||
args["album_id"] = album_id
|
||||
if only_liked_by is not None:
|
||||
args["liked_by"] = only_liked_by
|
||||
return await self.collection.count_documents(args)
|
||||
|
||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||
@@ -99,219 +77,3 @@ class GenerationRepo:
|
||||
|
||||
async def update_generation(self, generation: Generation, ):
|
||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||
|
||||
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||
"""
|
||||
Toggles like for a user on a generation.
|
||||
Returns True if liked, False if unliked, None if generation not found.
|
||||
"""
|
||||
if not ObjectId.is_valid(generation_id):
|
||||
return None
|
||||
|
||||
oid = ObjectId(generation_id)
|
||||
|
||||
# Check if generation exists
|
||||
gen = await self.collection.find_one({"_id": oid}, {"liked_by": 1})
|
||||
|
||||
if not gen:
|
||||
return None
|
||||
|
||||
if user_id in gen.get("liked_by", []):
|
||||
# Unlike
|
||||
await self.collection.update_one(
|
||||
{"_id": oid},
|
||||
{"$pull": {"liked_by": user_id}}
|
||||
)
|
||||
return False
|
||||
else:
|
||||
# Like
|
||||
await self.collection.update_one(
|
||||
{"_id": oid},
|
||||
{"$addToSet": {"liked_by": user_id}}
|
||||
)
|
||||
return True
|
||||
|
||||
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
|
||||
if not ObjectId.is_valid(generation_id):
|
||||
return False
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(generation_id)},
|
||||
{"$set": {"nsfw": is_nsfw}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
# 1. Match all done generations (including soft-deleted)
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
# 2. Group by null (total)
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(1)
|
||||
|
||||
if not res:
|
||||
return {
|
||||
"total_runs": 0,
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_cost": 0.0
|
||||
}
|
||||
|
||||
result = res[0]
|
||||
result.pop("_id")
|
||||
result["total_cost"] = round(result["total_cost"], 4)
|
||||
return result
|
||||
|
||||
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Returns usage statistics grouped by user or project.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": f"${group_by}",
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
pipeline.append({"$sort": {"total_cost": -1}})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(None)
|
||||
|
||||
results = []
|
||||
for item in res:
|
||||
entity_id = item.pop("_id")
|
||||
item["total_cost"] = round(item["total_cost"], 4)
|
||||
results.append({
|
||||
"entity_id": str(entity_id) if entity_id else "unknown",
|
||||
"stats": item
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
||||
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
for generation in res:
|
||||
generation["id"] = str(generation.pop("_id"))
|
||||
generations.append(Generation(**generation))
|
||||
return generations
|
||||
|
||||
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
|
||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
||||
res = await self.collection.update_many(
|
||||
{
|
||||
"status": GenerationStatus.RUNNING,
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
},
|
||||
{
|
||||
"$set": {
|
||||
"status": GenerationStatus.FAILED,
|
||||
"failed_reason": "Timeout: Execution time limit exceeded",
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
return res.modified_count
|
||||
|
||||
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
|
||||
"""
|
||||
Мягко удаляет генерации старше N дней.
|
||||
Возвращает (количество удалённых, список asset IDs для очистки).
|
||||
"""
|
||||
cutoff_time = datetime.now(UTC) - timedelta(days=days)
|
||||
filter_query = {
|
||||
"is_deleted": False,
|
||||
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
|
||||
"created_at": {"$lt": cutoff_time}
|
||||
}
|
||||
|
||||
# Сначала собираем asset IDs из удаляемых генераций
|
||||
asset_ids: List[str] = []
|
||||
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
||||
async for doc in cursor:
|
||||
asset_ids.extend(doc.get("result_list", []))
|
||||
|
||||
# Мягкое удаление
|
||||
res = await self.collection.update_many(
|
||||
filter_query,
|
||||
{
|
||||
"$set": {
|
||||
"is_deleted": True,
|
||||
"updated_at": datetime.now(UTC)
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Убираем дубликаты
|
||||
unique_asset_ids = list(set(asset_ids))
|
||||
return res.modified_count, unique_asset_ids
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
from typing import Optional, List
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Idea import Idea
|
||||
|
||||
class IdeaRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["ideas"]
|
||||
|
||||
async def create_idea(self, idea: Idea) -> str:
|
||||
res = await self.collection.insert_one(idea.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Idea(**res)
|
||||
return None
|
||||
|
||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
||||
if project_id:
|
||||
match_stage = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
pipeline = [
|
||||
{"$match": match_stage},
|
||||
{"$sort": {"updated_at": -1}},
|
||||
{"$skip": offset},
|
||||
{"$limit": limit},
|
||||
# Add string id field for lookup
|
||||
{"$addFields": {"str_id": {"$toString": "$_id"}}},
|
||||
# Lookup generations
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "generations",
|
||||
"let": {"idea_id": "$str_id"},
|
||||
"pipeline": [
|
||||
{
|
||||
"$match": {
|
||||
"$and": [
|
||||
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
|
||||
{"status": "done"},
|
||||
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
|
||||
{"is_deleted": False}
|
||||
]
|
||||
}
|
||||
},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
|
||||
{"$limit": 1}
|
||||
],
|
||||
"as": "generations"
|
||||
}
|
||||
},
|
||||
# Unwind generations array (preserve ideas without generations)
|
||||
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
|
||||
# Rename for clarity
|
||||
{"$addFields": {
|
||||
"last_generation": "$generations",
|
||||
"id": "$str_id"
|
||||
}},
|
||||
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
|
||||
]
|
||||
|
||||
return await self.collection.aggregate(pipeline).to_list(None)
|
||||
|
||||
async def delete_idea(self, idea_id: str) -> bool:
|
||||
if not ObjectId.is_valid(idea_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea_id)},
|
||||
{"$set": {"is_deleted": True}}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def update_idea(self, idea: Idea) -> bool:
|
||||
if not idea.id or not ObjectId.is_valid(idea.id):
|
||||
return False
|
||||
|
||||
idea_dict = idea.model_dump()
|
||||
if "id" in idea_dict:
|
||||
del idea_dict["id"]
|
||||
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(idea.id)},
|
||||
{"$set": idea_dict}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Inspiration import Inspiration
|
||||
|
||||
|
||||
class InspirationRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["inspirations"]
|
||||
|
||||
async def create_inspiration(self, inspiration: Inspiration) -> str:
|
||||
res = await self.collection.insert_one(inspiration.model_dump(exclude={"id"}))
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(inspiration_id)})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Inspiration(**res)
|
||||
return None
|
||||
|
||||
async def get_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||
query = {}
|
||||
if project_id:
|
||||
query["project_id"] = project_id
|
||||
if created_by:
|
||||
query["created_by"] = created_by
|
||||
|
||||
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
|
||||
inspirations = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
inspirations.append(Inspiration(**doc))
|
||||
return inspirations
|
||||
|
||||
async def count_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None) -> int:
|
||||
query = {}
|
||||
if project_id:
|
||||
query["project_id"] = project_id
|
||||
if created_by:
|
||||
query["created_by"] = created_by
|
||||
return await self.collection.count_documents(query)
|
||||
|
||||
async def update_inspiration(self, inspiration: Inspiration):
|
||||
await self.collection.update_one(
|
||||
{"_id": ObjectId(inspiration.id)},
|
||||
{"$set": inspiration.model_dump(exclude={"id"})}
|
||||
)
|
||||
|
||||
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(inspiration_id)})
|
||||
return res.deleted_count > 0
|
||||
@@ -1,97 +0,0 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Post import Post
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["posts"]
|
||||
|
||||
async def create_post(self, post: Post) -> str:
|
||||
res = await self.collection.insert_one(post.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Post(**res)
|
||||
return None
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
if project_id:
|
||||
match = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
if date_from or date_to:
|
||||
date_filter = {}
|
||||
if date_from:
|
||||
date_filter["$gte"] = date_from
|
||||
if date_to:
|
||||
date_filter["$lte"] = date_to
|
||||
match["date"] = date_filter
|
||||
|
||||
cursor = (
|
||||
self.collection.find(match)
|
||||
.sort("date", -1)
|
||||
.skip(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
posts = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
posts.append(Post(**doc))
|
||||
return posts
|
||||
|
||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": data},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$pull": {"generation_ids": generation_id}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -51,5 +51,4 @@ python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.22
|
||||
email-validator
|
||||
prometheus-fastapi-instrumentator
|
||||
pydantic-settings==2.13.0
|
||||
yt-dlp
|
||||
PyJWT
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -51,66 +51,57 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||
wait_msg = await message.answer("💾 Сохраняю персонажа...")
|
||||
|
||||
try:
|
||||
# 1. Скачиваем файл (один раз)
|
||||
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
|
||||
# ВОТ ТУТ скачиваем файл (прямо перед сохранением)
|
||||
file_io = await bot.download(file_id)
|
||||
file_bytes = file_io.read()
|
||||
# photo_bytes = file_io.getvalue() # Получаем байты
|
||||
|
||||
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
|
||||
|
||||
# Создаем модель
|
||||
char = Character(
|
||||
id=None,
|
||||
name=name,
|
||||
character_image_data=file_io.read(),
|
||||
character_image_tg_id=None,
|
||||
character_image_doc_tg_id=file_id,
|
||||
character_bio=bio,
|
||||
created_by=str(message.from_user.id)
|
||||
)
|
||||
file_io.close()
|
||||
|
||||
# Сохраняем через DAO
|
||||
|
||||
# Сохраняем, чтобы получить ID
|
||||
await dao.chars.add_character(char)
|
||||
|
||||
# 3. Создаем Asset (связанный с персонажем)
|
||||
avatar_asset_id = await dao.assets.create_asset(
|
||||
Asset(
|
||||
name="avatar.png",
|
||||
type=AssetType.UPLOADED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=str(char.id),
|
||||
data=file_bytes,
|
||||
tg_doc_file_id=file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Обновляем персонажа ссылками на ассет
|
||||
char.avatar_asset_id = avatar_asset_id
|
||||
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
|
||||
|
||||
file_info = await bot.get_file(char.character_image_doc_tg_id)
|
||||
file_bytes = await bot.download_file(file_info.file_path)
|
||||
file_io = file_bytes.read()
|
||||
avatar_asset = await dao.assets.create_asset(
|
||||
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||
tg_doc_file_id=file_id))
|
||||
char.avatar_image = avatar_asset.link
|
||||
# Отправляем подтверждение
|
||||
# Используем байты для отправки обратно
|
||||
photo_msg = await message.answer_photo(
|
||||
photo=BufferedInputFile(file_bytes, filename="char.jpg"),
|
||||
photo=BufferedInputFile(file_io,
|
||||
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
|
||||
caption=(
|
||||
"🎉 <b>Персонаж создан!</b>\n\n"
|
||||
f"👤 <b>Имя:</b> {char.name}\n"
|
||||
f"📝 <b>Био:</b> {char.character_bio}"
|
||||
)
|
||||
)
|
||||
file_bytes.close()
|
||||
char.character_image_tg_id = photo_msg.photo[0].file_id
|
||||
|
||||
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
|
||||
char.character_image_tg_id = photo_msg.photo[-1].file_id
|
||||
|
||||
# Финальное обновление персонажа
|
||||
await dao.chars.update_char(char.id, char)
|
||||
|
||||
await wait_msg.delete()
|
||||
file_io.close()
|
||||
|
||||
# Сбрасываем состояние
|
||||
await state.clear()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating character: {e}")
|
||||
traceback.print_exc()
|
||||
logging.error(e)
|
||||
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
|
||||
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
||||
|
||||
|
||||
@router.message(Command("chars"))
|
||||
|
||||
@@ -126,11 +126,12 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||||
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||
await call.answer()
|
||||
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
|
||||
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
|
||||
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
|
||||
keyboards = []
|
||||
for ratio in AspectRatios:
|
||||
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
||||
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
|
||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
||||
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
||||
|
||||
|
||||
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||||
|
||||
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user