64 Commits

Author SHA1 Message Date
xds
14f9e7b7e9 models + refactor 2026-03-17 16:46:32 +03:00
xds
e011805186 models + refactor 2026-02-27 20:37:24 +03:00
xds
d9caececd7 nsfw mark api 2026-02-27 14:44:14 +03:00
xds
c1300b7a2d nsfw mark api 2026-02-27 14:33:37 +03:00
xds
f6001f5994 nsfw mark api 2026-02-27 13:51:22 +03:00
xds
e4a39e90c3 nsfw mark api 2026-02-27 09:08:48 +03:00
xds
e976fe1c58 inspirations 2026-02-26 11:26:18 +03:00
xds
ecc8d69039 inspirations 2026-02-24 16:42:46 +03:00
xds
bc9230a49b fixes 2026-02-24 12:11:19 +03:00
xds
f07105b0e5 fixes 2026-02-21 18:31:28 +03:00
xds
9a5d54a373 fixes 2026-02-20 17:54:57 +03:00
xds
1868864f76 fixes 2026-02-20 13:10:37 +03:00
xds
9e0c522b5f ашчуы 2026-02-20 10:28:56 +03:00
xds
e1d941a2cd + env 2026-02-20 02:02:13 +03:00
c7c27197c9 Merge pull request '+ env' (#4) from enviroments into main
Reviewed-on: #4
2026-02-19 18:32:51 +00:00
xds
5aa6391dc8 + env 2026-02-19 21:25:29 +03:00
xds
ffb0463fe0 os.getenv -> config.py 2026-02-19 15:28:04 +03:00
xds
dd0f8a1cb6 os.getenv -> config.py 2026-02-19 13:00:51 +03:00
xds
4af5134726 fixes 2026-02-18 17:06:17 +03:00
xds
7488665d04 fixes 2026-02-18 17:01:06 +03:00
xds
ecc88aca62 fixes 2026-02-18 16:53:28 +03:00
xds
70f50170fc fixes 2026-02-18 16:45:39 +03:00
xds
f4207fc4c1 fixes 2026-02-18 16:45:02 +03:00
xds
c50d2c8ad9 fixes 2026-02-18 16:44:04 +03:00
xds
4586daac38 fixes 2026-02-18 16:35:04 +03:00
198ac44960 Merge pull request 'feat: introduce post resource with full CRUD operations and generation linking.' (#3) from posts into main
Reviewed-on: #3
2026-02-17 12:54:47 +00:00
xds
d820d9145b feat: introduce post resource with full CRUD operations and generation linking. 2026-02-17 15:54:01 +03:00
xds
c93e577bcf feat: Implement asset soft deletion with S3 file purging, enhance type safety, and improve error handling in generation and adapter services. 2026-02-17 12:51:40 +03:00
xds
c5d4849bff Update compiled bytecode for assets_repo.py. 2026-02-16 16:41:53 +03:00
xds
9abfbef871 Merge branch 'ideas' 2026-02-16 16:41:13 +03:00
xds
68a3f529cb feat: Enhance idea retrieval to include the latest generation and support user-specific ideas not tied to a project, while also improving asset storage uniqueness and adjusting generation cancellation timeout. 2026-02-16 16:35:26 +03:00
xds
e2c050515d feat: Limit concurrent image generations to 4 using an asyncio semaphore. 2026-02-16 00:30:34 +03:00
xds
5e7dc19bf3 ideas 2026-02-15 12:42:15 +03:00
xds
97483b7030 + ideas 2026-02-15 10:26:01 +03:00
xds
2d3da59de9 12 2026-02-13 17:31:14 +03:00
xds
279cb5c6f6 feat: Implement cancellation of stale generations in the service and repository, along with a new test. 2026-02-13 17:30:11 +03:00
xds
30138bab38 feat: Introduce generation grouping, enabling multiple generations per request via a new count parameter and retrieval by group ID. 2026-02-13 11:18:11 +03:00
xds
977cab92f8 fix 2026-02-12 18:41:01 +03:00
xds
dcab238d3e fix 2026-02-12 15:50:43 +03:00
xds
9d2e4e47de fix 2026-02-12 15:32:28 +03:00
xds
c6142715d9 feat: Add image, update VS Code launch configuration, and enhance gitignore rules for build artifacts. 2026-02-12 14:02:36 +03:00
xds
456562ec1d main -> aiws 2026-02-12 00:13:06 +03:00
xds
0d0fbdf7d6 main -> aiws 2026-02-11 12:56:51 +03:00
xds
f63bcedb13 main -> aiws 2026-02-11 12:46:57 +03:00
xds
be92c766ac main -> aiws 2026-02-11 12:46:35 +03:00
xds
482bc1d9b7 main -> aiws 2026-02-11 12:30:05 +03:00
xds
a2321cf070 + prometheus 2026-02-11 11:56:08 +03:00
xds
29ccd5743e main -> aiws 2026-02-11 11:37:04 +03:00
xds
d9de2f48d2 main -> aiws 2026-02-11 11:19:50 +03:00
xds
1ddeb0af46 main -> aiws 2026-02-11 11:15:21 +03:00
xds
a7c2319f13 feat: Implement external generation import API secured by HMAC-SHA256 signature verification. 2026-02-10 14:06:37 +03:00
xds
00e83b8561 fix 2026-02-09 17:01:48 +03:00
xds
a9d24c725e Update user repository implementation. 2026-02-09 16:16:55 +03:00
xds
458b6ebfc3 feat: Implement project management with new models, repositories, and API endpoints, and enhance character management with project association and DTOs. 2026-02-09 16:06:54 +03:00
xds
668aadcdc9 fix 2026-02-09 09:49:49 +03:00
xds
4461964791 feat: Add created_by and cost fields to generation models, populate created_by from the authenticated user, and implement cost calculation. 2026-02-09 01:52:23 +03:00
xds
fa3e1bb05f refactor: Remove trailing slashes from album router endpoint paths. 2026-02-09 00:47:54 +03:00
xds
8a89b27624 feat: Add album management functionality with new data model, repository, service, API, and generation integration. 2026-02-08 23:13:31 +03:00
xds
c17c47ccc1 catch exception123 2026-02-08 22:56:08 +03:00
xds
c25b029006 123 2026-02-08 22:53:09 +03:00
xds
a449f65de9 123 2026-02-08 17:57:09 +03:00
xds
3cf7db5cdf Merge branch 'main' of https://gitea.luminic.space/ai-char/ai-char-bot
# Conflicts:
#	.DS_Store
2026-02-08 17:53:19 +03:00
xds
288515fa04 123 2026-02-08 17:41:07 +03:00
f1033210cc Merge pull request 'auth' (#1) from auth into main
Reviewed-on: #1
2026-02-08 14:37:30 +00:00
112 changed files with 5191 additions and 601 deletions

BIN
.DS_Store vendored

Binary file not shown.

33
.context.md Normal file
View File

@@ -0,0 +1,33 @@
# Project Context: AI Char Bot
## Overview
Python backend project using FastAPI and MongoDB (Motor).
Root: `/Users/xds/develop/py projects/ai-char-bot`
## Architecture
- **API Layer**: `api/endpoints` (FastAPI routers).
- **Service Layer**: `api/service` (Business logic).
- **Data Layer**: `repos` (DAOs and Repositories).
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
- **Adapters**: `adapters` (External services like S3, Google Gemini).
## Coding Standards & Preferences
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
- **Error Handling**:
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
- Services/Routers handle `HTTPException`.
## Key Features & Implementation Details
- **Generations**:
- Managed by `GenerationService` and `GenerationRepo`.
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
- **Ideas**:
- Managed by `IdeaService` and `IdeaRepo`.
- Can have linked generations.
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
## Recent Changes
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.

19
.dockerignore Normal file
View File

@@ -0,0 +1,19 @@
.git
.gitignore
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.venv/
node_modules/
tmp/
logs/
*.log
dist/
build/
.cache/
.idea/
.vscode/

3
.env
View File

@@ -8,3 +8,6 @@ MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=SuperSecretPassword123! MINIO_SECRET_KEY=SuperSecretPassword123!
MINIO_BUCKET=ai-char MINIO_BUCKET=ai-char
MODE=production MODE=production
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
PROXY_SECRET_SALT=AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o
SCHEDULER_CHARACTER_ID=69931c10721fbd539804589b

25
.gitignore vendored
View File

@@ -1 +1,26 @@
minio_backup.tar.gz minio_backup.tar.gz
.DS_Store
**/__pycache__/
*.py[cod]
*$py.class
*.cpython-*.pyc
**/.DS_Store
.idea/ai-char-bot.iml
.idea
.venv
.vscode
.vscode/launch.json
middlewares/__pycache__/
middlewares/*.pyc
api/__pycache__/
api/*.pyc
repos/__pycache__/
repos/*.pyc
adapters/__pycache__/
adapters/*.pyc
services/__pycache__/
services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

31
.vscode/launch.json vendored
View File

@@ -7,38 +7,15 @@
"request": "launch", "request": "launch",
"module": "uvicorn", "module": "uvicorn",
"args": [ "args": [
"main:app", "aiws:app",
"--reload", "--reload",
"--port", "--port",
"8090" "8090",
"--host",
"0.0.0.0"
], ],
"jinja": true, "jinja": true,
"justMyCode": true "justMyCode": true
},
{
"name": "Python: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Debug Tests: Current File",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"${file}"
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
} }
] ]
} }

View File

@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . . COPY . .
# Запуск приложения (замени app.py на свой файл) # Запуск приложения (замени app.py на свой файл)
CMD ["python", "main.py"] CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]

Binary file not shown.

View File

@@ -0,0 +1,100 @@
import logging
import io
import httpx
import hashlib
import time
from typing import List, Tuple, Dict, Any, Optional
from datetime import datetime
from models.enums import AspectRatios, Quality
from config import settings
logger = logging.getLogger(__name__)
class AIProxyAdapter:
def __init__(self, base_url: str = "http://82.22.174.14:8001", salt: str = None):
self.base_url = base_url.rstrip("/")
self.salt = salt or settings.PROXY_SECRET_SALT
def _generate_headers(self) -> Dict[str, str]:
timestamp = int(time.time())
hash_input = f"{timestamp}{self.salt}".encode()
signature = hashlib.sha256(hash_input).hexdigest()
return {
"X-Timestamp": str(timestamp),
"X-Signature": signature
}
async def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", asset_urls: List[str] | None = None) -> str:
"""
Generates text using the AI Proxy with signature verification.
"""
url = f"{self.base_url}/generate_text"
messages = [{"role": "user", "content": prompt}]
payload = {
"messages": messages,
"asset_urls": asset_urls
}
headers = self._generate_headers()
async with httpx.AsyncClient() as client:
try:
response = await client.post(url, json=payload, headers=headers, timeout=60.0)
response.raise_for_status()
data = response.json()
if data.get("finish_reason") != "STOP":
logger.warning(f"AI Proxy generation finished with reason: {data.get('finish_reason')}")
return data.get("response") or ""
except Exception as e:
logger.error(f"AI Proxy Text Error: {e}")
raise Exception(f"AI Proxy Text Error: {e}")
async def generate_image(
self,
prompt: str,
aspect_ratio: AspectRatios,
quality: Quality,
model: str = "gemini-3-pro-image-preview",
asset_urls: List[str] | None = None
) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Generates image using the AI Proxy with signature verification.
"""
url = f"{self.base_url}/generate_image"
payload = {
"prompt": prompt,
"asset_urls": asset_urls
}
headers = self._generate_headers()
start_time = datetime.now()
async with httpx.AsyncClient() as client:
try:
response = await client.post(url, json=payload, headers=headers, timeout=120.0)
response.raise_for_status()
image_bytes = response.content
byte_arr = io.BytesIO(image_bytes)
byte_arr.name = f"{datetime.now().timestamp()}.png"
byte_arr.seek(0)
end_time = datetime.now()
api_duration = (end_time - start_time).total_seconds()
metrics = {
"api_execution_time_seconds": api_duration,
"token_usage": 0,
"input_token_usage": 0,
"output_token_usage": 0
}
return [byte_arr], metrics
except Exception as e:
logger.error(f"AI Proxy Image Error: {e}")
raise Exception(f"AI Proxy Image Error: {e}")

View File

@@ -8,7 +8,7 @@ from google import genai
from google.genai import types from google.genai import types
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality, TextModel, ImageModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,36 +19,37 @@ class GoogleAdapter:
raise ValueError("API Key for Gemini is missing") raise ValueError("API Key for Gemini is missing")
self.client = genai.Client(api_key=api_key) self.client = genai.Client(api_key=api_key)
# Константы моделей def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
self.TEXT_MODEL = "gemini-3-pro-preview" """Вспомогательный метод для подготовки контента (текст + картинки).
self.IMAGE_MODEL = "gemini-3-pro-image-preview" Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents : list [Any]= [prompt]
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> list: opened_images = []
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
contents = [prompt]
if images_list: if images_list:
logger.info(f"Preparing content with {len(images_list)} images") logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list: for img_bytes in images_list:
try: try:
# Gemini API требует PIL Image на входе
image = Image.open(io.BytesIO(img_bytes)) image = Image.open(io.BytesIO(img_bytes))
contents.append(image) contents.append(image)
opened_images.append(image)
except Exception as e: except Exception as e:
logger.error(f"Error processing input image: {e}") logger.error(f"Error processing input image: {e}")
else: else:
logger.info("Preparing content with no images") logger.info("Preparing content with no images")
return contents return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str: def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
""" """
Генерация текста (Чат или Vision). Генерация текста (Чат или Vision).
Возвращает строку с ответом. Возвращает строку с ответом.
""" """
contents = self._prepare_contents(prompt, images_list) if model not in [m.value for m in TextModel]:
logger.info(f"Generating text: {prompt}") raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt} with model: {model}")
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.TEXT_MODEL, model=model,
contents=contents, contents=contents,
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
response_modalities=['TEXT'], response_modalities=['TEXT'],
@@ -68,22 +69,27 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Text API Error: {e}") logger.error(f"Gemini Text API Error: {e}")
raise GoogleGenerationException(f"Gemini Text API Error: {e}") raise GoogleGenerationException(f"Gemini Text API Error: {e}")
finally:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]: def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
""" """
Генерация изображений (Text-to-Image или Image-to-Image). Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке). Возвращает список байтовых потоков (готовых к отправке).
""" """
if model not in [m.value for m in ImageModel]:
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
contents = self._prepare_contents(prompt, images_list) contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}") logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
start_time = datetime.now() start_time = datetime.now()
token_usage = 0 token_usage = 0
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.IMAGE_MODEL, model=model,
contents=contents, contents=contents,
config=types.GenerateContentConfig( config=types.GenerateContentConfig(
response_modalities=['IMAGE'], response_modalities=['IMAGE'],
@@ -101,8 +107,20 @@ class GoogleAdapter:
if response.usage_metadata: if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count token_usage = response.usage_metadata.total_token_count
if response.parts is None and response.candidates[0].finish_reason is not None: # Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}") if response.prompt_feedback and response.prompt_feedback.block_reason:
raise GoogleGenerationException(
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
)
# Check candidate-level block
if response.parts is None:
response_reason = (
response.candidates[0].finish_reason
if response.candidates and len(response.candidates) > 0
else "Unknown"
)
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
generated_images = [] generated_images = []
@@ -113,7 +131,9 @@ class GoogleAdapter:
try: try:
# 1. Берем сырые байты # 1. Берем сырые байты
raw_data = part.inline_data.data raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data) if raw_data is None:
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG) # 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp() timestamp = datetime.now().timestamp()
@@ -148,3 +168,7 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Gemini Image API Error: {e}") logger.error(f"Gemini Image API Error: {e}")
raise GoogleGenerationException(f"Gemini Image API Error: {e}") raise GoogleGenerationException(f"Gemini Image API Error: {e}")
finally:
for img in opened_images:
img.close()
del contents

88
adapters/meta_adapter.py Normal file
View File

@@ -0,0 +1,88 @@
import logging
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
META_GRAPH_VERSION = "v18.0"
META_GRAPH_BASE = f"https://graph.facebook.com/{META_GRAPH_VERSION}"
class MetaAdapter:
"""Adapter for Meta Platform API (Instagram Graph API).
Requires:
- access_token: long-lived Page or Instagram access token
- instagram_account_id: Instagram Business Account ID
"""
def __init__(self, access_token: str, instagram_account_id: str):
self.access_token = access_token
self.instagram_account_id = instagram_account_id
async def post_to_feed(self, image_url: str, caption: str) -> Optional[str]:
"""Upload image and publish to Instagram feed.
Returns the post ID on success, raises on failure.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
# Step 1: create media container
resp = await client.post(
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
data={
"image_url": image_url,
"caption": caption,
"access_token": self.access_token,
},
)
resp.raise_for_status()
creation_id = resp.json().get("id")
if not creation_id:
raise ValueError(f"No creation_id from Meta API: {resp.text}")
# Step 2: publish
resp2 = await client.post(
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
data={
"creation_id": creation_id,
"access_token": self.access_token,
},
)
resp2.raise_for_status()
post_id = resp2.json().get("id")
logger.info(f"Published to Instagram feed: {post_id}")
return post_id
async def post_to_story(self, image_url: str) -> Optional[str]:
"""Upload image and publish to Instagram story.
Returns the story ID on success, raises on failure.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
# Step 1: create story container
resp = await client.post(
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
data={
"image_url": image_url,
"media_type": "STORIES",
"access_token": self.access_token,
},
)
resp.raise_for_status()
creation_id = resp.json().get("id")
if not creation_id:
raise ValueError(f"No creation_id from Meta API: {resp.text}")
# Step 2: publish
resp2 = await client.post(
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
data={
"creation_id": creation_id,
"access_token": self.access_token,
},
)
resp2.raise_for_status()
story_id = resp2.json().get("id")
logger.info(f"Published to Instagram story: {story_id}")
return story_id

View File

@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional, BinaryIO from typing import Optional, BinaryIO, AsyncGenerator
import aioboto3 import aioboto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import os import os
@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager @asynccontextmanager
async def _get_client(self): async def _get_client(self):
async with self.session.client( async with self.session.client( # type: ignore[reportGeneralTypeIssues]
"s3", "s3",
endpoint_url=self.endpoint_url, endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id, aws_access_key_id=self.aws_access_key_id,
@@ -56,6 +56,37 @@ class S3Adapter:
print(f"Error downloading from S3: {e}") print(f"Error downloading from S3: {e}")
return None return None
async def get_file_size(self, object_name: str) -> Optional[int]:
"""Returns the size of the file in bytes."""
try:
async with self._get_client() as client:
response = await client.head_object(Bucket=self.bucket_name, Key=object_name)
return response['ContentLength']
except ClientError as e:
print(f"Error getting file size from S3: {e}")
return None
async def stream_file(self, object_name: str, range_header: Optional[str] = None, chunk_size: int = 65536) -> AsyncGenerator[bytes, None]:
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
try:
async with self._get_client() as client:
args = {'Bucket': self.bucket_name, 'Key': object_name}
if range_header:
args['Range'] = range_header
response = await client.get_object(**args)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']
while True:
chunk = await body.read(chunk_size)
if not chunk:
break
yield chunk
except ClientError as e:
print(f"Error streaming from S3: {e}")
return
async def delete_file(self, object_name: str): async def delete_file(self, object_name: str):
"""Deletes a file from S3.""" """Deletes a file from S3."""
try: try:

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from aiogram import Bot, Dispatcher, Router, F from aiogram import Bot, Dispatcher, Router, F
@@ -9,18 +8,23 @@ from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command from aiogram.filters import CommandStart, Command
from aiogram.types import Message from aiogram.types import Message
from aiogram.fsm.storage.mongo import MongoStorage from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from fastapi import FastAPI from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА --- # --- ИМПОРТЫ ПРОЕКТА ---
from config import settings
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from api.service.album_service import AlbumService
from middlewares.album import AlbumMiddleware from middlewares.album import AlbumMiddleware
from middlewares.auth import AuthMiddleware from middlewares.auth import AuthMiddleware
from middlewares.dao import DaoMiddleware from middlewares.dao import DaoMiddleware
from scheduler.daily_scheduler import DailyScheduler
from scheduler.telegram_admin_handler import create_daily_scheduler_router
# Репозитории и DAO # Репозитории и DAO
from repos.char_repo import CharacterRepo from repos.char_repo import CharacterRepo
@@ -38,17 +42,23 @@ from api.endpoints.character_router import router as api_char_router # Роут
from api.endpoints.generation_router import router as api_gen_router from api.endpoints.generation_router import router as api_gen_router
from api.endpoints.auth import router as api_auth_router from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
from api.endpoints.environment_router import router as environment_api_router
from api.endpoints.inspiration_router import router as inspiration_api_router
load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- КОНФИГУРАЦИЯ --- # --- КОНФИГУРАЦИЯ ---
BOT_TOKEN = os.getenv("BOT_TOKEN") # Настройки теперь берутся из config.py
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") BOT_TOKEN = settings.BOT_TOKEN
GEMINI_API_KEY = settings.GEMINI_API_KEY
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017 MONGO_HOST = settings.MONGO_HOST
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных DB_NAME = settings.DB_NAME
ADMIN_ID = int(os.getenv("ADMIN_ID", 0)) ADMIN_ID = settings.ADMIN_ID
def setup_logging(): def setup_logging():
@@ -58,6 +68,8 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ --- # --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML)) bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API # Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -70,15 +82,20 @@ char_repo = CharacterRepo(mongo_client)
# S3 Adapter # S3 Adapter
s3_adapter = S3Adapter( s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"), endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"), aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"), aws_secret_access_key=settings.MINIO_SECRET_KEY,
bucket_name=os.getenv("MINIO_BUCKET", "ai-char") bucket_name=settings.MINIO_BUCKET
) )
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY) gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot) if bot is None:
raise ValueError("bot is not set")
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
album_service = AlbumService(dao)
# Dispatcher # Dispatcher
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME)) dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
@@ -93,7 +110,7 @@ dp["gemini"] = gemini
# 1. Роутеры без мидлварей (например, auth) # 1. Роутеры без мидлварей (например, auth)
dp.include_router(auth_router) dp.include_router(auth_router)
# 2. Основные роутеры # 2. Основные роутеры (daily_scheduler router добавляется в lifespan)
main_router = Router() main_router = Router()
dp.include_router(main_router) dp.include_router(main_router)
dp.include_router(assets_router) dp.include_router(assets_router)
@@ -114,6 +131,46 @@ assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_
gen_router.message.middleware(AlbumMiddleware(latency=0.8)) gen_router.message.middleware(AlbumMiddleware(latency=0.8))
async def start_scheduler(service: GenerationService):
while True:
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=14)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(60) # Check every 60 seconds
def _build_daily_scheduler() -> DailyScheduler:
"""Construct DailyScheduler; MetaAdapter is optional (needs env vars)."""
meta_adapter = None
if settings.META_ACCESS_TOKEN and settings.META_INSTAGRAM_ACCOUNT_ID:
from adapters.meta_adapter import MetaAdapter
meta_adapter = MetaAdapter(
access_token=settings.META_ACCESS_TOKEN,
instagram_account_id=settings.META_INSTAGRAM_ACCOUNT_ID,
)
logger.info("MetaAdapter initialized")
else:
logger.warning("META_ACCESS_TOKEN / META_INSTAGRAM_ACCOUNT_ID not set — Instagram publishing disabled")
if not settings.SCHEDULER_CHARACTER_ID:
logger.warning("SCHEDULER_CHARACTER_ID not set — daily scheduler will error at runtime")
return DailyScheduler(
dao=dao,
gemini=gemini,
s3_adapter=s3_adapter,
generation_service=generation_service,
bot=bot,
admin_id=ADMIN_ID,
character_id=settings.SCHEDULER_CHARACTER_ID or "",
meta_adapter=meta_adapter,
)
# --- LIFESPAN (Запуск FastAPI + Bot) --- # --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@@ -132,29 +189,44 @@ async def lifespan(app: FastAPI):
app.state.gemini_client = gemini app.state.gemini_client = gemini
app.state.bot = bot app.state.bot = bot
app.state.s3_adapter = s3_adapter app.state.s3_adapter = s3_adapter
app.state.album_service = album_service
app.state.users_repo = users_repo # Добавляем репозиторий в state app.state.users_repo = users_repo # Добавляем репозиторий в state
print("✅ DB & DAO initialized") print("✅ DB & DAO initialized")
# 2. ЗАПУСК БОТА (в фоне) # 2. Инициализация и регистрация daily_scheduler роутера
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn daily_scheduler = _build_daily_scheduler()
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше dp.include_router(create_daily_scheduler_router(daily_scheduler))
print("📅 Daily scheduler router registered")
# 3. ЗАПУСК БОТА (в фоне)
# handle_signals=False — бот не перехватывает сигналы остановки у uvicorn
polling_task = asyncio.create_task( polling_task = asyncio.create_task(
dp.start_polling(bot, handle_signals=False) dp.start_polling(bot, handle_signals=False)
) )
print("🤖 Bot polling started") print("🤖 Bot polling started")
# 4. ЗАПУСК ШЕДУЛЕРОВ
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
daily_scheduler_task = asyncio.create_task(daily_scheduler.run_loop())
print("⏰ Schedulers started")
yield yield
# --- SHUTDOWN --- # --- SHUTDOWN ---
print("🛑 Shutting down...") print("🛑 Shutting down...")
# 3. Остановка бота # Останавливаем все фоновые задачи
polling_task.cancel() for task, name in [
(polling_task, "Bot polling"),
(scheduler_task, "Stale-gen scheduler"),
(daily_scheduler_task, "Daily scheduler"),
]:
task.cancel()
try: try:
await polling_task await task
except asyncio.CancelledError: except asyncio.CancelledError:
print("🤖 Bot polling stopped") print(f"{name} stopped")
# 4. Отключение БД # 4. Отключение БД
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается # Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
@@ -173,16 +245,30 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Подключаем роутер API # Подключаем роутеры API
from api.endpoints.auth import router as auth_api_router app.include_router(api_auth_router)
from api.endpoints.admin import router as admin_api_router app.include_router(api_admin_router)
app.include_router(auth_api_router)
app.include_router(admin_api_router)
app.include_router(api_assets_router) app.include_router(api_assets_router)
app.include_router(api_char_router) app.include_router(api_char_router)
app.include_router(api_gen_router) app.include_router(api_gen_router)
app.include_router(api_admin_router) app.include_router(api_album_router)
app.include_router(api_auth_router) app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
app.include_router(environment_api_router)
app.include_router(inspiration_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
).instrument(
app,
metric_namespace="ai_bot",
).expose(app, endpoint="/metrics", include_in_schema=False)
app_info = Info("fastapi_app_info", "FastAPI application info")
app_info.info({"app_name": "ai-bot"})
# --- ХЕНДЛЕРЫ БОТА (Main Router) --- # --- ХЕНДЛЕРЫ БОТА (Main Router) ---
@@ -209,7 +295,7 @@ if __name__ == "__main__":
async def main(): async def main():
# Создаем конфигурацию uvicorn вручную # Создаем конфигурацию uvicorn вручную
# loop="asyncio" заставляет использовать стандартный цикл # loop="asyncio" заставляет использовать стандартный цикл
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development") config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
server = uvicorn.Server(config) server = uvicorn.Server(config)
# Запускаем сервер (lifespan запустится внутри) # Запускаем сервер (lifespan запустится внутри)

BIN
api/.DS_Store vendored

Binary file not shown.

View File

@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from repos.dao import DAO from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ... # ... ваши импорты ...
@@ -44,3 +45,26 @@ def get_generation_service(
bot: Bot = Depends(get_bot_client), bot: Bot = Depends(get_bot_client),
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot)
from api.service.idea_service import IdeaService
def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
return IdeaService(dao)
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)
from api.service.inspiration_service import InspirationService
def get_inspiration_service(dao: DAO = Depends(get_dao), s3_adapter: S3Adapter = Depends(get_s3_adapter)) -> InspirationService:
return InspirationService(dao, s3_adapter)

Binary file not shown.

View File

@@ -1,10 +1,13 @@
from typing import Annotated, List from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus from repos.user_repo import UsersRepo, UserStatus
from api.dependency import get_dao
from repos.dao import DAO
from models.Settings import SystemSettings
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt from jose import JWTError, jwt
from starlette.requests import Request from starlette.requests import Request
@@ -23,7 +26,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
) )
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub") username: str | None = payload.get("sub")
if username is None: if username is None:
raise credentials_exception raise credentials_exception
except JWTError: except JWTError:
@@ -52,7 +55,7 @@ class UserResponse(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
@router.get("/approvals", response_model=List[UserResponse]) @router.get("/approvals", response_model=list[UserResponse])
async def list_pending_users( async def list_pending_users(
admin: Annotated[dict, Depends(get_current_admin)], admin: Annotated[dict, Depends(get_current_admin)],
repo: Annotated[UsersRepo, Depends(get_users_repo)] repo: Annotated[UsersRepo, Depends(get_users_repo)]
@@ -94,3 +97,21 @@ async def deny_user(
await repo.deny_user(username) await repo.deny_user(username)
return {"message": f"User {username} denied"} return {"message": f"User {username} denied"}
@router.get("/settings", response_model=SystemSettings)
async def get_settings(
admin: Annotated[dict, Depends(get_current_admin)],
dao: Annotated[DAO, Depends(get_dao)]
):
return await dao.settings.get_settings()
@router.post("/settings")
async def update_settings(
settings: SystemSettings,
admin: Annotated[dict, Depends(get_current_admin)],
dao: Annotated[DAO, Depends(get_dao)]
):
success = await dao.settings.update_settings(settings)
if not success:
raise HTTPException(status_code=500, detail="Failed to update settings")
return {"message": "Settings updated successfully"}

View File

@@ -0,0 +1,83 @@
from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"])
class AlbumCreateRequest(BaseModel):
name: str
description: str | None = None
class AlbumUpdateRequest(BaseModel):
name: str | None = None
description: str | None = None
class AlbumResponse(BaseModel):
id: str
name: str
description: str | None = None
generation_ids: list[str] = []
cover_asset_id: str | None = None # Not implemented yet
@router.post("", response_model=AlbumResponse)
async def create_album(request: Request, album_in: AlbumCreateRequest):
service: AlbumService = request.app.state.album_service
album = await service.create_album(name=album_in.name, description=album_in.description)
return AlbumResponse(**album.model_dump())
@router.get("", response_model=list[AlbumResponse])
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
albums = await service.get_albums(limit=limit, offset=offset)
return [AlbumResponse(**album.model_dump()) for album in albums]
@router.get("/{album_id}", response_model=AlbumResponse)
async def get_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
album = await service.get_album(album_id)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.put("/{album_id}", response_model=AlbumResponse)
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
service: AlbumService = request.app.state.album_service
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
deleted = await service.delete_album(album_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
@router.post("/{album_id}/generations/{generation_id}")
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.add_generation_to_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.delete("/{album_id}/generations/{generation_id}")
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.remove_generation_from_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.get("/{album_id}/generations", response_model=list[GenerationResponse])
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
return [GenerationResponse(**gen.model_dump()) for gen in generations]

View File

@@ -1,17 +1,21 @@
from typing import List, Optional from typing import Any
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from bson import ObjectId
from fastapi import APIRouter, UploadFile, File, Form, Depends from fastapi import APIRouter, UploadFile, File, Form, Depends
from fastapi.openapi.models import MediaType from fastapi.openapi.models import MediaType
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from starlette import status from starlette import status
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import Response, JSONResponse, StreamingResponse
from api.models.AssetDTO import AssetsResponse, AssetResponse from adapters.s3_adapter import S3Adapter
from api.models import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao, get_mongo_client, get_s3_adapter
import asyncio import asyncio
import logging import logging
@@ -19,6 +23,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/assets", tags=["Assets"]) router = APIRouter(prefix="/api/assets", tags=["Assets"])
@@ -28,28 +33,214 @@ async def get_asset(
asset_id: str, asset_id: str,
request: Request, request: Request,
thumbnail: bool = False, thumbnail: bool = False,
dao: DAO = Depends(get_dao) dao: DAO = Depends(get_dao),
s3_adapter: S3Adapter = Depends(get_s3_adapter),
) -> Response: ) -> Response:
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}") logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
asset = await dao.assets.get_asset(asset_id) # Загружаем только метаданные (без data/thumbnail bytes)
# 2. Проверка на существование asset = await dao.assets.get_asset(asset_id, with_data=False)
if not asset: if not asset:
raise HTTPException(status_code=404, detail="Asset not found") raise HTTPException(status_code=404, detail="Asset not found")
headers = { base_headers = {
# Кэшировать на 1 год (31536000 сек) "Cache-Control": "public, max-age=31536000, immutable",
"Cache-Control": "public, max-age=31536000, immutable" "Accept-Ranges": "bytes"
} }
content = asset.data # Thumbnail: маленький, можно грузить в RAM
media_type = "image/png" # Default, or detect if thumbnail:
if asset.minio_thumbnail_object_name and s3_adapter:
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
if thumb_bytes:
return Response(content=thumb_bytes, media_type="image/jpeg", headers=base_headers)
# Fallback: thumbnail in DB
if asset.thumbnail:
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=base_headers)
# No thumbnail available — fall through to main content
if thumbnail and asset.thumbnail: # Main content: стримим из S3 без загрузки в RAM
content = asset.thumbnail if asset.minio_object_name and s3_adapter:
media_type = "image/jpeg" content_type = "image/png"
if asset.content_type == AssetContentType.VIDEO:
content_type = "video/mp4" # Or detect from extension if stored
elif asset.content_type == AssetContentType.IMAGE:
content_type = "image/png" # Default for images
return Response(content=content, media_type=media_type, headers=headers) # Better content type detection based on extension if possible, but for now this is okay
if asset.minio_object_name.endswith(".mp4"):
content_type = "video/mp4"
elif asset.minio_object_name.endswith(".mov"):
content_type = "video/quicktime"
elif asset.minio_object_name.endswith(".jpg") or asset.minio_object_name.endswith(".jpeg"):
content_type = "image/jpeg"
# Handle Range requests for video streaming
range_header = request.headers.get("range")
file_size = await s3_adapter.get_file_size(asset.minio_object_name)
if range_header and file_size:
try:
# Parse Range header: bytes=start-end
byte_range = range_header.replace("bytes=", "")
start_str, end_str = byte_range.split("-")
start = int(start_str)
end = int(end_str) if end_str else file_size - 1
# Validate range
if start >= file_size:
# 416 Range Not Satisfiable
return Response(status_code=416, headers={"Content-Range": f"bytes */{file_size}"})
chunk_size = end - start + 1
headers = base_headers.copy()
headers.update({
"Content-Range": f"bytes {start}-{end}/{file_size}",
"Content-Length": str(chunk_size),
})
# Pass the exact range string to S3
s3_range = f"bytes={start}-{end}"
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name, range_header=s3_range),
status_code=206,
headers=headers,
media_type=content_type
)
except ValueError:
pass # Fallback to full content if range parsing fails
# Full content response
headers = base_headers.copy()
if file_size:
headers["Content-Length"] = str(file_size)
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name),
media_type=content_type,
headers=headers,
)
# Fallback: data stored in DB (legacy)
if asset.data:
return Response(content=asset.data, media_type="image/png", headers=base_headers)
raise HTTPException(status_code=404, detail="Asset data not found")
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio(
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
minio_client: S3Adapter = Depends(get_s3_adapter),
*,
assets_collection: str = "assets",
generations_collection: str = "generations",
asset_type: str | None = "generated",
project_id: str | None = None,
dry_run: bool = True,
mark_assets_deleted: bool = False,
batch_size: int = 500,
) -> dict[str, Any]:
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
assets = db[assets_collection]
match_assets: dict[str, Any] = {}
if asset_type is not None:
match_assets["type"] = asset_type
if project_id is not None:
match_assets["project_id"] = project_id
pipeline: list[dict[str, Any]] = [
{"$match": match_assets} if match_assets else {"$match": {}},
{
"$lookup": {
"from": generations_collection,
"let": {"assetIdStr": {"$toString": "$_id"}},
"pipeline": [
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
{"$match": {"is_deleted": {"$ne": True}}},
{
"$match": {
"$expr": {
"$in": [
"$$assetIdStr",
{"$ifNull": ["$result_list", []]},
]
}
}
},
{"$limit": 1},
],
"as": "alive_generations",
}
},
{
"$match": {
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
}
},
{
"$project": {
"_id": 1,
"minio_object_name": 1,
"minio_thumbnail_object_name": 1,
}
},
]
print(pipeline)
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
deleted_objects = 0
deleted_assets = 0
errors: list[dict[str, Any]] = []
orphan_asset_ids: list[ObjectId] = []
async for asset in cursor:
aid = asset["_id"]
obj = asset.get("minio_object_name")
thumb = asset.get("minio_thumbnail_object_name")
orphan_asset_ids.append(aid)
if dry_run:
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
continue
try:
if obj:
await minio_client.delete_file(obj)
deleted_objects += 1
if thumb:
await minio_client.delete_file(thumb)
deleted_objects += 1
deleted_assets += 1
except Exception as e:
errors.append({"asset_id": str(aid), "error": str(e)})
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
res = await assets.update_many(
{"_id": {"$in": orphan_asset_ids}},
{"$set": {"is_deleted": True}},
)
marked = res.modified_count
else:
marked = 0
return {
"dry_run": dry_run,
"filter": {
"asset_type": asset_type,
"project_id": project_id,
},
"orphans_found": len(orphan_asset_ids),
"deleted_assets": deleted_assets,
"deleted_objects": deleted_objects,
"marked_assets_deleted": marked,
"errors": errors,
}
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) @router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
async def delete_asset( async def delete_asset(
@@ -68,11 +259,19 @@ async def delete_asset(
@router.get("", dependencies=[Depends(get_current_user)]) @router.get("", dependencies=[Depends(get_current_user)])
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse: async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str | None = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: str | None = Depends(get_project_id)) -> AssetsResponse:
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}") logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
assets = await dao.assets.get_assets(type, limit, offset)
user_id_filter = current_user["id"]
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code # assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
total_count = await dao.assets.get_asset_count() total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
# Manually map to DTO to trigger computed fields validation if necessary, # Manually map to DTO to trigger computed fields validation if necessary,
# but primarily to ensure valid Pydantic models for the response list. # but primarily to ensure valid Pydantic models for the response list.
@@ -84,11 +283,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)]) @router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset( async def upload_asset(
file: UploadFile = File(...), file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None), linked_char_id: str | None = Form(None),
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id)
): ):
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}") logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
if not file.content_type: if not file.content_type:
@@ -97,6 +298,11 @@ async def upload_asset(
if not file.content_type.startswith("image/"): if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}") raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
data = await file.read() data = await file.read()
if not data: if not data:
raise HTTPException(status_code=400, detail="Empty file") raise HTTPException(status_code=400, detail="Empty file")
@@ -111,7 +317,9 @@ async def upload_asset(
content_type=AssetContentType.IMAGE, content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id, linked_char_id=linked_char_id,
data=data, data=data,
thumbnail=thumbnail_bytes thumbnail=thumbnail_bytes,
created_by=str(current_user["_id"]),
project_id=project_id,
) )
asset_id = await dao.assets.create_asset(asset) asset_id = await dao.assets.create_asset(asset)
@@ -124,8 +332,7 @@ async def upload_asset(
type=asset.type.value if hasattr(asset.type, "value") else asset.type, type=asset.type.value if hasattr(asset.type, "value") else asset.type,
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type, content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
linked_char_id=asset.linked_char_id, linked_char_id=asset.linked_char_id,
created_at=asset.created_at, created_at=asset.created_at
url=asset.url
) )
@@ -172,3 +379,4 @@ async def migrate_to_minio(dao: DAO = Depends(get_dao)):
result = await dao.assets.migrate_to_minio() result = await dao.assets.migrate_to_minio()
logger.info(f"Migration result: {result}") logger.info(f"Migration result: {result}")
return result return result

View File

@@ -59,6 +59,7 @@ class Token(BaseModel):
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str
username: str username: str
full_name: str | None = None full_name: str | None = None
status: str status: str

View File

@@ -1,14 +1,15 @@
from typing import List, Any, Coroutine from typing import Any, Coroutine
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse from api.models import GenerationRequest, GenerationResponse
from models.Asset import Asset from models.Asset import Asset
from models.Character import Character from models.Character import Character
from api.models import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao
@@ -17,25 +18,61 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)]) router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
@router.get("/", response_model=List[Character]) @router.get("/", response_model=list[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]: async def get_characters(
logger.info("get_characters called") request: Request,
characters = await dao.chars.get_all_characters() dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
limit: int = 100,
offset: int = 0
) -> list[Character]:
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
characters = await dao.chars.get_all_characters(
created_by=user_id_filter,
project_id=project_id,
limit=limit,
offset=offset
)
return characters return characters
@router.get("/{character_id}/assets", response_model=AssetsResponse) @router.get("/{character_id}/assets", response_model=AssetsResponse)
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10, async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
offset: int = 0, ) -> AssetsResponse: offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}") logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if character is None: if character is None:
raise HTTPException(status_code=404, detail="Character not found") raise HTTPException(status_code=404, detail="Character not found")
# Access Check
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset) assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
# Filter assets by user ownership as well?
# Usually if you own character, you see its assets.
# But assets also have specific created_by.
# Let's assume if you own character you can see its assets.
total_count = await dao.assets.get_asset_count(character_id) total_count = await dao.assets.get_asset_count(character_id)
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets] asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
@@ -43,14 +80,113 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
@router.get("/{character_id}", response_model=Character) @router.get("/{character_id}", response_model=Character)
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character: async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
logger.debug(f"get_character_by_id called. ID: {character_id}") logger.debug(f"get_character_by_id called. ID: {character_id}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
if character:
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
return character return character
@router.post("/{character_id}/_run", response_model=GenerationResponse) @router.post("/", response_model=Character)
async def post_character_generation(character_id: str, generation: GenerationRequest, async def create_character(
request: Request) -> GenerationResponse: char_req: CharacterCreateRequest,
logger.info(f"post_character_generation called. CharacterID: {character_id}") project_id: str | None = Depends(get_project_id),
generation_service = request.app.state.generation_service dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info("create_character called")
char_req.project_id = project_id
char_data = char_req.model_dump()
char_data["created_by"] = str(current_user["_id"])
if "id" not in char_data:
char_data["id"] = None
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
new_char = Character(**char_data)
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
created_char = await dao.chars.add_character(new_char)
return created_char
@router.put("/{character_id}", response_model=Character)
async def update_character(
character_id: str,
char_update: CharacterUpdateRequest,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info(f"update_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
update_data = char_update.model_dump(exclude_unset=True)
if "project_id" in update_data and update_data["project_id"]:
new_project_id = update_data["project_id"]
project = await dao.projects.get_project(new_project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Target project access denied")
updated_char_data = existing_char.model_dump()
updated_char_data.update(update_data)
updated_char = Character(**updated_char_data)
success = await dao.chars.update_char(character_id, updated_char)
if not success:
raise HTTPException(status_code=500, detail="Failed to update character")
return updated_char
@router.delete("/{character_id}", status_code=204)
async def delete_character(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"delete_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
success = await dao.chars.delete_character(character_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete character")
return

View File

@@ -0,0 +1,191 @@
import logging
from fastapi import APIRouter, Depends, HTTPException
from starlette import status
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
from models.Environment import Environment
from repos.dao import DAO
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied to character")
return character
@router.post("/", response_model=Environment)
async def create_environment(
env_req: EnvironmentCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
await check_character_access(env_req.character_id, current_user, dao)
# Verify assets exist if provided
if env_req.asset_ids:
for aid in env_req.asset_ids:
asset = await dao.assets.get_asset(aid)
if not asset:
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
new_env = Environment(**env_req.model_dump())
created_env = await dao.environments.create_env(new_env)
return created_env
@router.get("/character/{character_id}", response_model=list[Environment])
async def get_character_environments(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Getting environments for character {character_id}")
await check_character_access(character_id, current_user, dao)
return await dao.environments.get_character_envs(character_id)
@router.get("/{env_id}", response_model=Environment)
async def get_environment(
env_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
return env
@router.put("/{env_id}", response_model=Environment)
async def update_environment(
env_id: str,
env_update: EnvironmentUpdate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
update_data = env_update.model_dump(exclude_unset=True)
if not update_data:
return env
# Verify assets exist if provided
if "asset_ids" in update_data:
if update_data["asset_ids"] is None:
del update_data["asset_ids"]
elif update_data["asset_ids"]:
# Verify all assets exist using batch check
assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"])
if len(assets) != len(update_data["asset_ids"]):
found_ids = {a.id for a in assets}
missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids]
raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}")
success = await dao.environments.update_env(env_id, update_data)
if not success:
raise HTTPException(status_code=500, detail="Failed to update environment")
return await dao.environments.get_env(env_id)
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_environment(
env_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
success = await dao.environments.delete_env(env_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete environment")
return None
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
async def add_asset_to_environment(
env_id: str,
req: AssetToEnvironment,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
# Verify asset exists
asset = await dao.assets.get_asset(req.asset_id)
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
success = await dao.environments.add_asset(env_id, req.asset_id)
return {"success": success}
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
async def add_assets_batch_to_environment(
env_id: str,
req: AssetsToEnvironment,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
# Verify all assets exist
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
if len(assets) != len(req.asset_ids):
found_ids = {a.id for a in assets}
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
success = await dao.environments.add_assets(env_id, req.asset_ids)
return {"success": success}
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
async def remove_asset_from_environment(
env_id: str,
asset_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
success = await dao.environments.remove_asset(env_id, asset_id)
return {"success": success}

View File

@@ -1,84 +1,258 @@
from typing import List, Optional import logging
import json
from fastapi import APIRouter, UploadFile, File, Form from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends from fastapi.params import Depends
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from api import service from config import settings
from api.dependency import get_generation_service from api.dependency import get_generation_service, get_project_id, get_dao
from api.endpoints.auth import get_current_user
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.models import (
GenerationResponse,
GenerationRequest,
GenerationsResponse,
PromptResponse,
PromptRequest,
GenerationGroupResponse,
FinancialReport,
ExternalGenerationRequest,
NsfwRequest
)
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Generation import Generation from repos.dao import DAO
from utils.external_auth import verify_signature
from starlette import status
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user router = APIRouter(prefix='/api/generations', tags=["Generation"])
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
"""Helper to check if user has access to project."""
if not project_id:
return
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
@router.post("/prompt-assistant", response_model=PromptResponse) @router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request, async def ask_prompt_assistant(
generation_service: GenerationService = Depends( prompt_request: PromptRequest,
get_generation_service)) -> PromptResponse: generation_service: GenerationService = Depends(get_generation_service),
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}") current_user: dict = Depends(get_current_user)
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets) ) -> PromptResponse:
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
generated_prompt = await generation_service.ask_prompt_assistant(
prompt_request.prompt,
prompt_request.linked_assets,
prompt_request.model
)
return PromptResponse(prompt=generated_prompt) return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse) @router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image( async def prompt_from_image(
prompt: Optional[str] = Form(None), prompt: str | None = Form(None),
images: List[UploadFile] = File(...), model: str = Form("gemini-3.1-pro-preview"),
generation_service: GenerationService = Depends(get_generation_service) images: list[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse: ) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}") images_bytes = [await img.read() for img in images]
images_bytes = [] generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt) return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse) @router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, async def get_generations(
generation_service: GenerationService = Depends(get_generation_service)): character_id: str | None = None,
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") limit: int = 10,
return await generation_service.get_generations(character_id, limit=limit, offset=offset) offset: int = 0,
only_liked: bool = False,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
# If project_id is set, we don't filter by user to show all project-wide generations
created_by_filter = None if project_id else str(current_user["_id"])
only_liked_by = str(current_user["_id"]) if only_liked else None
return await generation_service.get_generations(
character_id=character_id,
limit=limit,
offset=offset,
created_by=created_by_filter,
project_id=project_id,
only_liked_by=only_liked_by,
current_user_id=str(current_user["_id"])
)
@router.post("/_run", response_model=GenerationResponse) @router.get("/usage", response_model=FinancialReport)
async def post_generation(generation: GenerationRequest, request: Request, async def get_usage_report(
generation_service: GenerationService = Depends( breakdown: str | None = None, # "user" or "project"
get_generation_service)) -> GenerationResponse: generation_service: GenerationService = Depends(get_generation_service),
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") current_user: dict = Depends(get_current_user),
return await generation_service.create_generation_task(generation) project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"]) if not project_id else None
breakdown_by = None
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
return await generation_service.get_financial_report(
user_id=user_id_filter,
project_id=project_id,
breakdown_by=breakdown_by
)
@router.get("/{generation_id}", response_model=GenerationResponse) @router.post("/_run", response_model=GenerationGroupResponse)
async def get_generation(generation_id: str, async def post_generation(
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse: generation: GenerationRequest,
logger.debug(f"get_generation called for ID: {generation_id}") generation_service: GenerationService = Depends(get_generation_service),
return await generation_service.get_generation(generation_id) current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> GenerationGroupResponse:
await check_project_access(project_id, current_user, dao)
if project_id:
generation.project_id = project_id
return await generation_service.create_generation_task(
generation,
user_id=str(current_user.get("_id"))
)
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(
generation_service: GenerationService = Depends(get_generation_service)): generation_service: GenerationService = Depends(get_generation_service),
return await generation_service.get_running_generations() current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_running_generations(
user_id=user_id_filter,
project_id=project_id
)
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) @router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)): async def get_generation_group(
logger.info(f"delete_generation called for ID: {generation_id}") group_id: str,
deleted = await generation_service.delete_generation(generation_id) generation_service: GenerationService = Depends(get_generation_service),
if not deleted: current_user: dict = Depends(get_current_user)
):
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> GenerationResponse:
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.post("/{generation_id}/like", response_model=dict)
async def toggle_like(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
if is_liked is None:
raise HTTPException(status_code=404, detail="Generation not found")
return {"is_liked": is_liked}
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
async def mark_generation_nsfw(
generation_id: str,
request: NsfwRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied")
await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw)
return None
@router.post("/import", response_model=GenerationResponse)
async def import_external_generation(
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
body = await request.body()
secret = settings.EXTERNAL_API_SECRET
if not secret:
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
raise HTTPException(status_code=401, detail="Invalid signature")
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
logger.error(f"Failed to import external generation: {e}")
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
if not await generation_service.delete_generation(generation_id):
raise HTTPException(status_code=404, detail="Generation not found") raise HTTPException(status_code=404, detail="Generation not found")
return None return None

View File

@@ -0,0 +1,106 @@
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from api.dependency import get_idea_service, get_project_id, get_generation_service
from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models import GenerationResponse, GenerationsResponse
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
request: IdeaCreateRequest,
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
pid = project_id or request.project_id
return await idea_service.create_idea(
name=request.name,
description=request.description,
project_id=pid,
user_id=str(current_user["_id"]),
inspiration_id=request.inspiration_id
)
@router.get("", response_model=list[IdeaResponse])
async def get_ideas(
project_id: str | None = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
return await idea_service.get_ideas(project_id, str(current_user["_id"]), limit, offset)
@router.get("/{idea_id}", response_model=Idea)
async def get_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.get_idea(idea_id)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.put("/{idea_id}", response_model=Idea)
async def update_idea(
idea_id: str,
request: IdeaUpdateRequest,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.update_idea(
idea_id=idea_id,
name=request.name,
description=request.description,
inspiration_id=request.inspiration_id
)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@router.delete("/{idea_id}")
async def delete_idea(
idea_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.delete_idea(idea_id)
if not success:
raise HTTPException(status_code=404, detail="Idea not found or could not be deleted")
return {"status": "success"}
@router.get("/{idea_id}/generations", response_model=GenerationsResponse)
async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset, current_user_id=str(current_user["_id"]))
@router.post("/{idea_id}/generations/{generation_id}")
async def add_generation_to_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.add_generation_to_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}
@router.delete("/{idea_id}/generations/{generation_id}")
async def remove_generation_from_idea(
idea_id: str,
generation_id: str,
idea_service: IdeaService = Depends(get_idea_service)
):
success = await idea_service.remove_generation_from_idea(idea_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Idea or Generation not found")
return {"status": "success"}

View File

@@ -0,0 +1,94 @@
from fastapi import APIRouter, Depends, HTTPException, status
from api.dependency import get_inspiration_service, get_project_id
from api.endpoints.auth import get_current_user
from api.models.InspirationRequest import InspirationCreateRequest, InspirationResponse, InspirationListResponse
from api.service.inspiration_service import InspirationService
from models.Inspiration import Inspiration
router = APIRouter(prefix="/api/inspirations", tags=["Inspirations"])
@router.post("", response_model=InspirationResponse, status_code=status.HTTP_201_CREATED)
async def create_inspiration(
request: InspirationCreateRequest,
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
service: InspirationService = Depends(get_inspiration_service)
):
pid = project_id or request.project_id
inspiration = await service.create_inspiration(
source_url=request.source_url,
created_by=str(current_user["_id"]),
project_id=pid,
caption=request.caption
)
return inspiration
@router.get("", response_model=InspirationListResponse)
async def get_inspirations(
project_id: str | None = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
service: InspirationService = Depends(get_inspiration_service)
):
# If project_id is provided, filter by it. Otherwise, filter by user.
# Or maybe we want to see all user's inspirations if no project is selected?
# Let's follow the pattern: if project_id is present, show project's inspirations.
# If not, show user's personal inspirations (where project_id is None) OR all user's inspirations?
# Usually "My Inspirations" means created by me.
# Let's assume:
# If project_id -> filter by project_id (and maybe created_by if we want strict ownership, but usually project members share)
# If no project_id -> filter by created_by (personal)
pid = project_id
uid = str(current_user["_id"])
inspirations = await service.get_inspirations(project_id=pid, created_by=uid if not pid else None, limit=limit, offset=offset)
total_count = await service.dao.inspirations.count_inspirations(project_id=pid, created_by=uid if not pid else None)
return InspirationListResponse(
inspirations=[InspirationResponse(**inspiration.model_dump()) for inspiration in inspirations],
total_count=total_count
)
@router.get("/{inspiration_id}", response_model=InspirationResponse)
async def get_inspiration(
inspiration_id: str,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
inspiration = await service.get_inspiration(inspiration_id)
if not inspiration:
raise HTTPException(status_code=404, detail="Inspiration not found")
return inspiration
@router.patch("/{inspiration_id}/complete", response_model=InspirationResponse)
async def mark_inspiration_complete(
inspiration_id: str,
is_completed: bool = True,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
inspiration = await service.mark_as_completed(inspiration_id, is_completed)
if not inspiration:
raise HTTPException(status_code=404, detail="Inspiration not found")
return inspiration
@router.delete("/{inspiration_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_inspiration(
inspiration_id: str,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
success = await service.delete_inspiration(inspiration_id)
if not success:
raise HTTPException(status_code=404, detail="Inspiration not found")
return None

View File

@@ -0,0 +1,98 @@
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=list[Post])
async def get_posts(
project_id: str | None = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: datetime | None = None,
date_to: datetime | None = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

View File

@@ -0,0 +1,181 @@
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from models.Project import Project
from repos.dao import DAO
router = APIRouter(prefix="/api/projects", tags=["Projects"])
class ProjectCreate(BaseModel):
name: str
description: str | None = None
class ProjectMemberResponse(BaseModel):
id: str
username: str
class ProjectResponse(BaseModel):
id: str
name: str
description: str | None = None
owner_id: str
members: list[ProjectMemberResponse]
is_owner: bool = False
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
member_responses = []
for member_id in project.members:
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
# But for Web users we might need to search by _id.
# Let's try to get user info.
# Since project.members contains strings (ObjectIds as strings), we search by _id.
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
if not user_doc and member_id.isdigit():
# Fallback for telegram IDs if they are stored as strings of digits
user_doc = await dao.users.get_user(int(member_id))
username = "unknown"
if user_doc:
username = user_doc.get("username", "unknown")
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
return ProjectResponse(
id=project.id,
name=project.name,
description=project.description,
owner_id=project.owner_id,
members=member_responses,
is_owner=(project.owner_id == current_user_id)
)
@router.post("", response_model=ProjectResponse)
async def create_project(
project_data: ProjectCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
new_project = Project(
name=project_data.name,
description=project_data.description,
owner_id=user_id,
members=[user_id]
)
project_id = await dao.projects.create_project(new_project)
new_project.id = project_id
# Add project to user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return await _get_project_response(new_project, user_id, dao)
@router.get("", response_model=list[ProjectResponse])
async def get_my_projects(
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
projects = await dao.projects.get_projects_by_user(user_id)
responses = []
for p in projects:
responses.append(await _get_project_response(p, user_id, dao))
return responses
class MemberAdd(BaseModel):
username: str
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
async def add_member(
project_id: str,
member_data: MemberAdd,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can add members")
target_user = await dao.users.get_user_by_username(member_data.username)
if not target_user:
raise HTTPException(status_code=404, detail="User not found")
target_user_id = str(target_user["_id"])
if target_user_id in project.members:
return {"message": "User already in project"}
await dao.projects.add_member(project_id, target_user_id)
# Update target user's project list
await dao.users.collection.update_one(
{"_id": target_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Member added"}
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
async def join_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
# Retrieve project to verify it exists
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
user_id = str(current_user["_id"])
# Check if user is ALREADY in project
if user_id in project.members:
return {"message": "Already a member"}
# Add member
await dao.projects.add_member(project_id, user_id)
# Update user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Joined project"}
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
async def delete_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can delete project")
await dao.projects.delete_project(project_id)
# Remove project from user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$pull": {"project_ids": project_id}}
)
return {"message": "Project deleted"}

View File

@@ -1,5 +1,4 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -11,10 +10,10 @@ class AssetResponse(BaseModel):
name: str name: str
type: str # uploaded / generated type: str # uploaded / generated
content_type: str # image / prompt content_type: str # image / prompt
linked_char_id: Optional[str] = None linked_char_id: str | None = None
created_at: datetime created_at: datetime
url: Optional[str] = None url: str | None = None
class AssetsResponse(BaseModel): class AssetsResponse(BaseModel):
assets: List[AssetResponse] assets: list[AssetResponse]
total_count: int total_count: int

View File

@@ -0,0 +1,17 @@
from pydantic import BaseModel
class CharacterCreateRequest(BaseModel):
name: str
character_bio: str
character_image_doc_tg_id: str | None = None
avatar_image: str | None = None
character_image_tg_id: str | None = None
project_id: str | None = None
class CharacterUpdateRequest(BaseModel):
name: str | None = None
character_bio: str | None = None
character_image_doc_tg_id: str | None = None
avatar_image: str | None = None
character_image_tg_id: str | None = None
project_id: str | None = None

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel, Field
class EnvironmentCreate(BaseModel):
character_id: str
name: str = Field(..., min_length=1)
description: str | None = None
asset_ids: list[str] | None = []
class EnvironmentUpdate(BaseModel):
name: str | None = Field(None, min_length=1)
description: str | None = None
asset_ids: list[str] | None = None
class AssetToEnvironment(BaseModel):
asset_id: str
class AssetsToEnvironment(BaseModel):
asset_ids: list[str]

View File

@@ -0,0 +1,40 @@
from pydantic import BaseModel, Field
from models.enums import AspectRatios, Quality
class ExternalGenerationRequest(BaseModel):
"""Request model for importing external generations."""
prompt: str
tech_prompt: str | None = None
# Image can be provided as base64 string OR URL (one must be provided)
image_data: str | None = Field(None, description="Base64-encoded image data")
image_url: str | None = Field(None, description="URL to download image from")
nsfw: bool = False
# Generation metadata
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK
model: str | None = None
seed: int | None = None
# Optional linking
linked_character_id: str | None = None
created_by: str = Field(..., description="User ID from external system")
project_id: str | None = None
# Performance metrics
execution_time_seconds: float | None = None
api_execution_time_seconds: float | None = None
token_usage: int | None = None
input_token_usage: int | None = None
output_token_usage: int | None = None
def validate_image_source(self):
"""Ensure at least one image source is provided."""
if not self.image_data and not self.image_url:
raise ValueError("Either image_data or image_url must be provided")
if self.image_data and self.image_url:
raise ValueError("Only one of image_data or image_url should be provided")

View File

@@ -0,0 +1,17 @@
from pydantic import BaseModel
class UsageStats(BaseModel):
total_runs: int
total_tokens: int
total_input_tokens: int
total_output_tokens: int
total_cost: float
class UsageByEntity(BaseModel):
entity_id: str | None = None
stats: UsageStats
class FinancialReport(BaseModel):
summary: UsageStats
by_user: list[UsageByEntity] | None = None
by_project: list[UsageByEntity] | None = None

View File

@@ -1,55 +1,78 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from models.Asset import Asset from models.Asset import Asset
from models.Generation import GenerationStatus from models.Generation import GenerationStatus
from models.enums import AspectRatios, Quality, GenType from models.enums import AspectRatios, Quality, GenType, ImageModel, TextModel
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
linked_character_id: Optional[str] = None linked_character_id: str | None = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK quality: Quality = Quality.ONEK
prompt: str prompt: str
telegram_id: Optional[int] = None model: ImageModel = Field(default=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW)
telegram_id: int | None = None
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: list[str]
environment_id: str | None = None
project_id: str | None = None
idea_id: str | None = None
nsfw: bool = False
count: int = Field(default=1, ge=1, le=10)
class NsfwRequest(BaseModel):
is_nsfw: bool
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):
generations: List["GenerationResponse"] generations: list["GenerationResponse"]
total_count: int total_count: int
class GenerationResponse(BaseModel): class GenerationResponse(BaseModel):
id: str id: str
status: GenerationStatus status: GenerationStatus
failed_reason: Optional[str] = None failed_reason: str | None = None
project_id: str | None = None
linked_character_id: Optional[str] = None linked_character_id: str | None = None
aspect_ratio: AspectRatios aspect_ratio: AspectRatios
quality: Quality quality: Quality
prompt: str prompt: str
tech_prompt: Optional[str] = None model: ImageModel | None = None
assets_list: List[str] seed: int | None = None
result_list: List[str] = [] tech_prompt: str | None = None
result: Optional[str] = None assets_list: list[str]
execution_time_seconds: Optional[float] = None result_list: list[str] = []
api_execution_time_seconds: Optional[float] = None result: str | None = None
token_usage: Optional[int] = None execution_time_seconds: float | None = None
input_token_usage: Optional[int] = None api_execution_time_seconds: float | None = None
output_token_usage: Optional[int] = None token_usage: int | None = None
input_token_usage: int | None = None
output_token_usage: int | None = None
progress: int = 0 progress: int = 0
cost: float | None = None
created_by: str | None = None
generation_group_id: str | None = None
idea_id: str | None = None
likes_count: int = 0
is_liked: bool = False
nsfw: bool = False
created_at: datetime = datetime.now(UTC) created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: list[GenerationResponse]
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str
linked_assets: List[str] = [] model: TextModel = Field(default=TextModel.GEMINI_3_1_PRO_PREVIEW)
linked_assets: list[str] = []
class PromptResponse(BaseModel): class PromptResponse(BaseModel):

17
api/models/IdeaRequest.py Normal file
View File

@@ -0,0 +1,17 @@
from pydantic import BaseModel
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
class IdeaCreateRequest(BaseModel):
name: str
description: str | None = None
project_id: str | None = None # Optional in body if passed via header/dependency
inspiration_id: str | None = None
class IdeaUpdateRequest(BaseModel):
name: str | None = None
description: str | None = None
inspiration_id: str | None = None
class IdeaResponse(Idea):
last_generation: GenerationResponse | None = None

View File

@@ -0,0 +1,28 @@
from datetime import datetime
from pydantic import BaseModel
from models.Inspiration import Inspiration
class InspirationCreateRequest(BaseModel):
source_url: str
caption: str | None = None
project_id: str | None = None
class InspirationResponse(BaseModel):
id: str
source_url: str
caption: str | None = None
asset_id: str
is_completed: bool
created_by: str
project_id: str | None = None
created_at: datetime
updated_at: datetime
class InspirationListResponse(BaseModel):
inspirations: list[InspirationResponse]
total_count: int

18
api/models/PostRequest.py Normal file
View File

@@ -0,0 +1,18 @@
from datetime import datetime
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: list[str] = []
project_id: str | None = None
class PostUpdateRequest(BaseModel):
date: datetime | None = None
topic: str | None = None
class AddGenerationsRequest(BaseModel):
generation_ids: list[str]

View File

@@ -0,0 +1,7 @@
from .AssetDTO import AssetResponse, AssetsResponse
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from .ExternalGenerationDTO import ExternalGenerationRequest
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse, NsfwRequest
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest

BIN
api/service/.DS_Store vendored

Binary file not shown.

View File

@@ -0,0 +1,85 @@
from typing import List, Optional
from models.Album import Album
from models.Generation import Generation
from repos.dao import DAO
class AlbumService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
album = Album(name=name, description=description)
album_id = await self.dao.albums.create_album(album)
album.id = album_id
return album
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
return await self.dao.albums.get_albums(limit=limit, offset=offset)
async def get_album(self, album_id: str) -> Optional[Album]:
return await self.dao.albums.get_album(album_id)
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
album = await self.dao.albums.get_album(album_id)
if not album:
return None
if name:
album.name = name
if description is not None:
album.description = description
await self.dao.albums.update_album(album_id, album)
return album
async def delete_album(self, album_id: str) -> bool:
return await self.dao.albums.delete_album(album_id)
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
# Verify album exists
album = await self.dao.albums.get_album(album_id)
if not album:
return False
# Verify generation exists (optional but good practice)
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
if album.cover_asset_id is None and gen.status == 'done':
album.cover_asset_id = gen.result_list[0]
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
return await self.dao.albums.remove_generation(album_id, generation_id)
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
album = await self.dao.albums.get_album(album_id)
if not album or not album.generation_ids:
return []
# Slice the generation IDs (simple pagination on ID list)
# Note: This pagination is on IDs, then we fetch objects.
# Ideally, fetch only slice.
# Reverse to show newest first? Or just follow list order?
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
# Let's assume user wants same order as in list.
sliced_ids = album.generation_ids[offset : offset + limit]
if not sliced_ids:
return []
# Fetch generations by IDs
# We need a method in GenerationRepo to fetch by IDs.
# Currently we only have get_generations with filters.
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
# Let's add get_generations_by_ids to GenerationRepo.
# For now, I will use a loop if I can't modify Repo immediately,
# but I SHOULD modify GenerationRepo.
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
# But get_generations doesn't support generic filter passing.
# I'll update GenerationRepo to support fetching by IDs.
return await self.dao.generations.get_generations_by_ids(sliced_ids)

View File

@@ -1,234 +1,218 @@
import asyncio import asyncio
import base64
import logging import logging
import random import random
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO from uuid import uuid4
import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse from adapters.ai_proxy_adapter import AIProxyAdapter
# Импортируйте ваши модели DAO, Asset, Generation корректно from adapters.s3_adapter import S3Adapter
from api.models import (
FinancialReport, UsageStats, UsageByEntity,
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
)
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality, GenType from models.enums import AspectRatios, Quality
from repos.dao import DAO from repos.dao import DAO
from adapters.s3_adapter import S3Adapter from utils.image_utils import create_thumbnail
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации ---
async def generate_image_task( async def generate_image_task(
prompt: str, prompt: str,
media_group_bytes: List[bytes], media_group_bytes: List[bytes],
aspect_ratio: AspectRatios, aspect_ratio: AspectRatios,
quality: Quality, quality: Quality,
model: str,
gemini: GoogleAdapter, gemini: GoogleAdapter,
) -> Tuple[List[bytes], Dict[str, Any]]: ) -> Tuple[List[bytes], Dict[str, Any]]:
""" """
Обертка для вызова синхронного метода Gemini в отдельном потоке. Wrapper for calling Gemini's synchronous method in a separate thread.
Возвращает список байтов сгенерированных изображений.
""" """
try: try:
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}") logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
result = await asyncio.to_thread( result = await asyncio.to_thread(
gemini.generate_image, gemini.generate_image,
prompt=prompt, prompt=prompt,
images_list=media_group_bytes, images_list=media_group_bytes,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
quality=quality, quality=quality,
model=model,
) )
generated_images_io, metrics = result generated_images_io, metrics = result
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images") logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e: except GoogleGenerationException:
raise e raise
finally:
del media_group_bytes
images_bytes = [] images_bytes = []
if generated_images_io: if generated_images_io:
for img_io in generated_images_io: for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0) img_io.seek(0)
content = img_io.read() images_bytes.append(img_io.read())
images_bytes.append(content)
# Закрываем поток
img_io.close() img_io.close()
del generated_images_io
return images_bytes, metrics return images_bytes, metrics
class GenerationService: class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
self.dao = dao self.dao = dao
self.gemini = gemini self.gemini = gemini
self.ai_proxy = AIProxyAdapter()
self.s3_adapter = s3_adapter self.s3_adapter = s3_adapter
self.bot = bot self.bot = bot
# --- Public API ---
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str: async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too. future_prompt = (
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt. "You are an prompt-assistant. You improving user-entered prompts for image generation. "
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """ "User may upload reference image too. I will provide sources prompt entered by user. "
future_prompt += prompt "Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
f"USER_ENTERED_PROMPT: {prompt}"
)
settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
asset_urls = await self._prepare_asset_urls(assets) if assets else None
generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls)
else:
assets_data = [] assets_data = []
if assets is not None: if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets) assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db) assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
logger.info(future_prompt) generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
logger.info(generated_prompt)
logger.info(f"Prompt Assistant: {generated_prompt}")
return generated_prompt return generated_prompt
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str: async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. " technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
if user_prompt: if user_prompt:
technical_prompt += f"User also provided this context: {user_prompt}. " technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt." technical_prompt += "Provide ONLY the detailed prompt."
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
# Proxy doesn't support raw bytes currently.
# In a real scenario we'd upload them to a temp bucket.
# For now, we call the proxy with just the prompt,
# or we could fall back to GoogleAdapter if images are essential.
# To be safe and follow instructions to use proxy, we use it.
return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[ return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
Generation]:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset)
total_count = await self.dao.generations.count_generations(character_id = character_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]: async def get_generations(self, **kwargs) -> GenerationsResponse:
current_user_id = kwargs.pop('current_user_id', None)
generations = await self.dao.generations.get_generations(**kwargs)
total_count = await self.dao.generations.count_generations(
character_id=kwargs.get('character_id'),
created_by=kwargs.get('created_by'),
project_id=kwargs.get('project_id'),
idea_id=kwargs.get('idea_id'),
only_liked_by=kwargs.get('only_liked_by')
)
return GenerationsResponse(
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
total_count=total_count
)
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id) gen = await self.dao.generations.get_generation(generation_id)
if gen is None: return self._map_to_response(gen, current_user_id) if gen else None
return None
else:
return GenerationResponse(**gen.model_dump())
async def get_running_generations(self) -> List[Generation]: async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING) return await self.dao.generations.toggle_like(generation_id, user_id)
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse: async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
gen_id = None generations = await self.dao.generations.get_generations_by_group(group_id)
generation_model = None return GenerationGroupResponse(
generation_group_id=group_id,
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
)
try: def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
generation_model = Generation(**generation_request.model_dump()) res = GenerationResponse(**gen.model_dump())
gen_id = await self.dao.generations.create_generation(generation_model) res.likes_count = len(gen.liked_by) if gen.liked_by else 0
generation_model.id = gen_id res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
return res
async def runner(gen): async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
logger.info(f"Starting background generation task for ID: {gen.id}") return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
try:
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model)) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
if generation_group_id is None:
generation_group_id = str(uuid4())
return GenerationResponse(**generation_model.model_dump()) results = []
for _ in range(generation_request.count):
except Exception: gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
# если не успели создать запись — нечего помечать results.append(gen_response)
if gen_id is not None: return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation): async def create_generation(self, generation: Generation):
start_time = datetime.now() start_time = datetime.now()
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 2. Получаем ассеты-референсы (если они есть) # 1. Prepare input
reference_assets: List[Asset] = [] media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation)
media_group_bytes: List[bytes] = []
generation_prompt = f"""
Create detailed image of character in scene. # 2. Run generation with progress simulation
SCENE DESCRIPTION: {generation.prompt}
Rules:
- Integrate the character's appearance naturally into the scene description.
- Focus on lighting, texture, and composition.
"""
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
media_group_bytes.append(char_info.character_image_data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
for asset in reference_assets:
if asset.content_type != AssetContentType.IMAGE:
continue
img_data = None
if asset.minio_object_name:
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
elif asset.data:
img_data = asset.data
if img_data:
media_group_bytes.append(img_data)
if media_group_bytes:
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
# 3. Запускаем процесс генерации и симуляцию прогресса
progress_task = asyncio.create_task(self._simulate_progress(generation)) progress_task = asyncio.create_task(self._simulate_progress(generation))
try: try:
settings = await self.dao.settings.get_settings()
# Default to Image Generation (Gemini) if settings.use_ai_proxy:
generated_bytes_list, metrics = await generate_image_task( asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None
prompt=generation_prompt, # или request.prompt generated_images_io, metrics = await self.ai_proxy.generate_image(
media_group_bytes=media_group_bytes, prompt=generation_prompt,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request aspect_ratio=generation.aspect_ratio,
quality=generation.quality, quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
asset_urls=asset_urls
)
generated_bytes_list = []
for img_io in generated_images_io:
img_io.seek(0)
generated_bytes_list.append(img_io.read())
img_io.close()
else:
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
gemini=self.gemini gemini=self.gemini
) )
self._update_generation_metrics(generation, metrics)
# Update metrics from API (Common for both) # 3. Process results
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") created_assets = await self._process_generated_images(generation, generated_bytes_list)
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
except GoogleGenerationException as e: # 4. Finalize generation record
generation.status = GenerationStatus.FAILED await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC) # 5. Notify
await self.dao.generations.update_generation(generation) if generation.telegram_id and self.bot:
raise e await self._notify_telegram(generation, created_assets)
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
finally: finally:
if not progress_task.done(): if not progress_task.done():
progress_task.cancel() progress_task.cancel()
@@ -237,103 +221,53 @@ class GenerationService:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# 4. Сохраняем полученные изображения как новые Ассеты async def import_external_generation(self, external_gen) -> Generation:
created_assets: List[Asset] = [] external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
for idx, img_bytes in enumerate(generated_bytes_list): image_bytes = await self._fetch_external_image(external_gen)
# Generate thumbnail
thumbnail_bytes = None
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
# Save to S3 # Reuse internal processing logic
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png" new_asset = await self._save_asset(
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png") image_bytes=image_bytes,
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
new_asset = Asset( created_by=external_gen.created_by,
name=f"Generated_{generation.linked_character_id}", project_id=external_gen.project_id,
type=AssetType.GENERATED, linked_char_id=external_gen.linked_character_id,
content_type=AssetContentType.IMAGE, folder="external"
linked_char_id=generation.linked_character_id,
data=None, # Not storing bytes in DB anymore
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes
) )
# Сохраняем в БД generation = Generation(
asset_id = await self.dao.assets.create_asset(new_asset) status=GenerationStatus.DONE,
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы linked_character_id=external_gen.linked_character_id,
aspect_ratio=external_gen.aspect_ratio,
created_assets.append(new_asset) quality=external_gen.quality,
prompt=external_gen.prompt,
# 5. (Опционально) Обновляем запись генерации ссылками на результаты model=external_gen.model,
# Предполагаем, что у модели Generation есть поле result_asset_ids tech_prompt=external_gen.tech_prompt,
result_ids = [a.id for a in created_assets] seed=external_gen.seed,
result_list=[new_asset.id],
generation.result_list = result_ids result=new_asset.id,
generation.status = GenerationStatus.DONE progress=100,
generation.progress = 100 nsfw=external_gen.nsfw,
generation.updated_at = datetime.now(UTC) execution_time_seconds=external_gen.execution_time_seconds,
generation.tech_prompt = generation_prompt api_execution_time_seconds=external_gen.api_execution_time_seconds,
token_usage=external_gen.token_usage,
end_time = datetime.now() input_token_usage=external_gen.input_token_usage,
generation.execution_time_seconds = (end_time - start_time).total_seconds() output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}") project_id=external_gen.project_id
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
# 6. Send to Telegram if telegram_id is provided
if generation.telegram_id and self.bot:
try:
for asset in created_assets:
if asset.data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
caption=f"Generated from prompt: {generation.prompt[:100]}..."
) )
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
except Exception as e:
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
async def _simulate_progress(self, generation: Generation):
"""
Increments progress from 0 to 90 over ~20 seconds.
"""
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
# Random increment between 5 and 15
increment = random.randint(5, 15)
current_progress = min(current_progress + increment, 90)
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
# But for simplicity here we just use the object we have and save it.
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
# Assuming simple update is fine for now.
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
# Task cancelled, generation finished (or failed)
pass
except Exception as e:
logger.error(f"Error in progress simulation: {e}")
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
return generation
async def delete_generation(self, generation_id: str) -> bool: async def delete_generation(self, generation_id: str) -> bool:
"""
Soft delete generation by marking it as deleted.
"""
try: try:
generation = await self.dao.generations.get_generation(generation_id) generation = await self.dao.generations.get_generation(generation_id)
if not generation: if not generation:
return False return False
generation.is_deleted = True generation.is_deleted = True
generation.updated_at = datetime.now(UTC) generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation) await self.dao.generations.update_generation(generation)
@@ -341,3 +275,222 @@ class GenerationService:
except Exception as e: except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}") logger.error(f"Error deleting generation {generation_id}: {e}")
return False return False
async def cleanup_stale_generations(self):
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
if count > 0:
logger.info(f"Cleaned up {count} stale generations")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 30):
try:
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
if asset_ids:
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user, by_project = None, None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
# --- Private Helpers ---
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
try:
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
gen_model.created_by = user_id
gen_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(gen_model)
gen_model.id = gen_id
asyncio.create_task(self._queued_generation_runner(gen_model))
return GenerationResponse(**gen_model.model_dump())
except Exception:
logger.exception("Failed to initiate single generation")
raise
async def _queued_generation_runner(self, gen: Generation):
logger.info(f"Generation {gen.id} waiting for slot...")
try:
async with generation_semaphore:
await self.create_generation(gen)
except Exception as e:
await self._handle_generation_failure(gen, e)
logger.exception(f"Background generation task failed for ID: {gen.id}")
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]:
media_group_bytes: List[bytes] = []
prompt = generation.prompt
asset_ids = []
# 1. Character Avatar
if generation.linked_character_id:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if not char_info:
raise ValueError(f"Character {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
asset_ids.append(char_info.avatar_asset_id)
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True)
if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# 2. Reference Assets
if generation.assets_list:
asset_ids.extend(generation.assets_list)
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in assets:
data = await self._load_asset_image_data(asset)
if data: media_group_bytes.append(data)
# 3. Environment Assets
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
asset_ids.extend(env.asset_ids)
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
data = await self._load_asset_image_data(asset)
if data: media_group_bytes.append(data)
if media_group_bytes:
prompt += (
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
)
return media_group_bytes, prompt, asset_ids
async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]:
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
urls = []
for asset in assets:
if asset.minio_object_name:
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
urls.append(f"{bucket}/{asset.minio_object_name}")
return urls
async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]:
"""Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids)."""
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.data:
return asset.data
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return None
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
logger.error(f"Generation {generation.id} failed: {error}")
generation.status = GenerationStatus.FAILED
# Don't overwrite if reason is already set, unless a new error is provided
if error:
generation.failed_reason = str(error)
elif not generation.failed_reason:
generation.failed_reason = "Unknown error"
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
created_assets = []
for img_bytes in bytes_list:
asset = await self._save_asset(
image_bytes=img_bytes,
name=f"Generated_{generation.linked_character_id}",
created_by=generation.created_by,
project_id=generation.project_id,
linked_char_id=generation.linked_character_id,
folder="generated"
)
created_assets.append(asset)
return created_assets
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
new_asset = Asset(
name=name,
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
return new_asset
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
generation.result_list = [a.id for a in assets]
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = tech_prompt
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
try:
for asset in assets:
# Need to get data for telegram if it's not in Asset object
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
if img_data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
caption=f"Generated from: {generation.prompt[:100]}..."
)
except Exception as e:
logger.error(f"Failed to send to Telegram: {e}")
async def _simulate_progress(self, generation: Generation):
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
current_progress = min(current_progress + random.randint(5, 15), 90)
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
pass
async def _fetch_external_image(self, external_gen) -> bytes:
if external_gen.image_url:
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
return response.content
elif external_gen.image_data:
return base64.b64decode(external_gen.image_data)
raise ValueError("No image source provided")

View File

@@ -0,0 +1,82 @@
from typing import List, Optional
from datetime import datetime
from repos.dao import DAO
from models.Idea import Idea
class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str, inspiration_id: Optional[str] = None) -> Idea:
idea = Idea(
name=name,
description=description,
project_id=project_id,
created_by=user_id,
inspiration_id=inspiration_id
)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
return await self.dao.ideas.get_ideas(project_id, user_id, limit, offset)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None, inspiration_id: Optional[str] = None) -> Optional[Idea]:
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return None
if name is not None:
idea.name = name
if description is not None:
idea.description = description
if inspiration_id is not None:
idea.inspiration_id = inspiration_id
idea.updated_at = datetime.now()
await self.dao.ideas.update_idea(idea)
return idea
async def delete_idea(self, idea_id: str) -> bool:
return await self.dao.ideas.delete_idea(idea_id)
async def add_generation_to_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Link
gen.idea_id = idea_id
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
async def remove_generation_from_idea(self, idea_id: str, generation_id: str) -> bool:
# Verify idea exists (optional, but good for validation)
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return False
# Get generation
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
# Unlink only if currently linked to this idea
if gen.idea_id == idea_id:
gen.idea_id = None
gen.updated_at = datetime.now()
await self.dao.generations.update_generation(gen)
return True
return False

View File

@@ -0,0 +1,146 @@
import asyncio
import logging
import os
import tempfile
from datetime import datetime
from typing import List, Optional, Tuple
import httpx
from fastapi import HTTPException
from models.Asset import Asset, AssetType, AssetContentType
from models.Inspiration import Inspiration
from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
# Try to import yt_dlp, but don't crash if it's missing (though we added it to requirements)
try:
import yt_dlp
except ImportError:
yt_dlp = None
logger = logging.getLogger(__name__)
class InspirationService:
def __init__(self, dao: DAO, s3_adapter: S3Adapter):
self.dao = dao
self.s3_adapter = s3_adapter
async def create_inspiration(self, source_url: str, created_by: str, project_id: Optional[str] = None, caption: Optional[str] = None) -> Inspiration:
# 1. Download content from Instagram
try:
content_bytes, content_type, ext = await self._download_content(source_url)
except Exception as e:
logger.error(f"Failed to download content from {source_url}: {e}")
raise HTTPException(status_code=400, detail=f"Failed to download content: {str(e)}")
# 2. Save as Asset
filename = f"inspirations/{datetime.now().strftime('%Y%m%d_%H%M%S')}_insta.{ext}"
await self.s3_adapter.upload_file(filename, content_bytes, content_type=content_type)
asset = Asset(
name=f"Inspiration from {source_url}",
type=AssetType.INSPIRATION,
content_type=AssetContentType.VIDEO if content_type.startswith("video") else AssetContentType.IMAGE,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(asset)
# 3. Create Inspiration object
inspiration = Inspiration(
source_url=source_url,
caption=caption,
asset_id=str(asset_id),
created_by=created_by,
project_id=project_id
)
insp_id = await self.dao.inspirations.create_inspiration(inspiration)
inspiration.id = insp_id
return inspiration
async def get_inspirations(self, project_id: Optional[str], created_by: str, limit: int = 20, offset: int = 0) -> List[Inspiration]:
return await self.dao.inspirations.get_inspirations(project_id, created_by, limit, offset)
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
return await self.dao.inspirations.get_inspiration(inspiration_id)
async def mark_as_completed(self, inspiration_id: str, is_completed: bool = True) -> Optional[Inspiration]:
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
if not inspiration:
return None
inspiration.is_completed = is_completed
inspiration.updated_at = datetime.now()
await self.dao.inspirations.update_inspiration(inspiration)
return inspiration
async def delete_inspiration(self, inspiration_id: str) -> bool:
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
if not inspiration:
return False
# Delete associated asset
if inspiration.asset_id:
await self.dao.assets.delete_asset(inspiration.asset_id)
return await self.dao.inspirations.delete_inspiration(inspiration_id)
async def _download_content(self, url: str) -> Tuple[bytes, str, str]:
"""
Downloads content using yt-dlp.
Returns (content_bytes, content_type, extension)
"""
if not yt_dlp:
raise RuntimeError("yt-dlp is not installed")
logger.info(f"Downloading from {url} using yt-dlp...")
def run_yt_dlp():
with tempfile.TemporaryDirectory() as tmpdirname:
ydl_opts = {
'outtmpl': f'{tmpdirname}/%(id)s.%(ext)s',
'quiet': True,
'no_warnings': True,
'format': 'best', # Best quality single file
'noplaylist': True, # Only single video if it's a playlist/profile
'writethumbnail': False,
'writesubtitles': False,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
# Find the downloaded file
files = os.listdir(tmpdirname)
if not files:
raise Exception("No files downloaded")
# Pick the largest file if multiple (e.g. if yt-dlp downloaded parts)
# But with 'format': 'best', it should be one.
# If carousel, it might be multiple. Let's pick the first one.
filename = files[0]
filepath = os.path.join(tmpdirname, filename)
with open(filepath, 'rb') as f:
data = f.read()
ext = filename.split('.')[-1].lower()
# Determine content type
if ext in ['mp4', 'mov', 'avi', 'mkv', 'webm']:
content_type = f"video/{ext}"
if ext == 'mov': content_type = "video/quicktime"
elif ext in ['jpg', 'jpeg', 'png', 'webp']:
content_type = f"image/{ext}"
if ext == 'jpg': content_type = "image/jpeg"
else:
content_type = "application/octet-stream"
return data, content_type, ext
return await asyncio.to_thread(run_yt_dlp)

View File

@@ -0,0 +1,79 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

49
config.py Normal file
View File

@@ -0,0 +1,49 @@
import os
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# Telegram Bot
BOT_TOKEN: str
ADMIN_ID: int = 0
# AI Service
GEMINI_API_KEY: str
# Database
MONGO_HOST: str = "mongodb://localhost:27017"
DB_NAME: str = "my_bot_db"
# S3 Storage (Minio)
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "ai-char"
# External API
EXTERNAL_API_SECRET: Optional[str] = None
# Daily Scheduler
SCHEDULER_CHARACTER_ID: Optional[str] = None # Character ID used for daily generation
# Meta Platform (Instagram Graph API)
META_ACCESS_TOKEN: Optional[str] = None # Long-lived page/Instagram access token
META_INSTAGRAM_ACCOUNT_ID: Optional[str] = None # Instagram Business Account ID
# AI Proxy Security
PROXY_SECRET_SALT: str = "AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o"
# JWT Security
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
model_config = SettingsConfigDict(
env_file=os.getenv("ENV_FILE", ".env"),
env_file_encoding="utf-8",
extra="ignore"
)
settings = Settings()

View File

@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
# Ждем сбора остальных частей # Ждем сбора остальных частей
await asyncio.sleep(self.latency) await asyncio.sleep(self.latency)
# Проверяем, что ключ все еще существует (на всякий случай) # Проверяем, что ключ все еще существует
if group_id in self.album_data: if group_id in self.album_data:
# Передаем собранный альбом в хендлер # Передаем собранный альбом в хендлер
# Сортируем по message_id, чтобы порядок был верным # Сортируем по message_id, чтобы порядок был верным
self.album_data[group_id].sort(key=lambda x: x.message_id) current_album = self.album_data[group_id]
data["album"] = self.album_data[group_id] current_album.sort(key=lambda x: x.message_id)
data["album"] = current_album
return await handler(event, data) return await handler(event, data)
finally: finally:
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись # ЧИСТКА: Удаляем запись после обработки или таймаута
# Проверяем, что мы удаляем именно то, что создали, и ключ существует # Используем pop() с дефолтом, чтобы избежать KeyError
if group_id in self.album_data and self.album_data[group_id][0] == event: self.album_data.pop(group_id, None)
del self.album_data[group_id]
else: else:
# Если группа уже собирается - просто добавляем и выходим # Если группа уже собирается - просто добавляем и выходим

BIN
models/.DS_Store vendored

Binary file not shown.

11
models/Album.py Normal file
View File

@@ -0,0 +1,11 @@
from datetime import datetime, UTC
from pydantic import BaseModel, Field
class Album(BaseModel):
id: str | None = None
name: str
description: str | None = None
cover_asset_id: str | None = None
generation_ids: list[str] = []
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,6 +1,6 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from enum import Enum from enum import Enum
from typing import Optional, Any, List from typing import Any
from pydantic import BaseModel, computed_field, Field, model_validator from pydantic import BaseModel, computed_field, Field, model_validator
@@ -8,26 +8,31 @@ from pydantic import BaseModel, computed_field, Field, model_validator
class AssetContentType(str, Enum): class AssetContentType(str, Enum):
IMAGE = 'image' IMAGE = 'image'
PROMPT = 'prompt' PROMPT = 'prompt'
VIDEO = 'video'
class AssetType(str, Enum): class AssetType(str, Enum):
UPLOADED = 'uploaded' UPLOADED = 'uploaded'
GENERATED = 'generated' GENERATED = 'generated'
INSPIRATION = 'inspiration'
class Asset(BaseModel): class Asset(BaseModel):
id: Optional[str] = None id: str | None = None
name: str name: str
type: AssetType = AssetType.GENERATED type: AssetType = AssetType.GENERATED
content_type: AssetContentType = AssetContentType.IMAGE content_type: AssetContentType = AssetContentType.IMAGE
linked_char_id: Optional[str] = None linked_char_id: str | None = None
data: Optional[bytes] = None data: bytes | None = None
tg_doc_file_id: Optional[str] = None tg_doc_file_id: str | None = None
tg_photo_file_id: Optional[str] = None tg_photo_file_id: str | None = None
minio_object_name: Optional[str] = None minio_object_name: str | None = None
minio_bucket: Optional[str] = None minio_bucket: str | None = None
minio_thumbnail_object_name: Optional[str] = None minio_thumbnail_object_name: str | None = None
thumbnail: Optional[bytes] = None thumbnail: bytes | None = None
tags: List[str] = [] tags: list[str] = []
created_by: str | None = None
project_id: str | None = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@@ -60,6 +65,7 @@ class Asset(BaseModel):
# --- CALCULATED FIELD --- # --- CALCULATED FIELD ---
@computed_field @computed_field
@property
def url(self) -> str: def url(self) -> str:
""" """
Это поле автоматически вычислится и попадет в model_dump() / .json() Это поле автоматически вычислится и попадет в model_dump() / .json()

View File

@@ -1,15 +1,15 @@
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_core.core_schema import computed_field from pydantic_core.core_schema import computed_field
class Character(BaseModel): class Character(BaseModel):
id: str | None id: str | None = None
name: str name: str
avatar_image: Optional[str] = None avatar_asset_id: str | None = None
character_image_data: Optional[bytes] = None avatar_image: str | None = None
character_image_doc_tg_id: str character_image_doc_tg_id: str | None = None
character_image_tg_id: str | None character_image_tg_id: str | None = None
character_bio: str character_bio: str | None = None
created_by: str | None = None
project_id: str | None = None

19
models/Environment.py Normal file
View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
from bson import ObjectId
class Environment(BaseModel):
id: str | None = Field(None, alias="_id")
character_id: str
name: str = Field(..., min_length=1)
description: str | None = None
asset_ids: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
model_config = ConfigDict(
populate_by_name=True,
json_encoders={ObjectId: str},
arbitrary_types_allowed=True
)

View File

@@ -1,11 +1,9 @@
from datetime import datetime, UTC from datetime import datetime, UTC
from enum import Enum from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, computed_field
from models.Asset import Asset from models.enums import AspectRatios, Quality
from models.enums import AspectRatios, Quality, GenType
class GenerationStatus(str, Enum): class GenerationStatus(str, Enum):
@@ -14,25 +12,43 @@ class GenerationStatus(str, Enum):
FAILED = "failed" FAILED = "failed"
class Generation(BaseModel): class Generation(BaseModel):
id: Optional[str] = None id: str | None = None
status: GenerationStatus = GenerationStatus.RUNNING status: GenerationStatus = GenerationStatus.RUNNING
failed_reason: Optional[str] = None failed_reason: str | None = None
linked_character_id: Optional[str] = None linked_character_id: str | None = None
telegram_id: Optional[int] = None telegram_id: int | None = None
use_profile_image: bool = True use_profile_image: bool = True
aspect_ratio: AspectRatios aspect_ratio: AspectRatios
quality: Quality quality: Quality
prompt: str prompt: str
tech_prompt: Optional[str] = None model: str | None = None
assets_list: List[str] = Field(default_factory=list) seed: int | None = None
result_list: List[str] = Field(default_factory=list) tech_prompt: str | None = None
result: Optional[str] = None assets_list: list[str] = Field(default_factory=list)
result_list: list[str] = Field(default_factory=list)
result: str | None = None
progress: int = 0 progress: int = 0
execution_time_seconds: Optional[float] = None execution_time_seconds: float | None = None
api_execution_time_seconds: Optional[float] = None api_execution_time_seconds: float | None = None
token_usage: Optional[int] = None token_usage: int | None = None
input_token_usage: Optional[int] = None input_token_usage: int | None = None
output_token_usage: Optional[int] = None output_token_usage: int | None = None
is_deleted: bool = False is_deleted: bool = False
album_id: str | None = None
environment_id: str | None = None
generation_group_id: str | None = None
created_by: str | None = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: str | None = None
idea_id: str | None = None
liked_by: list[str] = Field(default_factory=list)
nsfw: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@computed_field
def cost(self) -> float:
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
cost_input = self.input_token_usage * 0.000002
cost_output = self.output_token_usage * 0.00012
return round(cost_input + cost_output, 3)
return 0.0

13
models/Idea.py Normal file
View File

@@ -0,0 +1,13 @@
from datetime import datetime
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: str | None = None
name: str = "New Idea"
description: str | None = None
project_id: str | None = None
inspiration_id: str | None = None # Link to Inspiration
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)
updated_at: datetime = Field(default_factory=datetime.now)

15
models/Inspiration.py Normal file
View File

@@ -0,0 +1,15 @@
from datetime import datetime, UTC
from pydantic import BaseModel, Field
class Inspiration(BaseModel):
id: str | None = None
source_url: str
caption: str | None = None
asset_id: str
is_completed: bool = False
created_by: str
project_id: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

22
models/Post.py Normal file
View File

@@ -0,0 +1,22 @@
from datetime import datetime, timezone, UTC
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: str | None = None
date: datetime
topic: str
generation_ids: list[str] = Field(default_factory=list)
project_id: str | None = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

11
models/Project.py Normal file
View File

@@ -0,0 +1,11 @@
from datetime import datetime
from pydantic import BaseModel, Field
class Project(BaseModel):
id: str | None = None
name: str
description: str | None = None
owner_id: str
members: list[str] = [] # List of User IDs
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

10
models/Settings.py Normal file
View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel, Field
from datetime import datetime, UTC
class SystemSettings(BaseModel):
id: str = Field(default="system_settings", alias="_id")
use_ai_proxy: bool = False
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
class Config:
populate_by_name = True

View File

@@ -2,19 +2,30 @@ from enum import Enum
class AspectRatios(str, Enum): class AspectRatios(str, Enum):
NINESIXTEEN = "NINESIXTEEN" ONEONE = "1:1"
SIXTEENNINE = "SIXTEENNINE" TWOTHREE = "2:3"
THREEFOUR = "THREEFOUR" THREETWO = "3:2"
FOURTHREE = "FOURTHREE" THREEFOUR = "3:4"
FOURTHREE = "4:3"
FOURFIVE = "4:5"
FIVEFOUR = "5:4"
NINESIXTEEN = "9:16"
SIXTEENNINE = "16:9"
TWENTYONENINE = "21:9"
@classmethod
def _missing_(cls, value):
mapping = {
"NINESIXTEEN": cls.NINESIXTEEN,
"SIXTEENNINE": cls.SIXTEENNINE,
"THREEFOUR": cls.THREEFOUR,
"FOURTHREE": cls.FOURTHREE,
}
return mapping.get(value)
@property @property
def value_ratio(self) -> str: def value_ratio(self) -> str:
return { return self.value
AspectRatios.NINESIXTEEN: "9:16",
AspectRatios.SIXTEENNINE: "16:9",
AspectRatios.THREEFOUR: "3:4",
AspectRatios.FOURTHREE: "4:3",
}[self]
class Quality(str, Enum): class Quality(str, Enum):
@@ -41,3 +52,20 @@ class GenType(str, Enum):
GenType.TEXT: 'Text', GenType.TEXT: 'Text',
GenType.IMAGE: 'Image', GenType.IMAGE: 'Image',
}[self] }[self]
class TextModel(str, Enum):
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"
@property
def value_model(self) -> str:
return self.value
class ImageModel(str, Enum):
GEMINI_3_PRO_IMAGE_PREVIEW = "gemini-3-pro-image-preview"
GEMINI_3_1_FLASH_IMAGE_PREVIEW = "gemini-3.1-flash-image-preview"
@property
def value_model(self) -> str:
return self.value

BIN
repos/.DS_Store vendored

Binary file not shown.

61
repos/albums_repo.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import List, Optional
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Album import Album
logger = logging.getLogger(__name__)
class AlbumsRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["albums"]
async def create_album(self, album: Album) -> str:
res = await self.collection.insert_one(album.model_dump())
return str(res.inserted_id)
async def get_album(self, album_id: str) -> Optional[Album]:
try:
res = await self.collection.find_one({"_id": ObjectId(album_id)})
if not res:
return None
res["id"] = str(res.pop("_id"))
return Album(**res)
except Exception:
return None
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
albums = []
for doc in res:
doc["id"] = str(doc.pop("_id"))
albums.append(Album(**doc))
return albums
async def update_album(self, album_id: str, album: Album) -> bool:
if not album.id:
album.id = album_id
model_dump = album.model_dump()
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
return res.modified_count > 0
async def delete_album(self, album_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
return res.deleted_count > 0
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
)
return res.modified_count > 0
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$pull": {"generation_ids": generation_id}}
)
return res.modified_count > 0

View File

@@ -1,6 +1,8 @@
from typing import List, Optional from typing import Any, List, Optional
import logging import logging
from datetime import datetime, UTC
from bson import ObjectId from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset from models.Asset import Asset
@@ -19,7 +21,8 @@ class AssetsRepo:
# Main data # Main data
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}" uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
uploaded = await self.s3.upload_file(object_name, asset.data) uploaded = await self.s3.upload_file(object_name, asset.data)
if uploaded: if uploaded:
@@ -32,7 +35,8 @@ class AssetsRepo:
# Thumbnail # Thumbnail
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail) uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
if uploaded_thumb: if uploaded_thumb:
@@ -46,8 +50,8 @@ class AssetsRepo:
res = await self.collection.insert_one(asset.model_dump()) res = await self.collection.insert_one(asset.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]: async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {} filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
if asset_type: if asset_type:
filter["type"] = asset_type filter["type"] = asset_type
args = {} args = {}
@@ -70,6 +74,12 @@ class AssetsRepo:
# if not with_data: args["data"] = 0; args["thumbnail"] = 0 # if not with_data: args["data"] = 0; args["thumbnail"] = 0
# So list DOES NOT return thumbnails by default. # So list DOES NOT return thumbnails by default.
args["thumbnail"] = 0 args["thumbnail"] = 0
if created_by:
filter["created_by"] = created_by
filter['project_id'] = None
if project_id:
filter["project_id"] = project_id
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None) res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
assets = [] assets = []
@@ -92,7 +102,7 @@ class AssetsRepo:
return assets return assets
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset: async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]:
projection = None projection = None
if not with_data: if not with_data:
projection = {"data": 0, "thumbnail": 0} projection = {"data": 0, "thumbnail": 0}
@@ -128,7 +138,8 @@ class AssetsRepo:
if self.s3: if self.s3:
if asset.data: if asset.data:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}" uid = uuid4().hex[:8]
object_name = f"{asset.type.value}/{ts}_{uid}_{asset.name}"
if await self.s3.upload_file(object_name, asset.data): if await self.s3.upload_file(object_name, asset.data):
asset.minio_object_name = object_name asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name asset.minio_bucket = self.s3.bucket_name
@@ -136,7 +147,8 @@ class AssetsRepo:
if asset.thumbnail: if asset.thumbnail:
ts = int(asset.created_at.timestamp()) ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{asset.type.value}/thumbs/{ts}_{uid}_{asset.name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, asset.thumbnail): if await self.s3.upload_file(thumb_name, asset.thumbnail):
asset.minio_thumbnail_object_name = thumb_name asset.minio_thumbnail_object_name = thumb_name
asset.thumbnail = None asset.thumbnail = None
@@ -157,11 +169,22 @@ class AssetsRepo:
assets.append(Asset(**doc)) assets.append(Asset(**doc))
return assets return assets
async def get_asset_count(self, character_id: Optional[str] = None) -> int: async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {}) filter = {}
if character_id:
filter["linked_char_id"] = character_id
if created_by:
filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id:
filter["project_id"] = project_id
return await self.collection.count_documents(filter)
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids] object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)]
if not object_ids:
return []
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small? res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
# Original excluded thumbnail too. # Original excluded thumbnail too.
assets = [] assets = []
@@ -184,6 +207,61 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)}) res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0 return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict: async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO.""" """Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3: if not self.s3:
@@ -203,7 +281,8 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
object_name = f"{type_}/{ts}_{asset_id}_{name}" uid = uuid4().hex[:8]
object_name = f"{type_}/{ts}_{uid}_{asset_id}_{name}"
if await self.s3.upload_file(object_name, data): if await self.s3.upload_file(object_name, data):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},
@@ -230,7 +309,8 @@ class AssetsRepo:
created_at = doc.get("created_at") created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0 ts = int(created_at.timestamp()) if created_at else 0
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg" uid = uuid4().hex[:8]
thumb_name = f"{type_}/thumbs/{ts}_{uid}_{asset_id}_{name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, thumb): if await self.s3.upload_file(thumb_name, thumb):
await self.collection.update_one( await self.collection.update_one(
{"_id": asset_id}, {"_id": asset_id},

View File

@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
@@ -12,32 +12,37 @@ class CharacterRepo:
async def add_character(self, character: Character) -> Character: async def add_character(self, character: Character) -> Character:
op = await self.collection.insert_one(character.model_dump()) op = await self.collection.insert_one(character.model_dump())
character.id = op.inserted_id character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None: async def get_character(self, character_id: str) -> Character | None:
args = {} res = await self.collection.find_one({"_id": ObjectId(character_id)})
if not with_image_data:
args["character_image_data"] = 0
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
if res is None: if res is None:
return None return None
else: else:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None) filter = {}
if created_by:
filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id:
filter["project_id"] = project_id
characters = [] res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
for doc in docs: chars = []
# Конвертируем ObjectId в строку и кладем в поле id for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))
chars.append(Character(**doc))
return chars
# Создаем объект async def update_char(self, char_id: str, character: Character) -> bool:
characters.append(Character(**doc)) result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
return result.modified_count > 0
return characters async def delete_character(self, char_id: str) -> bool:
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
async def update_char(self, char_id: str, character: Character) -> None: return result.deleted_count > 0
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})

View File

@@ -4,6 +4,13 @@ from repos.assets_repo import AssetsRepo
from repos.char_repo import CharacterRepo from repos.char_repo import CharacterRepo
from repos.generation_repo import GenerationRepo from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from repos.environment_repo import EnvironmentRepo
from repos.inspiration_repo import InspirationRepo
from repos.settings_repo import SettingsRepo
from typing import Optional from typing import Optional
@@ -14,3 +21,11 @@ class DAO:
self.chars = CharacterRepo(client, db_name) self.chars = CharacterRepo(client, db_name)
self.assets = AssetsRepo(client, s3_adapter, db_name) self.assets = AssetsRepo(client, s3_adapter, db_name)
self.generations = GenerationRepo(client, db_name) self.generations = GenerationRepo(client, db_name)
self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)
self.environments = EnvironmentRepo(client, db_name)
self.inspirations = InspirationRepo(client, db_name)
self.settings = SettingsRepo(client, db_name)

73
repos/environment_repo.py Normal file
View File

@@ -0,0 +1,73 @@
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Environment import Environment
class EnvironmentRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["environments"]
async def create_env(self, env: Environment) -> Environment:
env_dict = env.model_dump(exclude={"id"})
res = await self.collection.insert_one(env_dict)
env.id = str(res.inserted_id)
return env
async def get_env(self, env_id: str) -> Optional[Environment]:
res = await self.collection.find_one({"_id": ObjectId(env_id)})
if not res:
return None
res["id"] = str(res.pop("_id"))
return Environment(**res)
async def get_character_envs(self, character_id: str) -> List[Environment]:
cursor = self.collection.find({"character_id": character_id})
envs = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
envs.append(Environment(**doc))
return envs
async def update_env(self, env_id: str, update_data: dict) -> bool:
update_data["updated_at"] = datetime.utcnow()
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{"$set": update_data}
)
return res.modified_count > 0
async def delete_env(self, env_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
return res.deleted_count > 0
async def add_asset(self, env_id: str, asset_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$addToSet": {"asset_ids": asset_id},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$addToSet": {"asset_ids": {"$each": asset_ids}},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$pull": {"asset_ids": asset_id},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0

View File

@@ -1,4 +1,5 @@
from typing import Optional, List from typing import Any, Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset from PIL.ImageChops import offset
from bson import ObjectId from bson import ObjectId
@@ -16,7 +17,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump()) res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]: async def get_generation(self, generation_id: str) -> Generation | None:
res = await self.collection.find_one({"_id": ObjectId(generation_id)}) res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None: if res is None:
return None return None
@@ -25,14 +26,32 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10) -> List[Generation]: limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None,
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False} filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None: if character_id is not None:
filter["linked_character_id"] = character_id filter["linked_character_id"] = character_id
if status is not None: if status is not None:
filter["status"] = status filter["status"] = status
res = await self.collection.find(filter).sort("created_at", -1).skip( if created_by is not None:
filter["created_by"] = created_by
# If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None.
# But if project_id is passed, we filter by that.
if project_id is None:
filter["project_id"] = None
if project_id is not None:
filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
if only_liked_by is not None:
filter["liked_by"] = only_liked_by
# If fetching for an idea, sort by created_at ascending (cronological)
# Otherwise typically descending (newest first)
sort_order = 1 if idea_id else -1
res = await self.collection.find(filter).sort("created_at", sort_order).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
for generation in res: for generation in res:
@@ -40,13 +59,259 @@ class GenerationRepo:
generations.append(Generation(**generation)) generations.append(Generation(**generation))
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int: async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None,
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
if status is not None: if status is not None:
args["status"] = status args["status"] = status
if created_by is not None:
args["created_by"] = created_by
if project_id is None:
args["project_id"] = None
if project_id is not None:
args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
if only_liked_by is not None:
args["liked_by"] = only_liked_by
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
generations: List[Generation] = []
# Maintain order of generation_ids
gen_map = {str(doc["_id"]): doc for doc in res}
for gen_id in generation_ids:
doc = gen_map.get(gen_id)
if doc:
doc["id"] = str(doc.pop("_id"))
generations.append(Generation(**doc))
return generations
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
"""
Toggles like for a user on a generation.
Returns True if liked, False if unliked, None if generation not found.
"""
if not ObjectId.is_valid(generation_id):
return None
oid = ObjectId(generation_id)
# Check if generation exists
gen = await self.collection.find_one({"_id": oid}, {"liked_by": 1})
if not gen:
return None
if user_id in gen.get("liked_by", []):
# Unlike
await self.collection.update_one(
{"_id": oid},
{"$pull": {"liked_by": user_id}}
)
return False
else:
# Like
await self.collection.update_one(
{"_id": oid},
{"$addToSet": {"liked_by": user_id}}
)
return True
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
if not ObjectId.is_valid(generation_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(generation_id)},
{"$set": {"nsfw": is_nsfw}}
)
return res.modified_count > 0
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
"""
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
Includes even soft-deleted generations to reflect actual expenditure.
"""
pipeline = []
# 1. Match all done generations (including soft-deleted)
match_stage = {"status": GenerationStatus.DONE}
if created_by:
match_stage["created_by"] = created_by
if project_id:
match_stage["project_id"] = project_id
pipeline.append({"$match": match_stage})
# 2. Group by null (total)
pipeline.append({
"$group": {
"_id": None,
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(1)
if not res:
return {
"total_runs": 0,
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cost": 0.0
}
result = res[0]
result.pop("_id")
result["total_cost"] = round(result["total_cost"], 4)
return result
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
"""
Returns usage statistics grouped by user or project.
Includes even soft-deleted generations to reflect actual expenditure.
"""
pipeline = []
match_stage = {"status": GenerationStatus.DONE}
if project_id:
match_stage["project_id"] = project_id
if created_by:
match_stage["created_by"] = created_by
pipeline.append({"$match": match_stage})
pipeline.append({
"$group": {
"_id": f"${group_by}",
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
pipeline.append({"$sort": {"total_cost": -1}})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(None)
results = []
for item in res:
entity_id = item.pop("_id")
item["total_cost"] = round(item["total_cost"], 4)
results.append({
"entity_id": str(entity_id) if entity_id else "unknown",
"stats": item
})
return results
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int:
cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes)
res = await self.collection.update_many(
{
"status": GenerationStatus.RUNNING,
"created_at": {"$lt": cutoff_time}
},
{
"$set": {
"status": GenerationStatus.FAILED,
"failed_reason": "Timeout: Execution time limit exceeded",
"updated_at": datetime.now(UTC)
}
}
)
return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

91
repos/idea_repo.py Normal file
View File

@@ -0,0 +1,91 @@
from typing import Optional, List
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Idea import Idea
class IdeaRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["ideas"]
async def create_idea(self, idea: Idea) -> str:
res = await self.collection.insert_one(idea.model_dump())
return str(res.inserted_id)
async def get_idea(self, idea_id: str) -> Optional[Idea]:
if not ObjectId.is_valid(idea_id):
return None
res = await self.collection.find_one({"_id": ObjectId(idea_id)})
if res:
res["id"] = str(res.pop("_id"))
return Idea(**res)
return None
async def get_ideas(self, project_id: Optional[str], user_id: str, limit: int = 20, offset: int = 0) -> List[dict]:
if project_id:
match_stage = {"project_id": project_id, "is_deleted": False}
else:
match_stage = {"created_by": user_id, "project_id": None, "is_deleted": False}
pipeline = [
{"$match": match_stage},
{"$sort": {"updated_at": -1}},
{"$skip": offset},
{"$limit": limit},
# Add string id field for lookup
{"$addFields": {"str_id": {"$toString": "$_id"}}},
# Lookup generations
{
"$lookup": {
"from": "generations",
"let": {"idea_id": "$str_id"},
"pipeline": [
{
"$match": {
"$and": [
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
{"status": "done"},
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
{"is_deleted": False}
]
}
},
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
{"$limit": 1}
],
"as": "generations"
}
},
# Unwind generations array (preserve ideas without generations)
{"$unwind": {"path": "$generations", "preserveNullAndEmptyArrays": True}},
# Rename for clarity
{"$addFields": {
"last_generation": "$generations",
"id": "$str_id"
}},
{"$project": {"generations": 0, "str_id": 0, "_id": 0}}
]
return await self.collection.aggregate(pipeline).to_list(None)
async def delete_idea(self, idea_id: str) -> bool:
if not ObjectId.is_valid(idea_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(idea_id)},
{"$set": {"is_deleted": True}}
)
return res.modified_count > 0
async def update_idea(self, idea: Idea) -> bool:
if not idea.id or not ObjectId.is_valid(idea.id):
return False
idea_dict = idea.model_dump()
if "id" in idea_dict:
del idea_dict["id"]
res = await self.collection.update_one(
{"_id": ObjectId(idea.id)},
{"$set": idea_dict}
)
return res.modified_count > 0

54
repos/inspiration_repo.py Normal file
View File

@@ -0,0 +1,54 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Inspiration import Inspiration
class InspirationRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["inspirations"]
async def create_inspiration(self, inspiration: Inspiration) -> str:
res = await self.collection.insert_one(inspiration.model_dump(exclude={"id"}))
return str(res.inserted_id)
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
res = await self.collection.find_one({"_id": ObjectId(inspiration_id)})
if res:
res["id"] = str(res.pop("_id"))
return Inspiration(**res)
return None
async def get_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None, limit: int = 20, offset: int = 0) -> List[Inspiration]:
query = {}
if project_id:
query["project_id"] = project_id
if created_by:
query["created_by"] = created_by
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
inspirations = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
inspirations.append(Inspiration(**doc))
return inspirations
async def count_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None) -> int:
query = {}
if project_id:
query["project_id"] = project_id
if created_by:
query["created_by"] = created_by
return await self.collection.count_documents(query)
async def update_inspiration(self, inspiration: Inspiration):
await self.collection.update_one(
{"_id": ObjectId(inspiration.id)},
{"$set": inspiration.model_dump(exclude={"id"})}
)
async def delete_inspiration(self, inspiration_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(inspiration_id)})
return res.deleted_count > 0

97
repos/post_repo.py Normal file
View File

@@ -0,0 +1,97 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0

62
repos/project_repo.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Project import Project
class ProjectRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["projects"]
async def create_project(self, project: Project) -> str:
res = await self.collection.insert_one(project.model_dump())
return str(res.inserted_id)
async def get_project(self, project_id: str) -> Optional[Project]:
if not ObjectId.is_valid(project_id):
return None
res = await self.collection.find_one({"_id": ObjectId(project_id)})
if res:
res["id"] = str(res.pop("_id"))
return Project(**res)
return None
async def get_projects_by_user(self, user_id: str) -> List[Project]:
# Find projects where user is owner OR in members
filter = {
"$or": [
{"owner_id": user_id},
{"members": user_id}
],
"is_deleted": False
}
cursor = self.collection.find(filter).sort("created_at", -1)
projects = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
projects.append(Project(**doc))
return projects
async def add_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$addToSet": {"members": user_id}}
)
return res.modified_count > 0
async def remove_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$pull": {"members": user_id}}
)
return res.modified_count > 0
async def update_project(self, project_id: str, updates: dict) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$set": updates}
)
return res.modified_count > 0
async def delete_project(self, project_id: str) -> bool:
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
return res.modified_count > 0

26
repos/settings_repo.py Normal file
View File

@@ -0,0 +1,26 @@
from typing import Optional
from motor.motor_asyncio import AsyncIOMotorClient
from models.Settings import SystemSettings
from datetime import datetime, UTC
class SettingsRepo:
def __init__(self, client: AsyncIOMotorClient, db_name: str):
self.collection = client[db_name]["settings"]
async def get_settings(self) -> SystemSettings:
doc = await self.collection.find_one({"_id": "system_settings"})
if not doc:
# Create default settings if not exists
settings = SystemSettings()
await self.collection.insert_one(settings.model_dump(by_alias=True))
return settings
return SystemSettings(**doc)
async def update_settings(self, settings: SystemSettings) -> bool:
settings.updated_at = datetime.now(UTC)
result = await self.collection.replace_one(
{"_id": "system_settings"},
settings.model_dump(by_alias=True),
upsert=True
)
return result.modified_count > 0 or result.upserted_id is not None

View File

@@ -19,10 +19,16 @@ class UsersRepo:
self.collection = client[db_name]["users"] self.collection = client[db_name]["users"]
async def get_user(self, user_id: int): async def get_user(self, user_id: int):
return await self.collection.find_one({"user_id": user_id}) user = await self.collection.find_one({"user_id": user_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_user_by_username(self, username: str): async def get_user_by_username(self, username: str):
return await self.collection.find_one({"username": username}) user = await self.collection.find_one({"username": username})
if user:
user["id"] = str(user["_id"])
return user
async def create_user(self, username: str, password: str, full_name: Optional[str] = None): async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
"""Создает нового пользователя с username/паролем""" """Создает нового пользователя с username/паролем"""
@@ -38,15 +44,23 @@ class UsersRepo:
"created_at": datetime.now(), "created_at": datetime.now(),
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать) "is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
"is_web_user": True, "is_web_user": True,
"is_admin": False "is_admin": False,
"project_ids": [],
"current_project_id": None
} }
result = await self.collection.insert_one(user_doc) result = await self.collection.insert_one(user_doc)
return await self.collection.find_one({"_id": result.inserted_id}) user = await self.collection.find_one({"_id": result.inserted_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_pending_users(self): async def get_pending_users(self):
"""Возвращает список пользователей со статусом PENDING""" """Возвращает список пользователей со статусом PENDING"""
cursor = self.collection.find({"status": UserStatus.PENDING}) cursor = self.collection.find({"status": UserStatus.PENDING})
return await cursor.to_list(length=100) users = await cursor.to_list(length=100)
for user in users:
user["id"] = str(user["_id"])
return users
async def approve_user(self, username: str): async def approve_user(self, username: str):
await self.collection.update_one( await self.collection.update_one(

View File

@@ -50,3 +50,6 @@ passlib[argon2]==1.7.4
python-jose[cryptography]==3.3.0 python-jose[cryptography]==3.3.0
python-multipart==0.0.22 python-multipart==0.0.22
email-validator email-validator
prometheus-fastapi-instrumentator
pydantic-settings==2.13.0
yt-dlp

View File

@@ -51,56 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
wait_msg = await message.answer("💾 Сохраняю персонажа...") wait_msg = await message.answer("💾 Сохраняю персонажа...")
try: try:
# ВОТ ТУТ скачиваем файл (прямо перед сохранением) # 1. Скачиваем файл (один раз)
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
file_io = await bot.download(file_id) file_io = await bot.download(file_id)
# photo_bytes = file_io.getvalue() # Получаем байты file_bytes = file_io.read()
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
# Создаем модель
char = Character( char = Character(
id=None, id=None,
name=name, name=name,
character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio character_bio=bio,
created_by=str(message.from_user.id)
) )
file_io.close()
# Сохраняем через DAO
# Сохраняем, чтобы получить ID
await dao.chars.add_character(char) await dao.chars.add_character(char)
file_info = await bot.get_file(char.character_image_doc_tg_id)
file_bytes = await bot.download_file(file_info.file_path) # 3. Создаем Asset (связанный с персонажем)
file_io = file_bytes.read() avatar_asset_id = await dao.assets.create_asset(
avatar_asset = await dao.assets.create_asset( Asset(
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io, name="avatar.png",
tg_doc_file_id=file_id)) type=AssetType.UPLOADED,
char.avatar_image = avatar_asset.link content_type=AssetContentType.IMAGE,
linked_char_id=str(char.id),
data=file_bytes,
tg_doc_file_id=file_id
)
)
# 4. Обновляем персонажа ссылками на ассет
char.avatar_asset_id = avatar_asset_id
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
# Отправляем подтверждение # Отправляем подтверждение
# Используем байты для отправки обратно
photo_msg = await message.answer_photo( photo_msg = await message.answer_photo(
photo=BufferedInputFile(file_io, photo=BufferedInputFile(file_bytes, filename="char.jpg"),
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
caption=( caption=(
"🎉 <b>Персонаж создан!</b>\n\n" "🎉 <b>Персонаж создан!</b>\n\n"
f"👤 <b>Имя:</b> {char.name}\n" f"👤 <b>Имя:</b> {char.name}\n"
f"📝 <b>Био:</b> {char.character_bio}" f"📝 <b>Био:</b> {char.character_bio}"
) )
) )
file_bytes.close()
char.character_image_tg_id = photo_msg.photo[0].file_id
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
char.character_image_tg_id = photo_msg.photo[-1].file_id
# Финальное обновление персонажа
await dao.chars.update_char(char.id, char) await dao.chars.update_char(char.id, char)
await wait_msg.delete() await wait_msg.delete()
file_io.close()
# Сбрасываем состояние # Сбрасываем состояние
await state.clear() await state.clear()
except Exception as e: except Exception as e:
logging.error(e) logger.error(f"Error creating character: {e}")
traceback.print_exc()
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}") await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
@router.message(Command("chars")) @router.message(Command("chars"))

View File

@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
await wait_msg.delete() await wait_msg.delete()
doc = await message.answer_document(res[0], caption="Generated result 💫") doc = await message.answer_document(res[0], caption="Generated result 💫")
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data, await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
@router.message(Command("gen_mode")) @router.message(Command("gen_mode"))
@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio') @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO): async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer() await call.answer()
keyboards = [] buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
for ratio in AspectRatios: keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}')) keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
await call.message.edit_caption(caption="Выбери соотношение сторон", await call.message.edit_caption(caption="Выбери соотношение сторон",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton( reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_')) @router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
@@ -259,7 +258,8 @@ async def handle_album(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
linked_char_id = data["char_id"])) linked_char_id = data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Генерация не вернула изображений.") await message.answer("❌ Генерация не вернула изображений.")
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
@@ -314,7 +314,8 @@ async def gen_mode_start(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
linked_char_id=data["char_id"])) linked_char_id=data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Ничего не сгенерировалось.") await message.answer("❌ Ничего не сгенерировалось.")

0
scheduler/__init__.py Normal file
View File

View File

@@ -0,0 +1,456 @@
import asyncio
import logging
from datetime import datetime, timezone, timedelta
from typing import Any, Dict, Optional, Tuple
from aiogram import Bot
from aiogram.types import BufferedInputFile, InlineKeyboardButton, InlineKeyboardMarkup
from adapters.google_adapter import GoogleAdapter
from adapters.ai_proxy_adapter import AIProxyAdapter
from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService
from models.Asset import Asset
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, ImageModel, Quality, TextModel
from repos.dao import DAO
logger = logging.getLogger(__name__)
MSK_TZ = timezone(timedelta(hours=3))
SCHEDULE_HOUR_MSK = 11
SCHEDULE_MINUTE_MSK = 0
# Callback data prefixes for inline keyboard buttons
CB_POST = "daily_post"
CB_REGEN_ALL = "daily_regen_all"
CB_REGEN_IMG = "daily_regen_img"
CB_REGEN_MORE = "daily_regen_more"
CB_CANCEL = "daily_cancel"
def make_admin_keyboard(generation_id: str) -> InlineKeyboardMarkup:
return InlineKeyboardMarkup(
inline_keyboard=[
[
InlineKeyboardButton(text="✅ Выложить", callback_data=f"{CB_POST}:{generation_id}"),
InlineKeyboardButton(text="❌ Отмена", callback_data=f"{CB_CANCEL}:{generation_id}"),
],
[
InlineKeyboardButton(text="🔄 Перегенерить с нуля", callback_data=f"{CB_REGEN_ALL}:{generation_id}"),
InlineKeyboardButton(text="🖼 Перегенерить изображение", callback_data=f"{CB_REGEN_IMG}:{generation_id}"),
],
[
InlineKeyboardButton(text=" Сгенерировать ещё 2", callback_data=f"{CB_REGEN_MORE}:{generation_id}"),
],
]
)
class DailyScheduler:
"""Orchestrates the daily AI-character content generation pipeline.
Flow:
1. Generate image prompt + social caption via LLM (with character avatar).
2. Generate image via GenerationService.create_generation() (reuses existing pipeline).
3. Send to Telegram admin with action buttons.
Admin actions (inline keyboard):
- Выложить → post to Instagram feed + story via Meta API.
- Перегенерить с нуля → restart from step 1.
- Перегенерить изображение → restart from step 2 (same prompt/caption).
- Сгенерировать ещё 2 → generate 2 pose-varied images.
- Отмена → dismiss (no action).
"""
def __init__(
self,
dao: DAO,
gemini: GoogleAdapter,
s3_adapter: S3Adapter,
generation_service: GenerationService,
bot: Bot,
admin_id: int,
character_id: str,
meta_adapter=None, # Optional[MetaAdapter]
):
self.dao = dao
self.gemini = gemini
self.ai_proxy = AIProxyAdapter()
self.s3_adapter = s3_adapter
self.generation_service = generation_service
self.bot = bot
self.admin_id = admin_id
self.character_id = character_id
self.meta_adapter = meta_adapter
# Stores session state keyed by generation_id.
# Each value: {prompt, caption, asset_id, message_id, chat_id}
self.pending_sessions: Dict[str, Dict[str, Any]] = {}
# ------------------------------------------------------------------
# Scheduler loop
# ------------------------------------------------------------------
async def run_loop(self):
"""Run indefinitely, triggering daily generation at 11:00 MSK."""
logger.info("Daily scheduler loop started")
while True:
try:
await self._wait_until_next_run()
logger.info("Daily scheduler: triggering daily generation")
await self.run_daily_generation()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Daily scheduler loop error: {e}", exc_info=True)
async def _wait_until_next_run(self):
now = datetime.now(MSK_TZ)
next_run = now.replace(
hour=SCHEDULE_HOUR_MSK,
minute=SCHEDULE_MINUTE_MSK,
second=0,
microsecond=0,
)
if now >= next_run:
next_run += timedelta(days=1)
wait_seconds = (next_run - now).total_seconds()
logger.info(
f"Next daily generation at {next_run.strftime('%Y-%m-%d %H:%M MSK')} "
f"(in {wait_seconds / 3600:.1f}h)"
)
await asyncio.sleep(wait_seconds)
# ------------------------------------------------------------------
# Main generation pipeline
# ------------------------------------------------------------------
async def run_daily_generation(self):
"""Full pipeline: prompt → image → send to admin."""
try:
prompt, caption = await self._generate_prompt_and_caption()
logger.info(f"Prompt generated ({len(prompt)} chars), caption ({len(caption)} chars)")
generation, asset = await self._generate_image_and_save(prompt)
logger.info(f"Generation done: id={generation.id}, asset={asset.id}")
await self._send_to_admin(generation, asset, prompt, caption)
except Exception as e:
logger.error(f"Daily generation pipeline failed: {e}", exc_info=True)
try:
await self.bot.send_message(
chat_id=self.admin_id,
text=f"❌ <b>Ежедневная генерация провалилась:</b>\n<code>{e}</code>",
)
except Exception:
pass
# ------------------------------------------------------------------
# Step 1 — Generate prompt + caption via LLM
# ------------------------------------------------------------------
async def _generate_prompt_and_caption(self) -> Tuple[str, str]:
"""Ask Gemini to produce an image prompt and social caption.
Passes the character's avatar photo to the model so it can create
a prompt that is faithful to the character's appearance.
"""
char = await self.dao.chars.get_character(self.character_id)
if not char:
raise ValueError(f"Character {self.character_id} not found in DB")
avatar_bytes_list: list[bytes] = []
if char.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char.avatar_asset_id, with_data=True)
if avatar_asset and avatar_asset.data:
avatar_bytes_list.append(avatar_asset.data)
char_bio = char.character_bio or "An expressive, stylish AI character."
system_prompt = (
f"You are a creative director for the social media account of an AI character named '{char.name}'.\n"
# f"Character description: {char_bio}\n\n"
"I'm attaching the character's avatar photo. Based on it, produce TWO things:\n\n"
"1. IMAGE_PROMPT: A detailed, vivid image generation prompt in English. "
"Describe the pose, environment, lighting, color palette, and artistic style. It must look amateur. "
"Make it unique and suitable for a social media post.\n\n"
"2. SOCIAL_CAPTION: An engaging caption in English for Instagram and TikTok. "
"Include 5-10 relevant hashtags at the end.\n\n"
"Reply in EXACTLY this format (two lines, no extra text before IMAGE_PROMPT):\n"
"IMAGE_PROMPT: <prompt here>\n"
"SOCIAL_CAPTION: <caption here>"
)
settings = await self.dao.settings.get_settings()
if settings.use_ai_proxy:
asset_urls = await self._prepare_asset_urls([char.avatar_asset_id]) if char.avatar_asset_id else None
raw = await self.ai_proxy.generate_text(
system_prompt,
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
asset_urls
)
else:
raw = await asyncio.to_thread(
self.gemini.generate_text,
system_prompt,
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
avatar_bytes_list or None,
)
logger.debug(f"LLM raw response: {raw[:500]}")
prompt, caption = self._parse_prompt_and_caption(raw, char.name)
return prompt, caption
async def _prepare_asset_urls(self, asset_ids: list[str]) -> list[str]:
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
urls = []
for asset in assets:
if asset.minio_object_name:
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
urls.append(f"{bucket}/{asset.minio_object_name}")
return urls
@staticmethod
def _parse_prompt_and_caption(raw: str, char_name: str) -> Tuple[str, str]:
prompt = ""
caption = ""
if "IMAGE_PROMPT:" in raw and "SOCIAL_CAPTION:" in raw:
after_label = raw.split("IMAGE_PROMPT:", 1)[1]
prompt = after_label.split("SOCIAL_CAPTION:", 1)[0].strip()
caption = after_label.split("SOCIAL_CAPTION:", 1)[1].strip()
elif "IMAGE_PROMPT:" in raw:
prompt = raw.split("IMAGE_PROMPT:", 1)[1].strip()
else:
prompt = raw.strip()
if not prompt:
raise ValueError(f"LLM did not produce IMAGE_PROMPT. Raw snippet: {raw[:300]}")
if not caption:
caption = f"✨ Новый контент от {char_name}"
return prompt, caption
# ------------------------------------------------------------------
# Step 2 — Generate image via GenerationService
# ------------------------------------------------------------------
async def _generate_image_and_save(
self,
prompt: str,
variation_hint: Optional[str] = None,
) -> Tuple[Generation, Asset]:
"""Create a Generation record and delegate execution to GenerationService.
Uses GenerationService.create_generation() which handles:
- loading character avatar / reference assets
- calling Gemini image generation
- saving result as Asset in S3
- finalizing the Generation record with metrics
No telegram_id is set, so the service won't send its own notification —
we handle that ourselves in _send_to_admin() with action buttons.
"""
actual_prompt = prompt
if variation_hint:
actual_prompt = f"{prompt}. {variation_hint}"
# Create Generation record (GenerationService.create_generation expects it pre-saved)
generation = Generation(
status=GenerationStatus.RUNNING,
linked_character_id=self.character_id,
aspect_ratio=AspectRatios.NINESIXTEEN,
quality=Quality.ONEK,
prompt=actual_prompt,
model=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW.value,
use_profile_image=True,
# No telegram_id → service won't send its own notification
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
try:
# Delegate all heavy lifting to the existing service
await self.generation_service.create_generation(generation)
except Exception:
# create_generation doesn't mark FAILED itself — the caller (_queued_generation_runner) does.
# So we need to handle failure here.
await self.generation_service._handle_generation_failure(generation, Exception("Image generation failed"))
raise
# After create_generation, generation.result_list is populated
if not generation.result_list:
raise ValueError("Generation completed but produced no assets")
asset = await self.dao.assets.get_asset(generation.result_list[0], with_data=False)
if not asset:
raise ValueError(f"Asset {generation.result_list[0]} not found after generation")
return generation, asset
# ------------------------------------------------------------------
# Step 3 — Send to admin
# ------------------------------------------------------------------
async def _send_to_admin(
self,
generation: Generation,
asset: Asset,
prompt: str,
caption: str,
):
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
if not img_data:
raise ValueError(f"Cannot load image from S3: {asset.minio_object_name}")
self.pending_sessions[generation.id] = {
"prompt": prompt,
"caption": caption,
"asset_id": asset.id,
}
msg = await self.bot.send_photo(
chat_id=self.admin_id,
photo=BufferedInputFile(img_data, filename="daily.png"),
caption=(
f"📸 <b>Ежедневная генерация</b>\n\n"
f"<b>Подпись для соцсетей:</b>\n{caption}\n\n"
f"<b>Промпт:</b>\n<code>{prompt[:300]}</code>"
),
reply_markup=make_admin_keyboard(generation.id),
)
self.pending_sessions[generation.id]["message_id"] = msg.message_id
self.pending_sessions[generation.id]["chat_id"] = msg.chat.id
# ------------------------------------------------------------------
# Admin action handlers (called from Telegram callback router)
# ------------------------------------------------------------------
async def handle_post(self, generation_id: str, message_id: int, chat_id: int):
"""Post to Instagram feed + story."""
session = self.pending_sessions.get(generation_id)
if not session:
return
if not self.meta_adapter:
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption="⚠️ Meta API не настроен (META_ACCESS_TOKEN не задан). Публикация недоступна.",
)
return
try:
asset = await self.dao.assets.get_asset(session["asset_id"], with_data=False)
if not asset or not asset.minio_object_name:
raise ValueError("Asset not found in DB")
image_url = await self.s3_adapter.get_presigned_url(
asset.minio_object_name, expiration=3600
)
if not image_url:
raise ValueError("Could not generate presigned URL for image")
feed_id = await self.meta_adapter.post_to_feed(image_url, session["caption"])
story_id = await self.meta_adapter.post_to_story(image_url)
self.pending_sessions.pop(generation_id, None)
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption=(
f"✅ <b>Опубликовано!</b>\n\n"
f"📰 Feed ID: <code>{feed_id}</code>\n"
f"📖 Story ID: <code>{story_id}</code>"
),
)
except Exception as e:
logger.error(f"Meta publish failed for generation {generation_id}: {e}", exc_info=True)
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption=f"❌ <b>Ошибка публикации:</b>\n<code>{e}</code>",
reply_markup=make_admin_keyboard(generation_id),
)
async def handle_regen_all(self, generation_id: str, message_id: int, chat_id: int):
"""Restart from step 1: generate new prompt, caption, and image."""
self.pending_sessions.pop(generation_id, None)
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption="🔄 <b>Перегенерация с нуля...</b>",
)
asyncio.create_task(self._run_regen_all(chat_id))
async def _run_regen_all(self, chat_id: int):
try:
await self.run_daily_generation()
except Exception as e:
logger.error(f"Regen-all failed: {e}", exc_info=True)
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка перегенерации:\n<code>{e}</code>")
async def handle_regen_image(self, generation_id: str, message_id: int, chat_id: int):
"""Restart from step 2: generate new image using existing prompt/caption."""
session = self.pending_sessions.pop(generation_id, None)
if not session:
return
prompt = session["prompt"]
caption = session["caption"]
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption="🖼 <b>Перегенерация изображения...</b>",
)
asyncio.create_task(self._run_regen_image(prompt, caption, chat_id))
async def _run_regen_image(self, prompt: str, caption: str, chat_id: int):
try:
generation, asset = await self._generate_image_and_save(prompt)
await self._send_to_admin(generation, asset, prompt, caption)
except Exception as e:
logger.error(f"Regen-image failed: {e}", exc_info=True)
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка генерации:\n<code>{e}</code>")
async def handle_regen_more(self, generation_id: str, message_id: int, chat_id: int):
"""Generate 2 more pose-varied images using the existing prompt/caption."""
session = self.pending_sessions.get(generation_id)
if not session:
return
prompt = session["prompt"]
caption = session["caption"]
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption=" <b>Генерирую ещё 2 варианта...</b>",
)
asyncio.create_task(self._run_regen_more(prompt, caption, chat_id))
async def _run_regen_more(self, prompt: str, caption: str, chat_id: int):
variation_hints = [
"Slightly vary the pose and camera angle while keeping the same scene, environment and lighting.",
"Try a different subtle pose or expression, same background and setting as described.",
]
for i, hint in enumerate(variation_hints):
try:
generation, asset = await self._generate_image_and_save(prompt, variation_hint=hint)
await self._send_to_admin(generation, asset, prompt, caption)
except Exception as e:
logger.error(f"Regen-more variant {i + 1} failed: {e}", exc_info=True)
await self.bot.send_message(
chat_id=chat_id,
text=f"❌ Ошибка варианта {i + 1}:\n<code>{e}</code>",
)
async def handle_cancel(self, generation_id: str, message_id: int, chat_id: int):
"""Dismiss: remove buttons, do nothing else."""
self.pending_sessions.pop(generation_id, None)
await self.bot.edit_message_caption(
chat_id=chat_id,
message_id=message_id,
caption="🚫 Отменено.",
)

View File

@@ -0,0 +1,82 @@
"""Telegram inline-keyboard handlers for the daily scheduler admin flow.
Usage (in aiws.py):
from scheduler.telegram_admin_handler import create_daily_scheduler_router
from scheduler.daily_scheduler import DailyScheduler
daily_scheduler = DailyScheduler(...)
dp.include_router(create_daily_scheduler_router(daily_scheduler))
"""
import logging
from aiogram import F, Router
from aiogram.types import CallbackQuery
from scheduler.daily_scheduler import (
CB_CANCEL,
CB_POST,
CB_REGEN_ALL,
CB_REGEN_IMG,
CB_REGEN_MORE,
DailyScheduler,
)
logger = logging.getLogger(__name__)
def create_daily_scheduler_router(scheduler: DailyScheduler) -> Router:
"""Return an aiogram Router with all callback handlers bound to *scheduler*."""
router = Router(name="daily_scheduler")
@router.callback_query(F.data.startswith(CB_POST + ":"))
async def on_post(callback: CallbackQuery):
generation_id = callback.data.split(":", 1)[1]
await callback.answer("Публикую в Instagram...")
await scheduler.handle_post(
generation_id=generation_id,
message_id=callback.message.message_id,
chat_id=callback.message.chat.id,
)
@router.callback_query(F.data.startswith(CB_REGEN_ALL + ":"))
async def on_regen_all(callback: CallbackQuery):
generation_id = callback.data.split(":", 1)[1]
await callback.answer("Перезапускаю с нуля...")
await scheduler.handle_regen_all(
generation_id=generation_id,
message_id=callback.message.message_id,
chat_id=callback.message.chat.id,
)
@router.callback_query(F.data.startswith(CB_REGEN_IMG + ":"))
async def on_regen_img(callback: CallbackQuery):
generation_id = callback.data.split(":", 1)[1]
await callback.answer("Генерирую новое изображение...")
await scheduler.handle_regen_image(
generation_id=generation_id,
message_id=callback.message.message_id,
chat_id=callback.message.chat.id,
)
@router.callback_query(F.data.startswith(CB_REGEN_MORE + ":"))
async def on_regen_more(callback: CallbackQuery):
generation_id = callback.data.split(":", 1)[1]
await callback.answer("Генерирую 2 варианта...")
await scheduler.handle_regen_more(
generation_id=generation_id,
message_id=callback.message.message_id,
chat_id=callback.message.chat.id,
)
@router.callback_query(F.data.startswith(CB_CANCEL + ":"))
async def on_cancel(callback: CallbackQuery):
generation_id = callback.data.split(":", 1)[1]
await callback.answer("Отменено")
await scheduler.handle_cancel(
generation_id=generation_id,
message_id=callback.message.message_id,
chat_id=callback.message.chat.id,
)
return router

View File

@@ -0,0 +1,51 @@
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
from api.service.generation_service import GenerationService
from models.Settings import SystemSettings
from models.Generation import Generation
from models.enums import AspectRatios, Quality
async def test_generation_service_proxy_logic():
dao = MagicMock()
gemini = MagicMock()
s3_adapter = MagicMock()
# Mock settings to have proxy ENABLED
dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True))
dao.assets.get_assets_by_ids = AsyncMock(return_value=[])
service = GenerationService(dao, gemini, s3_adapter)
# 1. Test ask_prompt_assistant with proxy
with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text:
mock_proxy_text.return_value = "Proxy Result"
result = await service.ask_prompt_assistant("Test Prompt")
assert result == "Proxy Result"
mock_proxy_text.assert_called_once()
gemini.generate_text.assert_not_called()
# 2. Test create_generation with proxy
generation = Generation(
prompt="Test Image",
aspect_ratio=AspectRatios.ONEONE,
quality=Quality.ONEK,
assets_list=[]
)
# Mock _prepare_generation_input to avoid complex DB calls
service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", []))
service._process_generated_images = AsyncMock(return_value=[])
service._finalize_generation = AsyncMock()
with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img:
import io
mock_img_io = io.BytesIO(b"fake image data")
mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0})
await service.create_generation(generation)
mock_proxy_img.assert_called_once()
# gemini.generate_image would be called via generate_image_task in else branch
print("✅ Proxy logic test passed!")
if __name__ == "__main__":
asyncio.run(test_generation_service_proxy_logic())

Some files were not shown because too many files have changed in this diff Show More