Compare commits
3 Commits
198ac44960
...
video
| Author | SHA1 | Date | |
|---|---|---|---|
| 32ff77e04b | |||
| d1f67c773f | |||
| c63b51ef75 |
4
.env
4
.env
@@ -8,4 +8,6 @@ MINIO_ACCESS_KEY=admin
|
|||||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||||
MINIO_BUCKET=ai-char
|
MINIO_BUCKET=ai-char
|
||||||
MODE=production
|
MODE=production
|
||||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||||
|
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
|
||||||
|
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4
|
||||||
17
.gitignore
vendored
17
.gitignore
vendored
@@ -8,19 +8,4 @@ minio_backup.tar.gz
|
|||||||
.idea/ai-char-bot.iml
|
.idea/ai-char-bot.iml
|
||||||
.idea
|
.idea
|
||||||
.venv
|
.venv
|
||||||
.vscode
|
.vscode
|
||||||
.vscode/launch.json
|
|
||||||
middlewares/__pycache__/
|
|
||||||
middlewares/*.pyc
|
|
||||||
api/__pycache__/
|
|
||||||
api/*.pyc
|
|
||||||
repos/__pycache__/
|
|
||||||
repos/*.pyc
|
|
||||||
adapters/__pycache__/
|
|
||||||
adapters/*.pyc
|
|
||||||
services/__pycache__/
|
|
||||||
services/*.pyc
|
|
||||||
utils/__pycache__/
|
|
||||||
utils/*.pyc
|
|
||||||
.vscode/launch.json
|
|
||||||
repos/__pycache__/assets_repo.cpython-313.pyc
|
|
||||||
25
.vscode/launch.json
vendored
25
.vscode/launch.json
vendored
@@ -16,6 +16,31 @@
|
|||||||
],
|
],
|
||||||
"jinja": true,
|
"jinja": true,
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Tests: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
"${file}"
|
||||||
|
],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -23,30 +23,28 @@ class GoogleAdapter:
|
|||||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
self.TEXT_MODEL = "gemini-3-pro-preview"
|
||||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||||
|
|
||||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list:
|
||||||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
contents = [prompt]
|
||||||
contents : list [Any]= [prompt]
|
|
||||||
opened_images = []
|
|
||||||
if images_list:
|
if images_list:
|
||||||
logger.info(f"Preparing content with {len(images_list)} images")
|
logger.info(f"Preparing content with {len(images_list)} images")
|
||||||
for img_bytes in images_list:
|
for img_bytes in images_list:
|
||||||
try:
|
try:
|
||||||
|
# Gemini API требует PIL Image на входе
|
||||||
image = Image.open(io.BytesIO(img_bytes))
|
image = Image.open(io.BytesIO(img_bytes))
|
||||||
contents.append(image)
|
contents.append(image)
|
||||||
opened_images.append(image)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing input image: {e}")
|
logger.error(f"Error processing input image: {e}")
|
||||||
else:
|
else:
|
||||||
logger.info("Preparing content with no images")
|
logger.info("Preparing content with no images")
|
||||||
return contents, opened_images
|
return contents
|
||||||
|
|
||||||
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
|
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Генерация текста (Чат или Vision).
|
Генерация текста (Чат или Vision).
|
||||||
Возвращает строку с ответом.
|
Возвращает строку с ответом.
|
||||||
"""
|
"""
|
||||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
contents = self._prepare_contents(prompt, images_list)
|
||||||
logger.info(f"Generating text: {prompt}")
|
logger.info(f"Generating text: {prompt}")
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
@@ -70,17 +68,14 @@ class GoogleAdapter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini Text API Error: {e}")
|
logger.error(f"Gemini Text API Error: {e}")
|
||||||
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||||||
finally:
|
|
||||||
for img in opened_images:
|
|
||||||
img.close()
|
|
||||||
|
|
||||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||||
Возвращает список байтовых потоков (готовых к отправке).
|
Возвращает список байтовых потоков (готовых к отправке).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
contents = self._prepare_contents(prompt, images_list)
|
||||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
||||||
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
@@ -105,21 +100,9 @@ class GoogleAdapter:
|
|||||||
|
|
||||||
if response.usage_metadata:
|
if response.usage_metadata:
|
||||||
token_usage = response.usage_metadata.total_token_count
|
token_usage = response.usage_metadata.total_token_count
|
||||||
|
|
||||||
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
|
if response.parts is None and response.candidates[0].finish_reason is not None:
|
||||||
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
||||||
raise GoogleGenerationException(
|
|
||||||
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check candidate-level block
|
|
||||||
if response.parts is None:
|
|
||||||
response_reason = (
|
|
||||||
response.candidates[0].finish_reason
|
|
||||||
if response.candidates and len(response.candidates) > 0
|
|
||||||
else "Unknown"
|
|
||||||
)
|
|
||||||
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
|
|
||||||
|
|
||||||
generated_images = []
|
generated_images = []
|
||||||
|
|
||||||
@@ -130,9 +113,7 @@ class GoogleAdapter:
|
|||||||
try:
|
try:
|
||||||
# 1. Берем сырые байты
|
# 1. Берем сырые байты
|
||||||
raw_data = part.inline_data.data
|
raw_data = part.inline_data.data
|
||||||
if raw_data is None:
|
byte_arr = io.BytesIO(raw_data)
|
||||||
raise GoogleGenerationException("Generation returned no data")
|
|
||||||
byte_arr : io.BytesIO = io.BytesIO(raw_data)
|
|
||||||
|
|
||||||
# 2. Нейминг (формально, для TG)
|
# 2. Нейминг (формально, для TG)
|
||||||
timestamp = datetime.now().timestamp()
|
timestamp = datetime.now().timestamp()
|
||||||
@@ -166,8 +147,4 @@ class GoogleAdapter:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini Image API Error: {e}")
|
logger.error(f"Gemini Image API Error: {e}")
|
||||||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||||
finally:
|
|
||||||
for img in opened_images:
|
|
||||||
img.close()
|
|
||||||
del contents
|
|
||||||
165
adapters/kling_adapter.py
Normal file
165
adapters/kling_adapter.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KLING_API_BASE = "https://api.klingai.com"
|
||||||
|
|
||||||
|
|
||||||
|
class KlingApiException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KlingAdapter:
|
||||||
|
def __init__(self, access_key: str, secret_key: str):
|
||||||
|
if not access_key or not secret_key:
|
||||||
|
raise ValueError("Kling API credentials are missing")
|
||||||
|
self.access_key = access_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
|
||||||
|
def _generate_token(self) -> str:
|
||||||
|
"""Generate a JWT token for Kling API authentication."""
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"iss": self.access_key,
|
||||||
|
"exp": now + 1800, # 30 minutes
|
||||||
|
"iat": now - 5, # небольшой запас назад
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, self.secret_key, algorithm="HS256",
|
||||||
|
headers={"typ": "JWT", "alg": "HS256"})
|
||||||
|
|
||||||
|
def _headers(self) -> dict:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self._generate_token()}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def create_video_task(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
prompt: str = "",
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_name: str = "kling-v2-6",
|
||||||
|
duration: int = 5,
|
||||||
|
mode: str = "std",
|
||||||
|
cfg_scale: float = 0.5,
|
||||||
|
aspect_ratio: str = "16:9",
|
||||||
|
callback_url: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create an image-to-video generation task.
|
||||||
|
Returns the full task data dict including task_id.
|
||||||
|
"""
|
||||||
|
body: Dict[str, Any] = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"image": image_url,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"duration": str(duration),
|
||||||
|
"mode": mode,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
}
|
||||||
|
if callback_url:
|
||||||
|
body["callback_url"] = callback_url
|
||||||
|
|
||||||
|
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video",
|
||||||
|
headers=self._headers(),
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown Kling API error")
|
||||||
|
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
|
||||||
|
|
||||||
|
task_data = data.get("data", {})
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
raise KlingApiException("No task_id returned from Kling API")
|
||||||
|
|
||||||
|
logger.info(f"Kling video task created: task_id={task_id}")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Query the status of a video generation task.
|
||||||
|
Returns the full task data dict.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown error")
|
||||||
|
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
|
||||||
|
|
||||||
|
return data.get("data", {})
|
||||||
|
|
||||||
|
async def wait_for_completion(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 600,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Poll the task status until completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Kling task ID
|
||||||
|
poll_interval: seconds between polls
|
||||||
|
timeout: max seconds to wait
|
||||||
|
progress_callback: async callable(progress_pct: int) to report progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final task data dict with video URL on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KlingApiException on failure or timeout.
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
if elapsed > timeout:
|
||||||
|
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
|
||||||
|
|
||||||
|
task_data = await self.get_task_status(task_id)
|
||||||
|
status = task_data.get("task_status")
|
||||||
|
|
||||||
|
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||||
|
|
||||||
|
if status == "succeed":
|
||||||
|
logger.info(f"Kling task {task_id} completed successfully")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
if status == "failed":
|
||||||
|
fail_reason = task_data.get("task_status_msg", "Unknown failure")
|
||||||
|
raise KlingApiException(f"Video generation failed: {fail_reason}")
|
||||||
|
|
||||||
|
# Report progress estimate (linear approximation based on typical time)
|
||||||
|
if progress_callback:
|
||||||
|
# Estimate: typical gen is ~120s, cap at 90%
|
||||||
|
estimated_progress = min(int((elapsed / 120) * 90), 90)
|
||||||
|
attempt += 1
|
||||||
|
await progress_callback(estimated_progress)
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
@@ -18,7 +18,7 @@ class S3Adapter:
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_client(self):
|
async def _get_client(self):
|
||||||
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
|
async with self.session.client(
|
||||||
"s3",
|
"s3",
|
||||||
endpoint_url=self.endpoint_url,
|
endpoint_url=self.endpoint_url,
|
||||||
aws_access_key_id=self.aws_access_key_id,
|
aws_access_key_id=self.aws_access_key_id,
|
||||||
@@ -56,21 +56,6 @@ class S3Adapter:
|
|||||||
print(f"Error downloading from S3: {e}")
|
print(f"Error downloading from S3: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def stream_file(self, object_name: str, chunk_size: int = 65536):
|
|
||||||
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
|
||||||
try:
|
|
||||||
async with self._get_client() as client:
|
|
||||||
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
|
|
||||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
|
||||||
body = response['Body']
|
|
||||||
data = await body.read()
|
|
||||||
# Yield in chunks to avoid holding entire response in StreamingResponse buffer
|
|
||||||
for i in range(0, len(data), chunk_size):
|
|
||||||
yield data[i:i + chunk_size]
|
|
||||||
except ClientError as e:
|
|
||||||
print(f"Error streaming from S3: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
async def delete_file(self, object_name: str):
|
async def delete_file(self, object_name: str):
|
||||||
"""Deletes a file from S3."""
|
"""Deletes a file from S3."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
60
aiws.py
60
aiws.py
@@ -18,6 +18,7 @@ from prometheus_fastapi_instrumentator import Instrumentator
|
|||||||
|
|
||||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from api.service.album_service import AlbumService
|
from api.service.album_service import AlbumService
|
||||||
@@ -43,8 +44,6 @@ from api.endpoints.auth import router as api_auth_router
|
|||||||
from api.endpoints.admin import router as api_admin_router
|
from api.endpoints.admin import router as api_admin_router
|
||||||
from api.endpoints.album_router import router as api_album_router
|
from api.endpoints.album_router import router as api_album_router
|
||||||
from api.endpoints.project_router import router as project_api_router
|
from api.endpoints.project_router import router as project_api_router
|
||||||
from api.endpoints.idea_router import router as idea_api_router
|
|
||||||
from api.endpoints.post_router import router as post_api_router
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -65,8 +64,6 @@ def setup_logging():
|
|||||||
|
|
||||||
|
|
||||||
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
||||||
if BOT_TOKEN is None:
|
|
||||||
raise ValueError("BOT_TOKEN is not set")
|
|
||||||
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
|
||||||
|
|
||||||
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
||||||
@@ -86,12 +83,19 @@ s3_adapter = S3Adapter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||||
if GEMINI_API_KEY is None:
|
|
||||||
raise ValueError("GEMINI_API_KEY is not set")
|
|
||||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||||
if bot is None:
|
|
||||||
raise ValueError("bot is not set")
|
# Kling Adapter (optional, for video generation)
|
||||||
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
|
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
|
||||||
|
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
|
||||||
|
kling_adapter = None
|
||||||
|
if kling_access_key and kling_secret_key:
|
||||||
|
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
|
||||||
|
logger.info("Kling adapter initialized")
|
||||||
|
else:
|
||||||
|
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
|
||||||
|
|
||||||
|
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
|
||||||
album_service = AlbumService(dao)
|
album_service = AlbumService(dao)
|
||||||
|
|
||||||
# Dispatcher
|
# Dispatcher
|
||||||
@@ -128,18 +132,6 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
|
|||||||
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
||||||
|
|
||||||
|
|
||||||
async def start_scheduler(service: GenerationService):
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
logger.info("Running scheduler for stacked generation killing")
|
|
||||||
await service.cleanup_stale_generations()
|
|
||||||
await service.cleanup_old_data(days=2)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Scheduler error: {e}")
|
|
||||||
await asyncio.sleep(60) # Check every 60 seconds
|
|
||||||
|
|
||||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -158,6 +150,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app.state.gemini_client = gemini
|
app.state.gemini_client = gemini
|
||||||
app.state.bot = bot
|
app.state.bot = bot
|
||||||
app.state.s3_adapter = s3_adapter
|
app.state.s3_adapter = s3_adapter
|
||||||
|
app.state.kling_adapter = kling_adapter
|
||||||
app.state.album_service = album_service
|
app.state.album_service = album_service
|
||||||
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||||
|
|
||||||
@@ -171,28 +164,17 @@ async def lifespan(app: FastAPI):
|
|||||||
# )
|
# )
|
||||||
# print("🤖 Bot polling started")
|
# print("🤖 Bot polling started")
|
||||||
|
|
||||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
|
||||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
|
||||||
print("⏰ Scheduler started")
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# --- SHUTDOWN ---
|
# --- SHUTDOWN ---
|
||||||
print("🛑 Shutting down...")
|
print("🛑 Shutting down...")
|
||||||
|
|
||||||
# 4. Остановка шедулера
|
|
||||||
scheduler_task.cancel()
|
|
||||||
try:
|
|
||||||
await scheduler_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
print("⏰ Scheduler stopped")
|
|
||||||
|
|
||||||
# 3. Остановка бота
|
# 3. Остановка бота
|
||||||
# polling_task.cancel()
|
polling_task.cancel()
|
||||||
# try:
|
try:
|
||||||
# await polling_task
|
await polling_task
|
||||||
# except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# print("🤖 Bot polling stopped")
|
print("🤖 Bot polling stopped")
|
||||||
|
|
||||||
# 4. Отключение БД
|
# 4. Отключение БД
|
||||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||||
@@ -219,8 +201,6 @@ app.include_router(api_char_router)
|
|||||||
app.include_router(api_gen_router)
|
app.include_router(api_gen_router)
|
||||||
app.include_router(api_album_router)
|
app.include_router(api_album_router)
|
||||||
app.include_router(project_api_router)
|
app.include_router(project_api_router)
|
||||||
app.include_router(idea_api_router)
|
|
||||||
app.include_router(post_api_router)
|
|
||||||
|
|
||||||
# Prometheus Metrics (Instrument after all routers are added)
|
# Prometheus Metrics (Instrument after all routers are added)
|
||||||
Instrumentator(
|
Instrumentator(
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -3,9 +3,9 @@ from fastapi import Request, Depends
|
|||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.service.album_service import AlbumService
|
|
||||||
|
|
||||||
|
|
||||||
# ... ваши импорты ...
|
# ... ваши импорты ...
|
||||||
@@ -37,29 +37,20 @@ def get_dao(
|
|||||||
# так что DAO создастся один раз за запрос.
|
# так что DAO создастся один раз за запрос.
|
||||||
return DAO(mongo_client, s3_adapter)
|
return DAO(mongo_client, s3_adapter)
|
||||||
|
|
||||||
|
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
|
||||||
|
return request.app.state.kling_adapter
|
||||||
|
|
||||||
# Провайдер сервиса (собирается из DAO и Gemini)
|
# Провайдер сервиса (собирается из DAO и Gemini)
|
||||||
def get_generation_service(
|
def get_generation_service(
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
gemini: GoogleAdapter = Depends(get_gemini_client),
|
gemini: GoogleAdapter = Depends(get_gemini_client),
|
||||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||||
bot: Bot = Depends(get_bot_client),
|
bot: Bot = Depends(get_bot_client),
|
||||||
|
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
|
||||||
) -> GenerationService:
|
) -> GenerationService:
|
||||||
return GenerationService(dao, gemini, s3_adapter, bot)
|
return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
|
||||||
|
|
||||||
from api.service.idea_service import IdeaService
|
|
||||||
|
|
||||||
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
|
|
||||||
return IdeaService(dao)
|
|
||||||
|
|
||||||
from fastapi import Header
|
from fastapi import Header
|
||||||
|
|
||||||
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||||
return x_project_id
|
return x_project_id
|
||||||
|
|
||||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
|
||||||
return AlbumService(dao)
|
|
||||||
|
|
||||||
from api.service.post_service import PostService
|
|
||||||
|
|
||||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
|
||||||
return PostService(dao)
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -23,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
username: str | None = payload.get("sub")
|
username: str = payload.get("sub")
|
||||||
if username is None:
|
if username is None:
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
except JWTError:
|
except JWTError:
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, HTTPException, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from api.models.GenerationRequest import GenerationResponse
|
from api.models.GenerationRequest import GenerationResponse
|
||||||
from models.Album import Album
|
from models.Album import Album
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.dependency import get_album_service
|
|
||||||
from api.service.album_service import AlbumService
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from pymongo import MongoClient
|
|||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response, JSONResponse, StreamingResponse
|
from starlette.responses import Response, JSONResponse
|
||||||
|
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||||
@@ -33,46 +33,27 @@ async def get_asset(
|
|||||||
asset_id: str,
|
asset_id: str,
|
||||||
request: Request,
|
request: Request,
|
||||||
thumbnail: bool = False,
|
thumbnail: bool = False,
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao)
|
||||||
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
|
||||||
) -> Response:
|
) -> Response:
|
||||||
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
||||||
# Загружаем только метаданные (без data/thumbnail bytes)
|
asset = await dao.assets.get_asset(asset_id)
|
||||||
asset = await dao.assets.get_asset(asset_id, with_data=False)
|
# 2. Проверка на существование
|
||||||
if not asset:
|
if not asset:
|
||||||
raise HTTPException(status_code=404, detail="Asset not found")
|
raise HTTPException(status_code=404, detail="Asset not found")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
# Кэшировать на 1 год (31536000 сек)
|
||||||
"Cache-Control": "public, max-age=31536000, immutable"
|
"Cache-Control": "public, max-age=31536000, immutable"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Thumbnail: маленький, можно грузить в RAM
|
content = asset.data
|
||||||
if thumbnail:
|
media_type = "image/png" # Default, or detect
|
||||||
if asset.minio_thumbnail_object_name and s3_adapter:
|
|
||||||
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
|
||||||
if thumb_bytes:
|
|
||||||
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
|
|
||||||
# Fallback: thumbnail in DB
|
|
||||||
if asset.thumbnail:
|
|
||||||
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
|
|
||||||
# No thumbnail available — fall through to main content
|
|
||||||
|
|
||||||
# Main content: стримим из S3 без загрузки в RAM
|
if thumbnail and asset.thumbnail:
|
||||||
if asset.minio_object_name and s3_adapter:
|
content = asset.thumbnail
|
||||||
content_type = "image/png"
|
media_type = "image/jpeg"
|
||||||
# if asset.content_type == AssetContentType.VIDEO:
|
|
||||||
# content_type = "video/mp4"
|
return Response(content=content, media_type=media_type, headers=headers)
|
||||||
return StreamingResponse(
|
|
||||||
s3_adapter.stream_file(asset.minio_object_name),
|
|
||||||
media_type=content_type,
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback: data stored in DB (legacy)
|
|
||||||
if asset.data:
|
|
||||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
|
||||||
|
|
||||||
raise HTTPException(status_code=404, detail="Asset data not found")
|
|
||||||
|
|
||||||
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||||
async def delete_orphan_assets_from_minio(
|
async def delete_orphan_assets_from_minio(
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from api import service
|
|||||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
|
||||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
|
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from models.Generation import Generation
|
from models.Generation import Generation
|
||||||
|
|
||||||
@@ -68,12 +69,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
|
|||||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
@router.post("/_run", response_model=GenerationResponse)
|
||||||
async def post_generation(generation: GenerationRequest, request: Request,
|
async def post_generation(generation: GenerationRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
dao: DAO = Depends(get_dao)) -> GenerationResponse:
|
||||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||||
|
|
||||||
if project_id:
|
if project_id:
|
||||||
@@ -85,6 +86,16 @@ async def post_generation(generation: GenerationRequest, request: Request,
|
|||||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||||
|
async def get_generation(generation_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||||
|
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||||
|
gen = await generation_service.get_generation(generation_id)
|
||||||
|
if gen and gen.created_by != str(current_user["_id"]):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
return gen
|
||||||
|
|
||||||
|
|
||||||
@router.get("/running")
|
@router.get("/running")
|
||||||
async def get_running_generations(request: Request,
|
async def get_running_generations(request: Request,
|
||||||
@@ -103,27 +114,25 @@ async def get_running_generations(request: Request,
|
|||||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
@router.post("/video/_run", response_model=GenerationResponse)
|
||||||
async def get_generation_group(group_id: str,
|
async def post_video_generation(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
video_request: VideoGenerationRequest,
|
||||||
current_user: dict = Depends(get_current_user)):
|
request: Request,
|
||||||
logger.info(f"get_generation_group called for group_id: {group_id}")
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
current_user: dict = Depends(get_current_user),
|
||||||
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
dao: DAO = Depends(get_dao),
|
||||||
|
) -> GenerationResponse:
|
||||||
|
"""Start image-to-video generation using Kling AI."""
|
||||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
|
||||||
async def get_generation(generation_id: str,
|
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
|
||||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
|
||||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
|
||||||
gen = await generation_service.get_generation(generation_id)
|
|
||||||
if gen and gen.created_by != str(current_user["_id"]):
|
|
||||||
raise HTTPException(status_code=403, detail="Access denied")
|
|
||||||
return gen
|
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
video_request.project_id = project_id
|
||||||
|
|
||||||
|
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=GenerationResponse)
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
|
|||||||
@@ -1,103 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
|
||||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
|
||||||
from api.endpoints.auth import get_current_user
|
|
||||||
from api.service.idea_service import IdeaService
|
|
||||||
from api.service.generation_service import GenerationService
|
|
||||||
from models.Idea import Idea
|
|
||||||
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
|
|
||||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
|
||||||
|
|
||||||
@router.post("", response_model=Idea)
|
|
||||||
async def create_idea(
|
|
||||||
request: IdeaCreateRequest,
|
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
pid = project_id or request.project_id
|
|
||||||
|
|
||||||
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
|
|
||||||
|
|
||||||
@router.get("", response_model=List[IdeaResponse])
|
|
||||||
async def get_ideas(
|
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
|
||||||
limit: int = 20,
|
|
||||||
offset: int = 0,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
|
|
||||||
|
|
||||||
@router.get("/{idea_id}", response_model=Idea)
|
|
||||||
async def get_idea(
|
|
||||||
idea_id: str,
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
idea = await idea_service.get_idea(idea_id)
|
|
||||||
if not idea:
|
|
||||||
raise HTTPException(status_code=404, detail="Idea not found")
|
|
||||||
return idea
|
|
||||||
|
|
||||||
@router.put("/{idea_id}", response_model=Idea)
|
|
||||||
async def update_idea(
|
|
||||||
idea_id: str,
|
|
||||||
request: IdeaUpdateRequest,
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
idea = await idea_service.update_idea(idea_id, request.name, request.description)
|
|
||||||
if not idea:
|
|
||||||
raise HTTPException(status_code=404, detail="Idea not found")
|
|
||||||
return idea
|
|
||||||
|
|
||||||
@router.delete("/{idea_id}")
|
|
||||||
async def delete_idea(
|
|
||||||
idea_id: str,
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
success = await idea_service.delete_idea(idea_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
|
|
||||||
return {"status": "success"}
|
|
||||||
|
|
||||||
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
|
|
||||||
async def get_idea_generations(
|
|
||||||
idea_id: str,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
generation_service: GenerationService = Depends(get_generation_service)
|
|
||||||
):
|
|
||||||
# Depending on how generation service implements filtering by idea_id.
|
|
||||||
# We might need to update generation_service to support getting by idea_id directly
|
|
||||||
# or ensure generic get_generations supports it.
|
|
||||||
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
|
|
||||||
# Let's check generation_service.get_generations signature again.
|
|
||||||
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
|
|
||||||
# I need to update GenerationService.get_generations too!
|
|
||||||
|
|
||||||
# For now, let's assume I will update it.
|
|
||||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
|
|
||||||
|
|
||||||
@router.post("/{idea_id}/generations/{generation_id}")
|
|
||||||
async def add_generation_to_idea(
|
|
||||||
idea_id: str,
|
|
||||||
generation_id: str,
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
|
||||||
return {"status": "success"}
|
|
||||||
|
|
||||||
@router.delete("/{idea_id}/generations/{generation_id}")
|
|
||||||
async def remove_generation_from_idea(
|
|
||||||
idea_id: str,
|
|
||||||
generation_id: str,
|
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
|
||||||
):
|
|
||||||
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Idea or Generation not found")
|
|
||||||
return {"status": "success"}
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
|
|
||||||
from api.dependency import get_post_service, get_project_id
|
|
||||||
from api.endpoints.auth import get_current_user
|
|
||||||
from api.service.post_service import PostService
|
|
||||||
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
|
||||||
from models.Post import Post
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=Post)
|
|
||||||
async def create_post(
|
|
||||||
request: PostCreateRequest,
|
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
pid = project_id or request.project_id
|
|
||||||
return await post_service.create_post(
|
|
||||||
date=request.date,
|
|
||||||
topic=request.topic,
|
|
||||||
generation_ids=request.generation_ids,
|
|
||||||
project_id=pid,
|
|
||||||
user_id=str(current_user["_id"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[Post])
|
|
||||||
async def get_posts(
|
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
|
||||||
limit: int = 200,
|
|
||||||
offset: int = 0,
|
|
||||||
date_from: Optional[datetime] = None,
|
|
||||||
date_to: Optional[datetime] = None,
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{post_id}", response_model=Post)
|
|
||||||
async def get_post(
|
|
||||||
post_id: str,
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
post = await post_service.get_post(post_id)
|
|
||||||
if not post:
|
|
||||||
raise HTTPException(status_code=404, detail="Post not found")
|
|
||||||
return post
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/{post_id}", response_model=Post)
|
|
||||||
async def update_post(
|
|
||||||
post_id: str,
|
|
||||||
request: PostUpdateRequest,
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
|
||||||
if not post:
|
|
||||||
raise HTTPException(status_code=404, detail="Post not found")
|
|
||||||
return post
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{post_id}")
|
|
||||||
async def delete_post(
|
|
||||||
post_id: str,
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
success = await post_service.delete_post(post_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
|
||||||
return {"status": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{post_id}/generations")
|
|
||||||
async def add_generations(
|
|
||||||
post_id: str,
|
|
||||||
request: AddGenerationsRequest,
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Post not found")
|
|
||||||
return {"status": "success"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{post_id}/generations/{generation_id}")
|
|
||||||
async def remove_generation(
|
|
||||||
post_id: str,
|
|
||||||
generation_id: str,
|
|
||||||
post_service: PostService = Depends(get_post_service),
|
|
||||||
):
|
|
||||||
success = await post_service.remove_generation(post_id, generation_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
|
||||||
return {"status": "success"}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Generation import GenerationStatus
|
from models.Generation import GenerationStatus
|
||||||
@@ -17,8 +17,6 @@ class GenerationRequest(BaseModel):
|
|||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
idea_id: Optional[str] = None
|
|
||||||
count: int = Field(default=1, ge=1, le=10)
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationsResponse(BaseModel):
|
class GenerationsResponse(BaseModel):
|
||||||
@@ -29,6 +27,7 @@ class GenerationsResponse(BaseModel):
|
|||||||
class GenerationResponse(BaseModel):
|
class GenerationResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: GenerationStatus
|
status: GenerationStatus
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
|
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
@@ -47,16 +46,14 @@ class GenerationResponse(BaseModel):
|
|||||||
progress: int = 0
|
progress: int = 0
|
||||||
cost: Optional[float] = None
|
cost: Optional[float] = None
|
||||||
created_by: Optional[str] = None
|
created_by: Optional[str] = None
|
||||||
generation_group_id: Optional[str] = None
|
# Video-specific
|
||||||
idea_id: Optional[str] = None
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None
|
||||||
|
video_mode: Optional[str] = None
|
||||||
created_at: datetime = datetime.now(UTC)
|
created_at: datetime = datetime.now(UTC)
|
||||||
updated_at: datetime = datetime.now(UTC)
|
updated_at: datetime = datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
class GenerationGroupResponse(BaseModel):
|
|
||||||
generation_group_id: str
|
|
||||||
generations: List[GenerationResponse]
|
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from models.Idea import Idea
|
|
||||||
from api.models.GenerationRequest import GenerationResponse
|
|
||||||
|
|
||||||
class IdeaCreateRequest(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
project_id: Optional[str] = None # Optional in body if passed via header/dependency
|
|
||||||
|
|
||||||
class IdeaUpdateRequest(BaseModel):
|
|
||||||
name: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
|
|
||||||
class IdeaResponse(Idea):
|
|
||||||
last_generation: Optional[GenerationResponse] = None
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class PostCreateRequest(BaseModel):
|
|
||||||
date: datetime
|
|
||||||
topic: str
|
|
||||||
generation_ids: List[str] = []
|
|
||||||
project_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PostUpdateRequest(BaseModel):
|
|
||||||
date: Optional[datetime] = None
|
|
||||||
topic: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AddGenerationsRequest(BaseModel):
|
|
||||||
generation_ids: List[str]
|
|
||||||
16
api/models/VideoGenerationRequest.py
Normal file
16
api/models/VideoGenerationRequest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VideoGenerationRequest(BaseModel):
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: Optional[str] = ""
|
||||||
|
image_asset_id: str # ID ассета-картинки для source image
|
||||||
|
duration: int = 5 # 5 or 10 seconds
|
||||||
|
mode: str = "std" # "std" or "pro"
|
||||||
|
model_name: str = "kling-v2-1"
|
||||||
|
cfg_scale: float = 0.5
|
||||||
|
aspect_ratio: str = "16:9"
|
||||||
|
linked_character_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -5,14 +5,15 @@ import base64
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional, Tuple, Any, Dict
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from uuid import uuid4
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
from adapters.kling_adapter import KlingAdapter, KlingApiException
|
||||||
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||||
from models.Asset import Asset, AssetType, AssetContentType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Generation import Generation, GenerationStatus
|
from models.Generation import Generation, GenerationStatus
|
||||||
@@ -22,9 +23,6 @@ from adapters.s3_adapter import S3Adapter
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Limit concurrent generations to 4
|
|
||||||
generation_semaphore = asyncio.Semaphore(4)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Вспомогательная функция генерации ---
|
# --- Вспомогательная функция генерации ---
|
||||||
async def generate_image_task(
|
async def generate_image_task(
|
||||||
@@ -54,30 +52,29 @@ async def generate_image_task(
|
|||||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||||
except GoogleGenerationException as e:
|
except GoogleGenerationException as e:
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
# Освобождаем входные данные — они больше не нужны
|
|
||||||
del media_group_bytes
|
|
||||||
|
|
||||||
images_bytes = []
|
images_bytes = []
|
||||||
if generated_images_io:
|
if generated_images_io:
|
||||||
for img_io in generated_images_io:
|
for img_io in generated_images_io:
|
||||||
|
# Читаем байты из BytesIO
|
||||||
img_io.seek(0)
|
img_io.seek(0)
|
||||||
images_bytes.append(img_io.read())
|
content = img_io.read()
|
||||||
|
images_bytes.append(content)
|
||||||
|
|
||||||
|
# Закрываем поток
|
||||||
img_io.close()
|
img_io.close()
|
||||||
# Освобождаем список BytesIO сразу
|
|
||||||
del generated_images_io
|
|
||||||
|
|
||||||
return images_bytes, metrics
|
return images_bytes, metrics
|
||||||
|
|
||||||
class GenerationService:
|
class GenerationService:
|
||||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
self.gemini = gemini
|
self.gemini = gemini
|
||||||
self.s3_adapter = s3_adapter
|
self.s3_adapter = s3_adapter
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
self.kling_adapter = kling_adapter
|
||||||
|
|
||||||
|
|
||||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||||
@@ -100,9 +97,10 @@ class GenerationService:
|
|||||||
|
|
||||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
|
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
|
||||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
Generation]:
|
||||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
|
||||||
|
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
|
||||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||||
|
|
||||||
@@ -116,50 +114,29 @@ class GenerationService:
|
|||||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||||
|
|
||||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
count = generation_request.count
|
|
||||||
|
|
||||||
if generation_group_id is None:
|
|
||||||
generation_group_id = str(uuid4())
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for _ in range(count):
|
|
||||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
|
||||||
results.append(gen_response)
|
|
||||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
|
||||||
|
|
||||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
|
|
||||||
gen_id = None
|
gen_id = None
|
||||||
generation_model = None
|
generation_model = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
generation_model = Generation(**generation_request.model_dump())
|
||||||
if user_id:
|
if user_id:
|
||||||
generation_model.created_by = user_id
|
generation_model.created_by = user_id
|
||||||
if generation_group_id:
|
|
||||||
generation_model.generation_group_id = generation_group_id
|
|
||||||
|
|
||||||
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
|
|
||||||
if generation_request.idea_id:
|
|
||||||
generation_model.idea_id = generation_request.idea_id
|
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||||
generation_model.id = gen_id
|
generation_model.id = gen_id
|
||||||
|
|
||||||
async def runner(gen):
|
async def runner(gen):
|
||||||
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
|
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||||
try:
|
try:
|
||||||
async with generation_semaphore:
|
await self.create_generation(gen)
|
||||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||||
await self.create_generation(gen)
|
|
||||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# если генерация уже пошла и упала — пометим FAILED
|
# если генерация уже пошла и упала — пометим FAILED
|
||||||
try:
|
try:
|
||||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||||
if db_gen is not None:
|
db_gen.status = GenerationStatus.FAILED
|
||||||
db_gen.status = GenerationStatus.FAILED
|
await self.dao.generations.update_generation(db_gen)
|
||||||
await self.dao.generations.update_generation(db_gen)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to mark generation as FAILED")
|
logger.exception("Failed to mark generation as FAILED")
|
||||||
logger.exception("create_generation task failed")
|
logger.exception("create_generation task failed")
|
||||||
@@ -173,9 +150,8 @@ class GenerationService:
|
|||||||
if gen_id is not None:
|
if gen_id is not None:
|
||||||
try:
|
try:
|
||||||
gen = await self.dao.generations.get_generation(gen_id)
|
gen = await self.dao.generations.get_generation(gen_id)
|
||||||
if gen is not None:
|
gen.status = GenerationStatus.FAILED
|
||||||
gen.status = GenerationStatus.FAILED
|
await self.dao.generations.update_generation(gen)
|
||||||
await self.dao.generations.update_generation(gen)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
||||||
raise
|
raise
|
||||||
@@ -203,10 +179,9 @@ class GenerationService:
|
|||||||
if char_info is None:
|
if char_info is None:
|
||||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||||
if generation.use_profile_image:
|
if generation.use_profile_image:
|
||||||
if char_info.avatar_asset_id is not None:
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
if avatar_asset:
|
||||||
if avatar_asset and avatar_asset.data:
|
media_group_bytes.append(avatar_asset.data)
|
||||||
media_group_bytes.append(avatar_asset.data)
|
|
||||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||||
|
|
||||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
@@ -307,9 +282,7 @@ class GenerationService:
|
|||||||
|
|
||||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
||||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||||
result_ids = []
|
result_ids = [a.id for a in created_assets]
|
||||||
for a in created_assets:
|
|
||||||
result_ids.append(a.id)
|
|
||||||
|
|
||||||
generation.result_list = result_ids
|
generation.result_list = result_ids
|
||||||
generation.status = GenerationStatus.DONE
|
generation.status = GenerationStatus.DONE
|
||||||
@@ -458,6 +431,168 @@ class GenerationService:
|
|||||||
|
|
||||||
return generation
|
return generation
|
||||||
|
|
||||||
|
# === VIDEO GENERATION (Kling) ===
|
||||||
|
|
||||||
|
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
|
"""Create a video generation task (async, returns immediately)."""
|
||||||
|
if not self.kling_adapter:
|
||||||
|
raise Exception("Kling adapter is not configured")
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.RUNNING,
|
||||||
|
gen_type=GenType.VIDEO,
|
||||||
|
linked_character_id=request.linked_character_id,
|
||||||
|
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
|
||||||
|
quality=Quality.ONEK,
|
||||||
|
prompt=request.prompt,
|
||||||
|
assets_list=[request.image_asset_id],
|
||||||
|
video_duration=request.duration,
|
||||||
|
video_mode=request.mode,
|
||||||
|
project_id=request.project_id,
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
generation.created_by = user_id
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
async def runner(gen, req):
|
||||||
|
logger.info(f"Starting background video generation task for ID: {gen.id}")
|
||||||
|
try:
|
||||||
|
await self.create_video_generation(gen, req)
|
||||||
|
logger.info(f"Background video generation task finished for ID: {gen.id}")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||||
|
if db_gen and db_gen.status != GenerationStatus.FAILED:
|
||||||
|
db_gen.status = GenerationStatus.FAILED
|
||||||
|
db_gen.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(db_gen)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to mark video generation as FAILED")
|
||||||
|
logger.exception("create_video_generation task failed")
|
||||||
|
|
||||||
|
asyncio.create_task(runner(generation, request))
|
||||||
|
return GenerationResponse(**generation.model_dump())
|
||||||
|
|
||||||
|
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
|
||||||
|
"""Background video generation: call Kling API, poll, download result, save asset."""
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Get source image presigned URL
|
||||||
|
asset = await self.dao.assets.get_asset(request.image_asset_id)
|
||||||
|
if not asset:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} not found")
|
||||||
|
|
||||||
|
if not asset.minio_object_name:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
|
||||||
|
|
||||||
|
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
|
||||||
|
if not presigned_url:
|
||||||
|
raise Exception("Failed to generate presigned URL for source image")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
|
||||||
|
|
||||||
|
# 2. Create Kling task
|
||||||
|
task_data = await self.kling_adapter.create_video_task(
|
||||||
|
image_url=presigned_url,
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt or "",
|
||||||
|
model_name=request.model_name,
|
||||||
|
duration=request.duration,
|
||||||
|
mode=request.mode,
|
||||||
|
cfg_scale=request.cfg_scale,
|
||||||
|
aspect_ratio=request.aspect_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
generation.kling_task_id = task_id
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
|
||||||
|
|
||||||
|
# 3. Poll for completion with progress updates
|
||||||
|
async def progress_callback(progress_pct: int):
|
||||||
|
generation.progress = progress_pct
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
result = await self.kling_adapter.wait_for_completion(
|
||||||
|
task_id=task_id,
|
||||||
|
poll_interval=10,
|
||||||
|
timeout=600,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Extract video URL and download
|
||||||
|
works = result.get("task_result", {}).get("videos", [])
|
||||||
|
if not works:
|
||||||
|
raise Exception("No video in Kling result")
|
||||||
|
|
||||||
|
video_url = works[0].get("url")
|
||||||
|
video_duration = works[0].get("duration", request.duration)
|
||||||
|
if not video_url:
|
||||||
|
raise Exception("No video URL in Kling result")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
video_response = await client.get(video_url)
|
||||||
|
video_response.raise_for_status()
|
||||||
|
video_bytes = video_response.content
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
|
||||||
|
|
||||||
|
# 5. Upload to S3
|
||||||
|
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
|
||||||
|
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
|
||||||
|
|
||||||
|
# 6. Create Asset
|
||||||
|
new_asset = Asset(
|
||||||
|
name=f"Video_{generation.linked_character_id or 'gen'}",
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.VIDEO,
|
||||||
|
linked_char_id=generation.linked_character_id,
|
||||||
|
data=None,
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=None, # видео thumbnails можно добавить позже
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
|
||||||
|
# 7. Finalize generation
|
||||||
|
end_time = datetime.now()
|
||||||
|
generation.result_list = [new_asset.id]
|
||||||
|
generation.result = new_asset.id
|
||||||
|
generation.status = GenerationStatus.DONE
|
||||||
|
generation.progress = 100
|
||||||
|
generation.video_duration = video_duration
|
||||||
|
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
|
||||||
|
|
||||||
|
except KlingApiException as e:
|
||||||
|
logger.error(f"Kling API error for generation {generation.id}: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Video generation {generation.id} failed: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
|
||||||
async def delete_generation(self, generation_id: str) -> bool:
|
async def delete_generation(self, generation_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Soft delete generation by marking it as deleted.
|
Soft delete generation by marking it as deleted.
|
||||||
@@ -473,37 +608,4 @@ class GenerationService:
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting generation {generation_id}: {e}")
|
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def cleanup_stale_generations(self):
|
|
||||||
"""
|
|
||||||
Cancels generations that have been running for more than 1 hour.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
|
|
||||||
if count > 0:
|
|
||||||
logger.info(f"Cleaned up {count} stale generations (timeout)")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error cleaning up stale generations: {e}")
|
|
||||||
|
|
||||||
async def cleanup_old_data(self, days: int = 2):
|
|
||||||
"""
|
|
||||||
Очистка старых данных:
|
|
||||||
1. Мягко удаляет генерации старше N дней
|
|
||||||
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 1. Мягко удаляем генерации и собираем asset IDs
|
|
||||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
|
||||||
|
|
||||||
if gen_count > 0:
|
|
||||||
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
|
|
||||||
f"Found {len(asset_ids)} associated asset IDs.")
|
|
||||||
|
|
||||||
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
|
|
||||||
if asset_ids:
|
|
||||||
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
|
||||||
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during old data cleanup: {e}")
|
|
||||||
@@ -1,75 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
from repos.dao import DAO
|
|
||||||
from models.Idea import Idea
|
|
||||||
|
|
||||||
class IdeaService:
|
|
||||||
def __init__(self, dao: DAO):
|
|
||||||
self.dao = dao
|
|
||||||
|
|
||||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
|
|
||||||
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
|
|
||||||
idea_id = await self.dao.ideas.create_idea(idea)
|
|
||||||
idea.id = idea_id
|
|
||||||
return idea
|
|
||||||
|
|
||||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
|
||||||
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
|
|
||||||
|
|
||||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
|
||||||
return await self.dao.ideas.get_idea(idea_id)
|
|
||||||
|
|
||||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
|
|
||||||
idea = await self.dao.ideas.get_idea(idea_id)
|
|
||||||
if not idea:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if name is not None:
|
|
||||||
idea.name = name
|
|
||||||
if description is not None:
|
|
||||||
idea.description = description
|
|
||||||
|
|
||||||
idea.updated_at = datetime.now()
|
|
||||||
await self.dao.ideas.update_idea(idea)
|
|
||||||
return idea
|
|
||||||
|
|
||||||
async def delete_idea(self, idea_id: str) -> bool:
|
|
||||||
return await self.dao.ideas.delete_idea(idea_id)
|
|
||||||
|
|
||||||
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
|
|
||||||
# Verify idea exists
|
|
||||||
idea = await self.dao.ideas.get_idea(idea_id)
|
|
||||||
if not idea:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Get generation
|
|
||||||
gen = await self.dao.generations.get_generation(generation_id)
|
|
||||||
if not gen:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Link
|
|
||||||
gen.idea_id = idea_id
|
|
||||||
gen.updated_at = datetime.now()
|
|
||||||
await self.dao.generations.update_generation(gen)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
|
|
||||||
# Verify idea exists (optional, but good for validation)
|
|
||||||
idea = await self.dao.ideas.get_idea(idea_id)
|
|
||||||
if not idea:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Get generation
|
|
||||||
gen = await self.dao.generations.get_generation(generation_id)
|
|
||||||
if not gen:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Unlink only if currently linked to this idea
|
|
||||||
if gen.idea_id == idea_id:
|
|
||||||
gen.idea_id = None
|
|
||||||
gen.updated_at = datetime.now()
|
|
||||||
await self.dao.generations.update_generation(gen)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime, UTC
|
|
||||||
|
|
||||||
from repos.dao import DAO
|
|
||||||
from models.Post import Post
|
|
||||||
|
|
||||||
|
|
||||||
class PostService:
|
|
||||||
def __init__(self, dao: DAO):
|
|
||||||
self.dao = dao
|
|
||||||
|
|
||||||
async def create_post(
|
|
||||||
self,
|
|
||||||
date: datetime,
|
|
||||||
topic: str,
|
|
||||||
generation_ids: List[str],
|
|
||||||
project_id: Optional[str],
|
|
||||||
user_id: str,
|
|
||||||
) -> Post:
|
|
||||||
post = Post(
|
|
||||||
date=date,
|
|
||||||
topic=topic,
|
|
||||||
generation_ids=generation_ids,
|
|
||||||
project_id=project_id,
|
|
||||||
created_by=user_id,
|
|
||||||
)
|
|
||||||
post_id = await self.dao.posts.create_post(post)
|
|
||||||
post.id = post_id
|
|
||||||
return post
|
|
||||||
|
|
||||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
|
||||||
return await self.dao.posts.get_post(post_id)
|
|
||||||
|
|
||||||
async def get_posts(
|
|
||||||
self,
|
|
||||||
project_id: Optional[str],
|
|
||||||
user_id: str,
|
|
||||||
limit: int = 20,
|
|
||||||
offset: int = 0,
|
|
||||||
date_from: Optional[datetime] = None,
|
|
||||||
date_to: Optional[datetime] = None,
|
|
||||||
) -> List[Post]:
|
|
||||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
|
||||||
|
|
||||||
async def update_post(
|
|
||||||
self,
|
|
||||||
post_id: str,
|
|
||||||
date: Optional[datetime] = None,
|
|
||||||
topic: Optional[str] = None,
|
|
||||||
) -> Optional[Post]:
|
|
||||||
post = await self.dao.posts.get_post(post_id)
|
|
||||||
if not post:
|
|
||||||
return None
|
|
||||||
|
|
||||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
|
||||||
if date is not None:
|
|
||||||
updates["date"] = date
|
|
||||||
if topic is not None:
|
|
||||||
updates["topic"] = topic
|
|
||||||
|
|
||||||
await self.dao.posts.update_post(post_id, updates)
|
|
||||||
|
|
||||||
# Return refreshed post
|
|
||||||
return await self.dao.posts.get_post(post_id)
|
|
||||||
|
|
||||||
async def delete_post(self, post_id: str) -> bool:
|
|
||||||
return await self.dao.posts.delete_post(post_id)
|
|
||||||
|
|
||||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
|
||||||
post = await self.dao.posts.get_post(post_id)
|
|
||||||
if not post:
|
|
||||||
return False
|
|
||||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
|
||||||
|
|
||||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
|
||||||
post = await self.dao.posts.get_post(post_id)
|
|
||||||
if not post:
|
|
||||||
return False
|
|
||||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@ from pydantic import BaseModel, computed_field, Field, model_validator
|
|||||||
|
|
||||||
class AssetContentType(str, Enum):
|
class AssetContentType(str, Enum):
|
||||||
IMAGE = 'image'
|
IMAGE = 'image'
|
||||||
|
VIDEO = 'video'
|
||||||
PROMPT = 'prompt'
|
PROMPT = 'prompt'
|
||||||
|
|
||||||
class AssetType(str, Enum):
|
class AssetType(str, Enum):
|
||||||
@@ -30,7 +31,6 @@ class Asset(BaseModel):
|
|||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
created_by: Optional[str] = None
|
created_by: Optional[str] = None
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
is_deleted: bool = False
|
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class GenerationStatus(str, Enum):
|
|||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
status: GenerationStatus = GenerationStatus.RUNNING
|
status: GenerationStatus = GenerationStatus.RUNNING
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
telegram_id: Optional[int] = None
|
telegram_id: Optional[int] = None
|
||||||
@@ -35,10 +36,12 @@ class Generation(BaseModel):
|
|||||||
output_token_usage: Optional[int] = None
|
output_token_usage: Optional[int] = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
album_id: Optional[str] = None
|
album_id: Optional[str] = None
|
||||||
generation_group_id: Optional[str] = None
|
|
||||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
idea_id: Optional[str] = None
|
# Video-specific fields
|
||||||
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None # 5 or 10 seconds
|
||||||
|
video_mode: Optional[str] = None # "std" or "pro"
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
class Idea(BaseModel):
|
|
||||||
id: Optional[str] = None
|
|
||||||
name: str = "New Idea"
|
|
||||||
description: Optional[str] = None
|
|
||||||
project_id: Optional[str] = None
|
|
||||||
created_by: str # User ID
|
|
||||||
is_deleted: bool = False
|
|
||||||
created_at: datetime = Field(default_factory=datetime.now)
|
|
||||||
updated_at: datetime = Field(default_factory=datetime.now)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
from datetime import datetime, timezone, UTC
|
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
|
|
||||||
|
|
||||||
class Post(BaseModel):
|
|
||||||
id: Optional[str] = None
|
|
||||||
date: datetime
|
|
||||||
topic: str
|
|
||||||
generation_ids: List[str] = Field(default_factory=list)
|
|
||||||
project_id: Optional[str] = None
|
|
||||||
created_by: str
|
|
||||||
is_deleted: bool = False
|
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def ensure_tz_aware(self):
|
|
||||||
for field in ("date", "created_at", "updated_at"):
|
|
||||||
val = getattr(self, field)
|
|
||||||
if val is not None and val.tzinfo is None:
|
|
||||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
|
||||||
return self
|
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -34,10 +34,12 @@ class Quality(str, Enum):
|
|||||||
class GenType(str, Enum):
|
class GenType(str, Enum):
|
||||||
TEXT = 'Text'
|
TEXT = 'Text'
|
||||||
IMAGE = 'Image'
|
IMAGE = 'Image'
|
||||||
|
VIDEO = 'Video'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self) -> str:
|
def value_type(self) -> str:
|
||||||
return {
|
return {
|
||||||
GenType.TEXT: 'Text',
|
GenType.TEXT: 'Text',
|
||||||
GenType.IMAGE: 'Image',
|
GenType.IMAGE: 'Image',
|
||||||
|
GenType.VIDEO: 'Video',
|
||||||
}[self]
|
}[self]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,8 +1,6 @@
|
|||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, UTC
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from uuid import uuid4
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
@@ -21,8 +19,7 @@ class AssetsRepo:
|
|||||||
# Main data
|
# Main data
|
||||||
if asset.data:
|
if asset.data:
|
||||||
ts = int(asset.created_at.timestamp())
|
ts = int(asset.created_at.timestamp())
|
||||||
uid = uuid4().hex[:8]
|
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
|
||||||
|
|
||||||
uploaded = await self.s3.upload_file(object_name, asset.data)
|
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||||
if uploaded:
|
if uploaded:
|
||||||
@@ -35,8 +32,7 @@ class AssetsRepo:
|
|||||||
# Thumbnail
|
# Thumbnail
|
||||||
if asset.thumbnail:
|
if asset.thumbnail:
|
||||||
ts = int(asset.created_at.timestamp())
|
ts = int(asset.created_at.timestamp())
|
||||||
uid = uuid4().hex[:8]
|
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
|
||||||
|
|
||||||
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||||
if uploaded_thumb:
|
if uploaded_thumb:
|
||||||
@@ -51,7 +47,7 @@ class AssetsRepo:
|
|||||||
return str(res.inserted_id)
|
return str(res.inserted_id)
|
||||||
|
|
||||||
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||||
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
|
filter = {}
|
||||||
if asset_type:
|
if asset_type:
|
||||||
filter["type"] = asset_type
|
filter["type"] = asset_type
|
||||||
args = {}
|
args = {}
|
||||||
@@ -138,8 +134,7 @@ class AssetsRepo:
|
|||||||
if self.s3:
|
if self.s3:
|
||||||
if asset.data:
|
if asset.data:
|
||||||
ts = int(asset.created_at.timestamp())
|
ts = int(asset.created_at.timestamp())
|
||||||
uid = uuid4().hex[:8]
|
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||||
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
|
|
||||||
if await self.s3.upload_file(object_name, asset.data):
|
if await self.s3.upload_file(object_name, asset.data):
|
||||||
asset.minio_object_name = object_name
|
asset.minio_object_name = object_name
|
||||||
asset.minio_bucket = self.s3.bucket_name
|
asset.minio_bucket = self.s3.bucket_name
|
||||||
@@ -147,8 +142,7 @@ class AssetsRepo:
|
|||||||
|
|
||||||
if asset.thumbnail:
|
if asset.thumbnail:
|
||||||
ts = int(asset.created_at.timestamp())
|
ts = int(asset.created_at.timestamp())
|
||||||
uid = uuid4().hex[:8]
|
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||||
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
|
|
||||||
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||||
asset.minio_thumbnail_object_name = thumb_name
|
asset.minio_thumbnail_object_name = thumb_name
|
||||||
asset.thumbnail = None
|
asset.thumbnail = None
|
||||||
@@ -203,61 +197,6 @@ class AssetsRepo:
|
|||||||
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
||||||
return res.deleted_count > 0
|
return res.deleted_count > 0
|
||||||
|
|
||||||
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
|
|
||||||
"""
|
|
||||||
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
|
|
||||||
Возвращает количество обработанных ассетов.
|
|
||||||
"""
|
|
||||||
if not asset_ids:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
|
|
||||||
if not object_ids:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Находим ассеты, которые ещё не удалены
|
|
||||||
cursor = self.collection.find(
|
|
||||||
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
|
|
||||||
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
|
|
||||||
)
|
|
||||||
|
|
||||||
purged_count = 0
|
|
||||||
ids_to_update = []
|
|
||||||
|
|
||||||
async for doc in cursor:
|
|
||||||
ids_to_update.append(doc["_id"])
|
|
||||||
|
|
||||||
# Жёсткое удаление файлов из S3
|
|
||||||
if self.s3:
|
|
||||||
if doc.get("minio_object_name"):
|
|
||||||
try:
|
|
||||||
await self.s3.delete_file(doc["minio_object_name"])
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
|
|
||||||
if doc.get("minio_thumbnail_object_name"):
|
|
||||||
try:
|
|
||||||
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
|
|
||||||
|
|
||||||
purged_count += 1
|
|
||||||
|
|
||||||
# Мягкое удаление + очистка ссылок на S3
|
|
||||||
if ids_to_update:
|
|
||||||
await self.collection.update_many(
|
|
||||||
{"_id": {"$in": ids_to_update}},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"is_deleted": True,
|
|
||||||
"minio_object_name": None,
|
|
||||||
"minio_thumbnail_object_name": None,
|
|
||||||
"updated_at": datetime.now(UTC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return purged_count
|
|
||||||
|
|
||||||
async def migrate_to_minio(self) -> dict:
|
async def migrate_to_minio(self) -> dict:
|
||||||
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
||||||
if not self.s3:
|
if not self.s3:
|
||||||
@@ -277,8 +216,7 @@ class AssetsRepo:
|
|||||||
created_at = doc.get("created_at")
|
created_at = doc.get("created_at")
|
||||||
ts = int(created_at.timestamp()) if created_at else 0
|
ts = int(created_at.timestamp()) if created_at else 0
|
||||||
|
|
||||||
uid = uuid4().hex[:8]
|
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||||
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
|
|
||||||
if await self.s3.upload_file(object_name, data):
|
if await self.s3.upload_file(object_name, data):
|
||||||
await self.collection.update_one(
|
await self.collection.update_one(
|
||||||
{"_id": asset_id},
|
{"_id": asset_id},
|
||||||
@@ -305,8 +243,7 @@ class AssetsRepo:
|
|||||||
created_at = doc.get("created_at")
|
created_at = doc.get("created_at")
|
||||||
ts = int(created_at.timestamp()) if created_at else 0
|
ts = int(created_at.timestamp()) if created_at else 0
|
||||||
|
|
||||||
uid = uuid4().hex[:8]
|
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||||
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
|
|
||||||
if await self.s3.upload_file(thumb_name, thumb):
|
if await self.s3.upload_file(thumb_name, thumb):
|
||||||
await self.collection.update_one(
|
await self.collection.update_one(
|
||||||
{"_id": asset_id},
|
{"_id": asset_id},
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ from repos.generation_repo import GenerationRepo
|
|||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
from repos.albums_repo import AlbumsRepo
|
from repos.albums_repo import AlbumsRepo
|
||||||
from repos.project_repo import ProjectRepo
|
from repos.project_repo import ProjectRepo
|
||||||
from repos.idea_repo import IdeaRepo
|
|
||||||
from repos.post_repo import PostRepo
|
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -21,5 +19,3 @@ class DAO:
|
|||||||
self.albums = AlbumsRepo(client, db_name)
|
self.albums = AlbumsRepo(client, db_name)
|
||||||
self.projects = ProjectRepo(client, db_name)
|
self.projects = ProjectRepo(client, db_name)
|
||||||
self.users = UsersRepo(client, db_name)
|
self.users = UsersRepo(client, db_name)
|
||||||
self.ideas = IdeaRepo(client, db_name)
|
|
||||||
self.posts = PostRepo(client, db_name)
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from typing import Any, Optional, List
|
from typing import Optional, List
|
||||||
from datetime import datetime, timedelta, UTC
|
|
||||||
|
|
||||||
from PIL.ImageChops import offset
|
from PIL.ImageChops import offset
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
@@ -17,7 +16,7 @@ class GenerationRepo:
|
|||||||
res = await self.collection.insert_one(generation.model_dump())
|
res = await self.collection.insert_one(generation.model_dump())
|
||||||
return str(res.inserted_id)
|
return str(res.inserted_id)
|
||||||
|
|
||||||
async def get_generation(self, generation_id: str) -> Generation | None:
|
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||||
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
|
||||||
if res is None:
|
if res is None:
|
||||||
return None
|
return None
|
||||||
@@ -26,29 +25,20 @@ class GenerationRepo:
|
|||||||
return Generation(**res)
|
return Generation(**res)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
|
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
|
|
||||||
filter: dict[str, Any] = {"is_deleted": False}
|
filter = {"is_deleted": False}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
filter["linked_character_id"] = character_id
|
filter["linked_character_id"] = character_id
|
||||||
if status is not None:
|
if status is not None:
|
||||||
filter["status"] = status
|
filter["status"] = status
|
||||||
if created_by is not None:
|
if created_by is not None:
|
||||||
filter["created_by"] = created_by
|
filter["created_by"] = created_by
|
||||||
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
|
filter["project_id"] = None
|
||||||
# But if project_id is passed, we filter by that.
|
|
||||||
if project_id is None:
|
|
||||||
filter["project_id"] = None
|
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
filter["project_id"] = project_id
|
filter["project_id"] = project_id
|
||||||
if idea_id is not None:
|
|
||||||
filter["idea_id"] = idea_id
|
|
||||||
|
|
||||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||||
# Otherwise typically descending (newest first)
|
|
||||||
sort_order = 1 if idea_id else -1
|
|
||||||
|
|
||||||
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
|
|
||||||
offset).limit(limit).to_list(None)
|
offset).limit(limit).to_list(None)
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
for generation in res:
|
for generation in res:
|
||||||
@@ -57,7 +47,7 @@ class GenerationRepo:
|
|||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
|
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||||
args = {}
|
args = {}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
args["linked_character_id"] = character_id
|
args["linked_character_id"] = character_id
|
||||||
@@ -67,10 +57,6 @@ class GenerationRepo:
|
|||||||
args["created_by"] = created_by
|
args["created_by"] = created_by
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
args["project_id"] = project_id
|
args["project_id"] = project_id
|
||||||
if idea_id is not None:
|
|
||||||
args["idea_id"] = idea_id
|
|
||||||
if album_id is not None:
|
|
||||||
args["album_id"] = album_id
|
|
||||||
return await self.collection.count_documents(args)
|
return await self.collection.count_documents(args)
|
||||||
|
|
||||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||||
@@ -91,62 +77,3 @@ class GenerationRepo:
|
|||||||
|
|
||||||
async def update_generation(self, generation: Generation, ):
|
async def update_generation(self, generation: Generation, ):
|
||||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||||
|
|
||||||
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
|
||||||
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
|
||||||
generations: List[Generation] = []
|
|
||||||
for generation in res:
|
|
||||||
generation["id"] = str(generation.pop("_id"))
|
|
||||||
generations.append(Generation(**generation))
|
|
||||||
return generations
|
|
||||||
|
|
||||||
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
|
|
||||||
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
|
|
||||||
res = await self.collection.update_many(
|
|
||||||
{
|
|
||||||
"status": GenerationStatus.RUNNING,
|
|
||||||
"created_at": {"$lt": cutoff_time}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"status": GenerationStatus.FAILED,
|
|
||||||
"failed_reason": "Timeout: Execution time limit exceeded",
|
|
||||||
"updated_at": datetime.now(UTC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return res.modified_count
|
|
||||||
|
|
||||||
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
|
|
||||||
"""
|
|
||||||
Мягко удаляет генерации старше N дней.
|
|
||||||
Возвращает (количество удалённых, список asset IDs для очистки).
|
|
||||||
"""
|
|
||||||
cutoff_time = datetime.now(UTC) - timedelta(days=days)
|
|
||||||
filter_query = {
|
|
||||||
"is_deleted": False,
|
|
||||||
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
|
|
||||||
"created_at": {"$lt": cutoff_time}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Сначала собираем asset IDs из удаляемых генераций
|
|
||||||
asset_ids: List[str] = []
|
|
||||||
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
|
||||||
async for doc in cursor:
|
|
||||||
asset_ids.extend(doc.get("result_list", []))
|
|
||||||
asset_ids.extend(doc.get("assets_list", []))
|
|
||||||
|
|
||||||
# Мягкое удаление
|
|
||||||
res = await self.collection.update_many(
|
|
||||||
filter_query,
|
|
||||||
{
|
|
||||||
"$set": {
|
|
||||||
"is_deleted": True,
|
|
||||||
"updated_at": datetime.now(UTC)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Убираем дубликаты
|
|
||||||
unique_asset_ids = list(set(asset_ids))
|
|
||||||
return res.modified_count, unique_asset_ids
|
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
from typing import Optional, List
|
|
||||||
from bson import ObjectId
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from models.Idea import Idea
|
|
||||||
|
|
||||||
class IdeaRepo:
|
|
||||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
|
||||||
self.collection = client[db_name]["ideas"]
|
|
||||||
|
|
||||||
async def create_idea(self, idea: Idea) -> str:
|
|
||||||
res = await self.collection.insert_one(idea.model_dump())
|
|
||||||
return str(res.inserted_id)
|
|
||||||
|
|
||||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
|
||||||
if not ObjectId.is_valid(idea_id):
|
|
||||||
return None
|
|
||||||
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
|
|
||||||
if res:
|
|
||||||
res["id"] = str(res.pop("_id"))
|
|
||||||
return Idea(**res)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
|
|
||||||
if project_id:
|
|
||||||
match_stage = {"project_id": project_id, "is_deleted": False}
|
|
||||||
else:
|
|
||||||
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
|
||||||
|
|
||||||
pipeline = [
|
|
||||||
{"$match": match_stage},
|
|
||||||
{"$sort": {"updated_at": -1}},
|
|
||||||
{"$skip": offset},
|
|
||||||
{"$limit": limit},
|
|
||||||
# Add string id field for lookup
|
|
||||||
{"$addFields": {"str_id": {"$toString": "$_id"}}},
|
|
||||||
# Lookup generations
|
|
||||||
{
|
|
||||||
"$lookup": {
|
|
||||||
"from": "generations",
|
|
||||||
"let": {"idea_id": "$str_id"},
|
|
||||||
"pipeline": [
|
|
||||||
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
|
|
||||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest
|
|
||||||
{"$limit": 1}
|
|
||||||
],
|
|
||||||
"as": "generations"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
# Unwind generations array (preserve ideas without generations)
|
|
||||||
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
|
|
||||||
# Rename for clarity
|
|
||||||
{"$addFields": {
|
|
||||||
"last_generation": "$generations",
|
|
||||||
"id": "$str_id"
|
|
||||||
}},
|
|
||||||
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
|
|
||||||
]
|
|
||||||
|
|
||||||
return await self.collection.aggregate(pipeline).to_list(None)
|
|
||||||
|
|
||||||
async def delete_idea(self, idea_id: str) -> bool:
|
|
||||||
if not ObjectId.is_valid(idea_id):
|
|
||||||
return False
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(idea_id)},
|
|
||||||
{"$set": {"is_deleted": True}}
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
|
|
||||||
async def update_idea(self, idea: Idea) -> bool:
|
|
||||||
if not idea.id or not ObjectId.is_valid(idea.id):
|
|
||||||
return False
|
|
||||||
|
|
||||||
idea_dict = idea.model_dump()
|
|
||||||
if "id" in idea_dict:
|
|
||||||
del idea_dict["id"]
|
|
||||||
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(idea.id)},
|
|
||||||
{"$set": idea_dict}
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
import logging
|
|
||||||
from bson import ObjectId
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
|
|
||||||
from models.Post import Post
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class PostRepo:
|
|
||||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
|
||||||
self.collection = client[db_name]["posts"]
|
|
||||||
|
|
||||||
async def create_post(self, post: Post) -> str:
|
|
||||||
res = await self.collection.insert_one(post.model_dump())
|
|
||||||
return str(res.inserted_id)
|
|
||||||
|
|
||||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
|
||||||
if not ObjectId.is_valid(post_id):
|
|
||||||
return None
|
|
||||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
|
||||||
if res:
|
|
||||||
res["id"] = str(res.pop("_id"))
|
|
||||||
return Post(**res)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_posts(
|
|
||||||
self,
|
|
||||||
project_id: Optional[str],
|
|
||||||
user_id: str,
|
|
||||||
limit: int = 20,
|
|
||||||
offset: int = 0,
|
|
||||||
date_from: Optional[datetime] = None,
|
|
||||||
date_to: Optional[datetime] = None,
|
|
||||||
) -> List[Post]:
|
|
||||||
if project_id:
|
|
||||||
match = {"project_id": project_id, "is_deleted": False}
|
|
||||||
else:
|
|
||||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
|
||||||
|
|
||||||
if date_from or date_to:
|
|
||||||
date_filter = {}
|
|
||||||
if date_from:
|
|
||||||
date_filter["$gte"] = date_from
|
|
||||||
if date_to:
|
|
||||||
date_filter["$lte"] = date_to
|
|
||||||
match["date"] = date_filter
|
|
||||||
|
|
||||||
cursor = (
|
|
||||||
self.collection.find(match)
|
|
||||||
.sort("date", -1)
|
|
||||||
.skip(offset)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
posts = []
|
|
||||||
async for doc in cursor:
|
|
||||||
doc["id"] = str(doc.pop("_id"))
|
|
||||||
posts.append(Post(**doc))
|
|
||||||
return posts
|
|
||||||
|
|
||||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
|
||||||
if not ObjectId.is_valid(post_id):
|
|
||||||
return False
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(post_id)},
|
|
||||||
{"$set": data},
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
|
|
||||||
async def delete_post(self, post_id: str) -> bool:
|
|
||||||
if not ObjectId.is_valid(post_id):
|
|
||||||
return False
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(post_id)},
|
|
||||||
{"$set": {"is_deleted": True}},
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
|
|
||||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
|
||||||
if not ObjectId.is_valid(post_id):
|
|
||||||
return False
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(post_id)},
|
|
||||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
|
|
||||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
|
||||||
if not ObjectId.is_valid(post_id):
|
|
||||||
return False
|
|
||||||
res = await self.collection.update_one(
|
|
||||||
{"_id": ObjectId(post_id)},
|
|
||||||
{"$pull": {"generation_ids": generation_id}},
|
|
||||||
)
|
|
||||||
return res.modified_count > 0
|
|
||||||
@@ -51,3 +51,4 @@ python-jose[cryptography]==3.3.0
|
|||||||
python-multipart==0.0.22
|
python-multipart==0.0.22
|
||||||
email-validator
|
email-validator
|
||||||
prometheus-fastapi-instrumentator
|
prometheus-fastapi-instrumentator
|
||||||
|
PyJWT
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,97 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
# Import from project root (requires PYTHONPATH=.)
|
|
||||||
from api.service.idea_service import IdeaService
|
|
||||||
from repos.dao import DAO
|
|
||||||
from models.Idea import Idea
|
|
||||||
from models.Generation import Generation, GenerationStatus
|
|
||||||
from models.enums import AspectRatios, Quality
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
|
||||||
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
|
||||||
|
|
||||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
|
||||||
|
|
||||||
async def test_idea_flow():
|
|
||||||
client = AsyncIOMotorClient(MONGO_HOST)
|
|
||||||
dao = DAO(client, db_name=DB_NAME)
|
|
||||||
service = IdeaService(dao)
|
|
||||||
|
|
||||||
# 1. Create an Idea
|
|
||||||
print("Creating idea...")
|
|
||||||
user_id = "test_user_123"
|
|
||||||
project_id = "test_project_abc"
|
|
||||||
idea = await service.create_idea("My Test Idea", "Initial Description", project_id, user_id)
|
|
||||||
print(f"Idea created: {idea.id} - {idea.name}")
|
|
||||||
|
|
||||||
# 2. Update Idea
|
|
||||||
print("Updating idea...")
|
|
||||||
updated_idea = await service.update_idea(idea.id, description="Updated description")
|
|
||||||
print(f"Idea updated: {updated_idea.description}")
|
|
||||||
if updated_idea.description == "Updated description":
|
|
||||||
print("✅ Idea update successful")
|
|
||||||
else:
|
|
||||||
print("❌ Idea update FAILED")
|
|
||||||
|
|
||||||
# 3. Add Generation linked to Idea
|
|
||||||
print("Creating generation linked to idea...")
|
|
||||||
gen = Generation(
|
|
||||||
prompt="idea generation 1",
|
|
||||||
# idea_id=idea.id, <-- Intentionally NOT linking initially to test linking method
|
|
||||||
project_id=project_id,
|
|
||||||
created_by=user_id,
|
|
||||||
aspect_ratio=AspectRatios.NINESIXTEEN,
|
|
||||||
quality=Quality.ONEK,
|
|
||||||
assets_list=[]
|
|
||||||
)
|
|
||||||
gen_id = await dao.generations.create_generation(gen)
|
|
||||||
print(f"Created generation: {gen_id}")
|
|
||||||
|
|
||||||
# Link generation to idea
|
|
||||||
print("Linking generation to idea...")
|
|
||||||
success = await service.add_generation_to_idea(idea.id, gen_id)
|
|
||||||
if success:
|
|
||||||
print("✅ Linking successful")
|
|
||||||
else:
|
|
||||||
print("❌ Linking FAILED")
|
|
||||||
|
|
||||||
# Debug: Check if generation was saved with idea_id
|
|
||||||
saved_gen = await dao.generations.collection.find_one({"_id": ObjectId(gen_id)})
|
|
||||||
print(f"DEBUG: Saved Generation in DB idea_id: {saved_gen.get('idea_id')}")
|
|
||||||
|
|
||||||
# 4. Fetch Generations for Idea (Verify filtering and ordering)
|
|
||||||
print("Fetching generations for idea...")
|
|
||||||
gens = await service.dao.generations.get_generations(idea_id=idea.id) # using repo directly as service might return wrapper
|
|
||||||
print(f"Found {len(gens)} generations in idea")
|
|
||||||
|
|
||||||
if len(gens) == 1 and gens[0].id == gen_id:
|
|
||||||
print("✅ Generation retrieval successful")
|
|
||||||
else:
|
|
||||||
print("❌ Generation retrieval FAILED")
|
|
||||||
|
|
||||||
# 5. Fetch Ideas for Project
|
|
||||||
ideas = await service.get_ideas(project_id)
|
|
||||||
print(f"Found {len(ideas)} ideas for project")
|
|
||||||
|
|
||||||
# Cleaning up
|
|
||||||
print("Cleaning up...")
|
|
||||||
await service.delete_idea(idea.id)
|
|
||||||
await dao.generations.collection.delete_one({"_id": ObjectId(gen_id)})
|
|
||||||
|
|
||||||
# Verify deletion
|
|
||||||
deleted_idea = await service.get_idea(idea.id)
|
|
||||||
# IdeaRepo.delete_idea logic sets is_deleted=True
|
|
||||||
if deleted_idea and deleted_idea.is_deleted:
|
|
||||||
print(f"✅ Idea deleted successfully")
|
|
||||||
|
|
||||||
# Hard delete for cleanup
|
|
||||||
await dao.ideas.collection.delete_one({"_id": ObjectId(idea.id)})
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test_idea_flow())
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from datetime import datetime, timedelta, UTC
|
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
|
||||||
from models.Generation import Generation, GenerationStatus
|
|
||||||
from repos.generation_repo import GenerationRepo
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
# Mock configs if not present in env
|
|
||||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
|
||||||
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
|
||||||
|
|
||||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
|
||||||
|
|
||||||
async def test_scheduler():
|
|
||||||
client = AsyncIOMotorClient(MONGO_HOST)
|
|
||||||
repo = GenerationRepo(client, db_name=DB_NAME)
|
|
||||||
|
|
||||||
# 1. Create a "stale" generation (2 hours ago)
|
|
||||||
stale_gen = Generation(
|
|
||||||
prompt="stale test",
|
|
||||||
status=GenerationStatus.RUNNING,
|
|
||||||
created_at=datetime.now(UTC) - timedelta(minutes=120),
|
|
||||||
assets_list=[],
|
|
||||||
aspect_ratio="NINESIXTEEN",
|
|
||||||
quality="ONEK"
|
|
||||||
)
|
|
||||||
gen_id = await repo.create_generation(stale_gen)
|
|
||||||
print(f"Created stale generation: {gen_id}")
|
|
||||||
|
|
||||||
# 2. Run cleanup
|
|
||||||
print("Running cleanup...")
|
|
||||||
count = await repo.cancel_stale_generations(timeout_minutes=60)
|
|
||||||
print(f"Cleaned up {count} generations")
|
|
||||||
|
|
||||||
# 3. Verify status
|
|
||||||
updated_gen = await repo.get_generation(gen_id)
|
|
||||||
print(f"Generation status: {updated_gen.status}")
|
|
||||||
print(f"Failed reason: {updated_gen.failed_reason}")
|
|
||||||
|
|
||||||
if updated_gen.status == GenerationStatus.FAILED:
|
|
||||||
print("✅ SUCCESS: Generation marked as FAILED")
|
|
||||||
else:
|
|
||||||
print("❌ FAILURE: Generation status not updated")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
await repo.collection.delete_one({"_id": updated_gen.id}) # Remove test data
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(test_scheduler())
|
|
||||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user