77 Commits

Author SHA1 Message Date
xds
32ff77e04b feat: Implement video generation functionality and integrate with Kling API. 2026-02-12 10:27:07 +03:00
xds
d1f67c773f 123 2026-02-12 00:25:08 +03:00
xds
c63b51ef75 123
er the commit message for your changes. Lines starting
2026-02-12 00:24:43 +03:00
xds
456562ec1d main -> aiws 2026-02-12 00:13:06 +03:00
xds
0d0fbdf7d6 main -> aiws 2026-02-11 12:56:51 +03:00
xds
f63bcedb13 main -> aiws 2026-02-11 12:46:57 +03:00
xds
be92c766ac main -> aiws 2026-02-11 12:46:35 +03:00
xds
482bc1d9b7 main -> aiws 2026-02-11 12:30:05 +03:00
xds
a2321cf070 + prometheus 2026-02-11 11:56:08 +03:00
xds
29ccd5743e main -> aiws 2026-02-11 11:37:04 +03:00
xds
d9de2f48d2 main -> aiws 2026-02-11 11:19:50 +03:00
xds
1ddeb0af46 main -> aiws 2026-02-11 11:15:21 +03:00
xds
a7c2319f13 feat: Implement external generation import API secured by HMAC-SHA256 signature verification. 2026-02-10 14:06:37 +03:00
xds
00e83b8561 fix 2026-02-09 17:01:48 +03:00
xds
a9d24c725e Update user repository implementation. 2026-02-09 16:16:55 +03:00
xds
458b6ebfc3 feat: Implement project management with new models, repositories, and API endpoints, and enhance character management with project association and DTOs. 2026-02-09 16:06:54 +03:00
xds
668aadcdc9 fix 2026-02-09 09:49:49 +03:00
xds
4461964791 feat: Add created_by and cost fields to generation models, populate created_by from the authenticated user, and implement cost calculation. 2026-02-09 01:52:23 +03:00
xds
fa3e1bb05f refactor: Remove trailing slashes from album router endpoint paths. 2026-02-09 00:47:54 +03:00
xds
8a89b27624 feat: Add album management functionality with new data model, repository, service, API, and generation integration. 2026-02-08 23:13:31 +03:00
xds
c17c47ccc1 catch exception123 2026-02-08 22:56:08 +03:00
xds
c25b029006 123 2026-02-08 22:53:09 +03:00
xds
a449f65de9 123 2026-02-08 17:57:09 +03:00
xds
3cf7db5cdf Merge branch 'main' of https://gitea.luminic.space/ai-char/ai-char-bot
# Conflicts:
#	.DS_Store
2026-02-08 17:53:19 +03:00
xds
288515fa04 123 2026-02-08 17:41:07 +03:00
f1033210cc Merge pull request 'auth' (#1) from auth into main
Reviewed-on: #1
2026-02-08 14:37:30 +00:00
xds
1832d07caa init 2026-02-08 17:36:52 +03:00
xds
b704707abc init auth 2026-02-08 17:36:40 +03:00
xds
31893414eb feat: Add pagination with total count to generation listings and enable filtering assets by type. 2026-02-08 02:13:59 +03:00
xds
aa50b1cc03 fix: prevent filtering by linked_character_id: None when character_id is None in generation queries. 2026-02-07 15:07:41 +03:00
xds
305ad24576 feat: Separate asset origin type from content type for improved asset categorization and handling. 2026-02-07 14:41:03 +03:00
xds
ce87ac7edb fix 2026-02-06 22:30:55 +03:00
xds
2f8de7a298 huge fix 2026-02-06 21:54:25 +03:00
xds
b8e96a2dca 12 2026-02-06 19:53:53 +03:00
xds
137279bcc5 fix 2026-02-06 19:53:05 +03:00
xds
553335940f fix 2026-02-06 19:05:52 +03:00
xds
fd1b023e7d push 2026-02-06 19:01:46 +03:00
xds
eeea0f5b8f 213 2026-02-06 18:33:49 +03:00
xds
ac5cc53006 fix 2026-02-06 18:33:12 +03:00
xds
c3b13360e0 fix 2026-02-06 18:12:03 +03:00
xds
63292a1699 fix 2026-02-06 17:51:08 +03:00
xds
59c40524e0 fix 2026-02-06 14:23:12 +03:00
xds
cdb09e84fc + s3 2026-02-06 14:07:10 +03:00
xds
37e69088a1 delete asset 2026-02-06 09:07:03 +03:00
xds
7e2f79aab1 fix 2026-02-05 22:51:36 +03:00
xds
c0debab0cb feat: Add use_profile_image and detailed token usage fields to generation models. 2026-02-05 22:19:57 +03:00
xds
002c949f08 fix 2026-02-05 21:22:36 +03:00
xds
d4682b1418 feat: Add optional telegram_id field to Generation and GenerationRequest models. 2026-02-05 21:10:50 +03:00
xds
463e73fa1e feat: Add deploy.sh script for automated remote deployment. 2026-02-05 20:54:49 +03:00
xds
76dd976854 feat: Implement image thumbnail generation, storage, and API endpoints for assets, including a regeneration utility. 2026-02-05 20:52:50 +03:00
xds
736e5a8c12 feat: Add logging to API endpoints, update generation response model, and refine project configurations. 2026-02-05 15:29:31 +03:00
xds
9ae6e8e08e feat: Update generation service, models, and API endpoints, along with refining Google generation exception handling. 2026-02-05 15:28:53 +03:00
xds
bf8396a790 catch exception 2026-02-04 18:23:36 +03:00
xds
53b2bce1b2 fix dates 2026-02-04 17:36:06 +03:00
xds
fba18728d6 + logging 2026-02-04 16:45:25 +03:00
xds
c86dfa917d + logging 2026-02-04 16:42:50 +03:00
xds
f36a368051 + logging 2026-02-04 15:59:29 +03:00
xds
c8984dc472 + logging 2026-02-04 15:57:57 +03:00
xds
b4f4ead3b3 + api 2026-02-04 15:54:39 +03:00
xds
35de8efc56 + api 2026-02-04 15:10:55 +03:00
xds
11c1f4f7dc + api 2026-02-03 23:16:18 +03:00
xds
43e9c263d5 + api 2026-02-03 17:02:14 +03:00
xds
30daa1340a + api 2026-02-03 16:51:23 +03:00
xds
e43cd575b0 + api 2026-02-03 16:21:15 +03:00
xds
cba813337e + api 2026-02-03 16:14:04 +03:00
xds
b8b708c659 + api 2026-02-03 16:11:36 +03:00
xds
a1dc734cdb + assets 2026-02-03 14:42:48 +03:00
xds
7050999ed8 + assets 2026-02-03 14:20:22 +03:00
xds
739f027742 + fixes 2026-02-03 10:27:36 +03:00
xds
f69e8f3c35 + fixes 2026-02-03 10:26:01 +03:00
xds
e8b91af804 + fixes 2026-02-03 10:24:07 +03:00
xds
befd1a66f7 + fixes 2026-02-03 10:18:27 +03:00
xds
2c310cae09 + fixes 2026-02-03 09:47:24 +03:00
xds
447107834c + fixes 2026-02-03 09:46:30 +03:00
xds
21f86afa38 + fixes 2026-02-03 09:45:27 +03:00
xds
2693675e85 + fixes 2026-02-03 09:40:43 +03:00
xds
ea3f50db50 + fixes 2026-02-03 09:37:03 +03:00
67 changed files with 4383 additions and 192 deletions

19
.dockerignore Normal file
View File

@@ -0,0 +1,19 @@
.git
.gitignore
__pycache__
*.pyc
*.pyo
*.pyd
.Python
env/
venv/
.venv/
node_modules/
tmp/
logs/
*.log
dist/
build/
.cache/
.idea/
.vscode/

9
.env
View File

@@ -1,4 +1,13 @@
BOT_TOKEN=8495170789:AAHyjjhHwwVtd9_ROnjHqPHRdnmyVr1aeaY BOT_TOKEN=8495170789:AAHyjjhHwwVtd9_ROnjHqPHRdnmyVr1aeaY
# BOT_TOKEN=8011562605:AAF3kyzrZJgii0Jx-H8Sum5Njbo0BdbsiAo
GEMINI_API_KEY=AIzaSyAHzDYhgjOqZZnvOnOFRGaSkKu4OAN3kZE GEMINI_API_KEY=AIzaSyAHzDYhgjOqZZnvOnOFRGaSkKu4OAN3kZE
MONGO_HOST=mongodb://admin:super_secure_password@31.59.58.220:27017/ MONGO_HOST=mongodb://admin:super_secure_password@31.59.58.220:27017/
ADMIN_ID=567047 ADMIN_ID=567047
MINIO_ENDPOINT=http://31.59.58.220:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=SuperSecretPassword123!
MINIO_BUCKET=ai-char
MODE=production
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4

11
.gitignore vendored Normal file
View File

@@ -0,0 +1,11 @@
minio_backup.tar.gz
.DS_Store
**/__pycache__/
*.py[cod]
*$py.class
*.cpython-*.pyc
**/.DS_Store
.idea/ai-char-bot.iml
.idea
.venv
.vscode

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

10
.idea/ai-char-bot.iml generated Normal file
View File

@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.13 (ai-char-bot)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -0,0 +1,16 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyAssertTypeInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
<inspection_tool class="PyAsyncCallInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="N802" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyTypeCheckerInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
<inspection_tool class="PyUnreachableCodeInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.13 (ai-char-bot)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (ai-char-bot)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/ai-char-bot.iml" filepath="$PROJECT_DIR$/.idea/ai-char-bot.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

46
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,46 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": [
"aiws:app",
"--reload",
"--port",
"8090",
"--host",
"0.0.0.0"
],
"jinja": true,
"justMyCode": true
},
{
"name": "Python: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Debug Tests: Current File",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": [
"${file}"
],
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
}
]
}

View File

@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
COPY . . COPY . .
# Запуск приложения (замени app.py на свой файл) # Запуск приложения (замени app.py на свой файл)
CMD ["python", "main.py"] CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]

4
adapters/Exception.py Normal file
View File

@@ -0,0 +1,4 @@
class GoogleGenerationException(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)

View File

@@ -1,12 +1,13 @@
import io import io
import logging import logging
from datetime import datetime from datetime import datetime
from typing import List, Union from typing import List, Union, Tuple, Dict, Any
from PIL import Image from PIL import Image
from google import genai from google import genai
from google.genai import types from google.genai import types
from adapters.Exception import GoogleGenerationException
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,6 +27,7 @@ class GoogleAdapter:
"""Вспомогательный метод для подготовки контента (текст + картинки)""" """Вспомогательный метод для подготовки контента (текст + картинки)"""
contents = [prompt] contents = [prompt]
if images_list: if images_list:
logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list: for img_bytes in images_list:
try: try:
# Gemini API требует PIL Image на входе # Gemini API требует PIL Image на входе
@@ -33,6 +35,8 @@ class GoogleAdapter:
contents.append(image) contents.append(image)
except Exception as e: except Exception as e:
logger.error(f"Error processing input image: {e}") logger.error(f"Error processing input image: {e}")
else:
logger.info("Preparing content with no images")
return contents return contents
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str: def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
@@ -41,7 +45,7 @@ class GoogleAdapter:
Возвращает строку с ответом. Возвращает строку с ответом.
""" """
contents = self._prepare_contents(prompt, images_list) contents = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}")
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
model=self.TEXT_MODEL, model=self.TEXT_MODEL,
@@ -58,19 +62,24 @@ class GoogleAdapter:
for part in response.parts: for part in response.parts:
if part.text: if part.text:
result_text += part.text result_text += part.text
logger.info(f"Generated text length: {len(result_text)}")
return result_text return result_text
except Exception as e: except Exception as e:
logger.error(f"Gemini Text API Error: {e}") logger.error(f"Gemini Text API Error: {e}")
return f"Ошибка генерации текста: {e}" raise GoogleGenerationException(f"Gemini Text API Error: {e}")
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> List[io.BytesIO]: def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
""" """
Генерация изображений (Text-to-Image или Image-to-Image). Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке). Возвращает список байтовых потоков (готовых к отправке).
""" """
contents = self._prepare_contents(prompt, images_list) contents = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
start_time = datetime.now()
token_usage = 0
try: try:
response = self.client.models.generate_content( response = self.client.models.generate_content(
@@ -80,12 +89,21 @@ class GoogleAdapter:
response_modalities=['IMAGE'], response_modalities=['IMAGE'],
temperature=1.0, temperature=1.0,
image_config=types.ImageConfig( image_config=types.ImageConfig(
aspect_ratio=aspect_ratio.value, aspect_ratio=aspect_ratio.value_ratio,
image_size=quality.value image_size=quality.value_quality
), ),
) )
) )
end_time = datetime.now()
api_duration = (end_time - start_time).total_seconds()
if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count
if response.parts is None and response.candidates[0].finish_reason is not None:
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
generated_images = [] generated_images = []
if response.parts: if response.parts:
@@ -108,9 +126,25 @@ class GoogleAdapter:
except Exception as e: except Exception as e:
logger.error(f"Error processing output image: {e}") logger.error(f"Error processing output image: {e}")
return generated_images if generated_images:
logger.info(f"Successfully generated {len(generated_images)} images in {api_duration:.2f}s. Tokens: {token_usage}")
else:
logger.warning("No images text generated from parts")
input_tokens = 0
output_tokens = 0
if response.usage_metadata:
input_tokens = response.usage_metadata.prompt_token_count
output_tokens = response.usage_metadata.candidates_token_count
metrics = {
"api_execution_time_seconds": api_duration,
"token_usage": token_usage,
"input_token_usage": input_tokens,
"output_token_usage": output_tokens
}
return generated_images, metrics
except Exception as e: except Exception as e:
logger.error(f"Gemini Image API Error: {e}") logger.error(f"Gemini Image API Error: {e}")
# В случае ошибки возвращаем пустой список (или можно рейзить исключение) raise GoogleGenerationException(f"Gemini Image API Error: {e}")
return []

165
adapters/kling_adapter.py Normal file
View File

@@ -0,0 +1,165 @@
import logging
import time
import asyncio
from typing import Optional, Dict, Any
import httpx
import jwt
logger = logging.getLogger(__name__)
KLING_API_BASE = "https://api.klingai.com"
class KlingApiException(Exception):
pass
class KlingAdapter:
def __init__(self, access_key: str, secret_key: str):
if not access_key or not secret_key:
raise ValueError("Kling API credentials are missing")
self.access_key = access_key
self.secret_key = secret_key
def _generate_token(self) -> str:
"""Generate a JWT token for Kling API authentication."""
now = int(time.time())
payload = {
"iss": self.access_key,
"exp": now + 1800, # 30 minutes
"iat": now - 5, # небольшой запас назад
"nbf": now - 5,
}
return jwt.encode(payload, self.secret_key, algorithm="HS256",
headers={"typ": "JWT", "alg": "HS256"})
def _headers(self) -> dict:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._generate_token()}"
}
async def create_video_task(
self,
image_url: str,
prompt: str = "",
negative_prompt: str = "",
model_name: str = "kling-v2-6",
duration: int = 5,
mode: str = "std",
cfg_scale: float = 0.5,
aspect_ratio: str = "16:9",
callback_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create an image-to-video generation task.
Returns the full task data dict including task_id.
"""
body: Dict[str, Any] = {
"model_name": model_name,
"image": image_url,
"prompt": prompt,
"negative_prompt": negative_prompt,
"duration": str(duration),
"mode": mode,
"cfg_scale": cfg_scale,
"aspect_ratio": aspect_ratio,
}
if callback_url:
body["callback_url"] = callback_url
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{KLING_API_BASE}/v1/videos/image2video",
headers=self._headers(),
json=body,
)
data = response.json()
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown Kling API error")
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
task_data = data.get("data", {})
task_id = task_data.get("task_id")
if not task_id:
raise KlingApiException("No task_id returned from Kling API")
logger.info(f"Kling video task created: task_id={task_id}")
return task_data
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
Query the status of a video generation task.
Returns the full task data dict.
"""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
headers=self._headers(),
)
data = response.json()
if response.status_code != 200 or data.get("code") != 0:
error_msg = data.get("message", "Unknown error")
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
return data.get("data", {})
async def wait_for_completion(
self,
task_id: str,
poll_interval: int = 10,
timeout: int = 600,
progress_callback=None,
) -> Dict[str, Any]:
"""
Poll the task status until completion.
Args:
task_id: Kling task ID
poll_interval: seconds between polls
timeout: max seconds to wait
progress_callback: async callable(progress_pct: int) to report progress
Returns:
Final task data dict with video URL on success.
Raises:
KlingApiException on failure or timeout.
"""
start = time.time()
attempt = 0
while True:
elapsed = time.time() - start
if elapsed > timeout:
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
task_data = await self.get_task_status(task_id)
status = task_data.get("task_status")
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
if status == "succeed":
logger.info(f"Kling task {task_id} completed successfully")
return task_data
if status == "failed":
fail_reason = task_data.get("task_status_msg", "Unknown failure")
raise KlingApiException(f"Video generation failed: {fail_reason}")
# Report progress estimate (linear approximation based on typical time)
if progress_callback:
# Estimate: typical gen is ~120s, cap at 90%
estimated_progress = min(int((elapsed / 120) * 90), 90)
attempt += 1
await progress_callback(estimated_progress)
await asyncio.sleep(poll_interval)

81
adapters/s3_adapter.py Normal file
View File

@@ -0,0 +1,81 @@
from contextlib import asynccontextmanager
from typing import Optional, BinaryIO
import aioboto3
from botocore.exceptions import ClientError
import os
class S3Adapter:
def __init__(self,
endpoint_url: str,
aws_access_key_id: str,
aws_secret_access_key: str,
bucket_name: str):
self.endpoint_url = endpoint_url
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.bucket_name = bucket_name
self.session = aioboto3.Session()
@asynccontextmanager
async def _get_client(self):
async with self.session.client(
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
) as client:
yield client
async def upload_file(self, object_name: str, data: bytes, content_type: Optional[str] = None):
"""Uploads bytes data to S3."""
try:
extra_args = {}
if content_type:
extra_args["ContentType"] = content_type
async with self._get_client() as client:
await client.put_object(
Bucket=self.bucket_name,
Key=object_name,
Body=data,
**extra_args
)
return True
except ClientError as e:
# logging.error(e)
print(f"Error uploading to S3: {e}")
return False
async def get_file(self, object_name: str) -> Optional[bytes]:
"""Downloads a file from S3 and returns bytes."""
try:
async with self._get_client() as client:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
return await response['Body'].read()
except ClientError as e:
print(f"Error downloading from S3: {e}")
return None
async def delete_file(self, object_name: str):
"""Deletes a file from S3."""
try:
async with self._get_client() as client:
await client.delete_object(Bucket=self.bucket_name, Key=object_name)
return True
except ClientError as e:
print(f"Error deleting from S3: {e}")
return False
async def get_presigned_url(self, object_name: str, expiration: int = 3600) -> Optional[str]:
"""Generate a presigned URL to share an S3 object."""
try:
async with self._get_client() as client:
response = await client.generate_presigned_url(
'get_object',
Params={'Bucket': self.bucket_name, 'Key': object_name},
ExpiresIn=expiration
)
return response
except ClientError as e:
print(f"Error generating presigned URL: {e}")
return None

254
aiws.py Normal file
View File

@@ -0,0 +1,254 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from aiogram import Bot, Dispatcher, Router, F
from aiogram.client.default import DefaultBotProperties
from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command
from aiogram.types import Message
from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info
from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА ---
from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService
from api.service.album_service import AlbumService
from middlewares.album import AlbumMiddleware
from middlewares.auth import AuthMiddleware
from middlewares.dao import DaoMiddleware
# Репозитории и DAO
from repos.char_repo import CharacterRepo
from repos.user_repo import UsersRepo
from repos.dao import DAO
# Роутеры
from routers.auth_router import router as auth_router
from routers.gen_router import router as gen_router
from routers.char_router import router as char_router
from routers.assets_router import router as assets_router # Роутер бота для ассетов
from api.endpoints.assets_router import router as api_assets_router # Роутер FastAPI
from api.endpoints.character_router import router as api_char_router # Роутер FastAPI
from api.endpoints.generation_router import router as api_gen_router
from api.endpoints.auth import router as api_auth_router
from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
load_dotenv()
logger = logging.getLogger(__name__)
# --- КОНФИГУРАЦИЯ ---
BOT_TOKEN = os.getenv("BOT_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
def setup_logging():
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s",
force=True)
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
mongo_client = AsyncIOMotorClient(MONGO_HOST)
# Репозитории
# Репозитории
users_repo = UsersRepo(mongo_client)
char_repo = CharacterRepo(mongo_client)
# S3 Adapter
s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
)
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
# Kling Adapter (optional, for video generation)
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
kling_adapter = None
if kling_access_key and kling_secret_key:
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
logger.info("Kling adapter initialized")
else:
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
album_service = AlbumService(dao)
# Dispatcher
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
# Внедрение зависимостей (глобально для бота)
dp["repo"] = users_repo
dp["admin_id"] = ADMIN_ID
dp["gemini"] = gemini
# --- НАСТРОЙКА РОУТЕРОВ БОТА ---
# 1. Роутеры без мидлварей (например, auth)
dp.include_router(auth_router)
# 2. Основные роутеры
main_router = Router()
dp.include_router(main_router)
dp.include_router(assets_router)
dp.include_router(char_router)
dp.include_router(gen_router)
# --- НАСТРОЙКА MIDDLEWARES БОТА ---
# DaoMiddleware прокидывает объект 'dao' во все хендлеры
dp.update.middleware(DaoMiddleware(dao=dao))
# AuthMiddleware проверяет права доступа
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
assets_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
# AlbumMiddleware для обработки групп фото
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
# --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# --- STARTUP ---
setup_logging()
print("🚀 Starting up...")
# 1. Настройка DAO для FastAPI
# Используем уже созданный mongo_client
db = mongo_client[DB_NAME]
# Инициализируем DAO для ассетов и кладем в state приложения
# Теперь в эндпоинтах можно делать request.app.state.assets_dao
app.state.mongo_client = mongo_client
app.state.gemini_client = gemini
app.state.bot = bot
app.state.s3_adapter = s3_adapter
app.state.kling_adapter = kling_adapter
app.state.album_service = album_service
app.state.users_repo = users_repo # Добавляем репозиторий в state
print("✅ DB & DAO initialized")
# 2. ЗАПУСК БОТА (в фоне)
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
# polling_task = asyncio.create_task(
# dp.start_polling(bot, handle_signals=False)
# )
# print("🤖 Bot polling started")
yield
# --- SHUTDOWN ---
print("🛑 Shutting down...")
# 3. Остановка бота
polling_task.cancel()
try:
await polling_task
except asyncio.CancelledError:
print("🤖 Bot polling stopped")
# 4. Отключение БД
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
# mongo_client.close()
print("🛑 DB Connection closed")
# --- НАСТРОЙКА FASTAPI ---
app = FastAPI(title="Assets API", lifespan=lifespan)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Подключаем роутеры API
app.include_router(api_auth_router)
app.include_router(api_admin_router)
app.include_router(api_assets_router)
app.include_router(api_char_router)
app.include_router(api_gen_router)
app.include_router(api_album_router)
app.include_router(project_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
).instrument(
app,
metric_namespace="ai_bot",
).expose(app, endpoint="/metrics", include_in_schema=False)
app_info = Info("fastapi_app_info", "FastAPI application info")
app_info.info({"app_name": "ai-bot"})
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
@main_router.message(Command("help"))
async def show_help(message: Message) -> None:
await message.answer(" <b>Справка:</b>\n\n"
"📝 <b>Текст:</b> Просто отправь промпт.\n"
"🎨 <b>Фото:</b> /image {промпт} (или прикрепи фото с подписью).\n\n"
"⚠️ Диалоги не сохраняются (каждое сообщение — новый запрос).")
@main_router.message(CommandStart())
async def cmd_start(message: Message):
await message.answer("👋 Привет! Я готов к работе.\n\n"
"Напиши мне, что нужно сгенерировать, или используй /help.")
# --- ЗАПУСК ---
if __name__ == "__main__":
import uvicorn
setup_logging()
async def main():
# Создаем конфигурацию uvicorn вручную
# loop="asyncio" заставляет использовать стандартный цикл
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
server = uvicorn.Server(config)
# Запускаем сервер (lifespan запустится внутри)
await server.serve()
try:
# Сами запускаем цикл, контролируя аргументы
asyncio.run(main())
except KeyboardInterrupt:
# Корректно обрабатываем выход
pass

0
api/__init__.py Normal file
View File

56
api/dependency.py Normal file
View File

@@ -0,0 +1,56 @@
# dependency.py
from fastapi import Request, Depends
from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter
from api.service.generation_service import GenerationService
from repos.dao import DAO
# ... ваши импорты ...
from aiogram import Bot
from adapters.s3_adapter import S3Adapter
from typing import Optional
# Провайдеры "сырых" клиентов из состояния приложения
def get_mongo_client(request: Request) -> AsyncIOMotorClient:
return request.app.state.mongo_client
def get_gemini_client(request: Request) -> GoogleAdapter:
return request.app.state.gemini_client
def get_bot_client(request: Request) -> Bot:
return request.app.state.bot
def get_s3_adapter(request: Request) -> Optional[S3Adapter]:
return request.app.state.s3_adapter
# Провайдер DAO (собирается из mongo_client)
def get_dao(
mongo_client: AsyncIOMotorClient = Depends(get_mongo_client),
s3_adapter: Optional[S3Adapter] = Depends(get_s3_adapter)
) -> DAO:
# FastAPI кэширует результат Depends в рамках одного запроса,
# так что DAO создастся один раз за запрос.
return DAO(mongo_client, s3_adapter)
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
return request.app.state.kling_adapter
# Провайдер сервиса (собирается из DAO и Gemini)
def get_generation_service(
dao: DAO = Depends(get_dao),
gemini: GoogleAdapter = Depends(get_gemini_client),
s3_adapter: S3Adapter = Depends(get_s3_adapter),
bot: Bot = Depends(get_bot_client),
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id

View File

96
api/endpoints/admin.py Normal file
View File

@@ -0,0 +1,96 @@
from typing import Annotated, List
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt
from starlette.requests import Request
router = APIRouter(prefix="/api/admin", tags=["admin"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
from api.endpoints.auth import get_users_repo
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo: Annotated[UsersRepo, Depends(get_users_repo)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await repo.get_user_by_username(username)
if user is None:
raise credentials_exception
return user
async def get_current_admin(user: Annotated[dict, Depends(get_current_user)]):
if not user.get("is_admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
)
return user
class UserResponse(BaseModel):
username: str
full_name: str | None = None
status: str
created_at: str | None = None
is_admin: bool
class Config:
from_attributes = True
@router.get("/approvals", response_model=List[UserResponse])
async def list_pending_users(
admin: Annotated[dict, Depends(get_current_admin)],
repo: Annotated[UsersRepo, Depends(get_users_repo)]
):
users = await repo.get_pending_users()
# Pydantic conversion handles the list of dicts
return [
UserResponse(
username=u["username"],
full_name=u.get("full_name"),
status=u["status"],
created_at=str(u.get("created_at")),
is_admin=u.get("is_admin", False)
) for u in users
]
@router.post("/approve/{username}")
async def approve_user(
username: str,
admin: Annotated[dict, Depends(get_current_admin)],
repo: Annotated[UsersRepo, Depends(get_users_repo)]
):
user = await repo.get_user_by_username(username)
if not user:
raise HTTPException(status_code=404, detail="User not found")
await repo.approve_user(username)
return {"message": f"User {username} approved"}
@router.post("/deny/{username}")
async def deny_user(
username: str,
admin: Annotated[dict, Depends(get_current_admin)],
repo: Annotated[UsersRepo, Depends(get_users_repo)]
):
user = await repo.get_user_by_username(username)
if not user:
raise HTTPException(status_code=404, detail="User not found")
await repo.deny_user(username)
return {"message": f"User {username} denied"}

View File

@@ -0,0 +1,81 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
router = APIRouter(prefix="/api/albums", tags=["Albums"])
class AlbumCreateRequest(BaseModel):
name: str
description: Optional[str] = None
class AlbumUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
class AlbumResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
generation_ids: List[str] = []
cover_asset_id: Optional[str] = None # Not implemented yet
@router.post("", response_model=AlbumResponse)
async def create_album(request: Request, album_in: AlbumCreateRequest):
service: AlbumService = request.app.state.album_service
album = await service.create_album(name=album_in.name, description=album_in.description)
return AlbumResponse(**album.model_dump())
@router.get("", response_model=List[AlbumResponse])
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
albums = await service.get_albums(limit=limit, offset=offset)
return [AlbumResponse(**album.model_dump()) for album in albums]
@router.get("/{album_id}", response_model=AlbumResponse)
async def get_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
album = await service.get_album(album_id)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.put("/{album_id}", response_model=AlbumResponse)
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
service: AlbumService = request.app.state.album_service
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
if not album:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
return AlbumResponse(**album.model_dump())
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_album(request: Request, album_id: str):
service: AlbumService = request.app.state.album_service
deleted = await service.delete_album(album_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
@router.post("/{album_id}/generations/{generation_id}")
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.add_generation_to_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.delete("/{album_id}/generations/{generation_id}")
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
service: AlbumService = request.app.state.album_service
success = await service.remove_generation_from_album(album_id, generation_id)
if not success:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
return [GenerationResponse(**gen.model_dump()) for gen in generations]

View File

@@ -0,0 +1,310 @@
from typing import List, Optional, Dict, Any
from aiogram.types import BufferedInputFile
from bson import ObjectId
from fastapi import APIRouter, UploadFile, File, Form, Depends
from fastapi.openapi.models import MediaType
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
import asyncio
import logging
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/assets", tags=["Assets"])
@router.get("/{asset_id}")
async def get_asset(
asset_id: str,
request: Request,
thumbnail: bool = False,
dao: DAO = Depends(get_dao)
) -> Response:
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
asset = await dao.assets.get_asset(asset_id)
# 2. Проверка на существование
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
headers = {
# Кэшировать на 1 год (31536000 сек)
"Cache-Control": "public, max-age=31536000, immutable"
}
content = asset.data
media_type = "image/png" # Default, or detect
if thumbnail and asset.thumbnail:
content = asset.thumbnail
media_type = "image/jpeg"
return Response(content=content, media_type=media_type, headers=headers)
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
async def delete_orphan_assets_from_minio(
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
minio_client: S3Adapter = Depends(get_s3_adapter),
*,
assets_collection: str = "assets",
generations_collection: str = "generations",
asset_type: Optional[str] = "generated",
project_id: Optional[str] = None,
dry_run: bool = True,
mark_assets_deleted: bool = False,
batch_size: int = 500,
) -> Dict[str, Any]:
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
assets = db[assets_collection]
match_assets: Dict[str, Any] = {}
if asset_type is not None:
match_assets["type"] = asset_type
if project_id is not None:
match_assets["project_id"] = project_id
pipeline: List[Dict[str, Any]] = [
{"$match": match_assets} if match_assets else {"$match": {}},
{
"$lookup": {
"from": generations_collection,
"let": {"assetIdStr": {"$toString": "$_id"}},
"pipeline": [
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
{"$match": {"is_deleted": {"$ne": True}}},
{
"$match": {
"$expr": {
"$in": [
"$$assetIdStr",
{"$ifNull": ["$result_list", []]},
]
}
}
},
{"$limit": 1},
],
"as": "alive_generations",
}
},
{
"$match": {
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
}
},
{
"$project": {
"_id": 1,
"minio_object_name": 1,
"minio_thumbnail_object_name": 1,
}
},
]
print(pipeline)
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
deleted_objects = 0
deleted_assets = 0
errors: List[Dict[str, Any]] = []
orphan_asset_ids: List[ObjectId] = []
async for asset in cursor:
aid = asset["_id"]
obj = asset.get("minio_object_name")
thumb = asset.get("minio_thumbnail_object_name")
orphan_asset_ids.append(aid)
if dry_run:
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
continue
try:
if obj:
await minio_client.delete_file(obj)
deleted_objects += 1
if thumb:
await minio_client.delete_file(thumb)
deleted_objects += 1
deleted_assets += 1
except Exception as e:
errors.append({"asset_id": str(aid), "error": str(e)})
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
res = await assets.update_many(
{"_id": {"$in": orphan_asset_ids}},
{"$set": {"is_deleted": True}},
)
marked = res.modified_count
else:
marked = 0
return {
"dry_run": dry_run,
"filter": {
"asset_type": asset_type,
"project_id": project_id,
},
"orphans_found": len(orphan_asset_ids),
"deleted_assets": deleted_assets,
"deleted_objects": deleted_objects,
"marked_assets_deleted": marked,
"errors": errors,
}
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
async def delete_asset(
asset_id: str,
dao: DAO = Depends(get_dao)
):
logger.info(f"delete_asset called for ID: {asset_id}")
# 1. Проверяем наличие (опционально, delete_one вернет false если нет, но для 404 нужно знать)
# Можно просто попробовать удалить
deleted = await dao.assets.delete_asset(asset_id)
if not deleted:
raise HTTPException(status_code=404, detail="Asset not found")
logger.info(f"Asset {asset_id} deleted successfully")
return None
@router.get("", dependencies=[Depends(get_current_user)])
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
user_id_filter = current_user["id"]
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
# Manually map to DTO to trigger computed fields validation if necessary,
# but primarily to ensure valid Pydantic models for the response list.
# Asset.model_dump() generally includes computed fields (url) if configured.
# Let's ensure strict conversion.
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
return AssetsResponse(assets=asset_responses, total_count=total_count)
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset(
file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None),
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id)
):
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
if not file.content_type:
raise HTTPException(status_code=400, detail="Unknown file type")
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
data = await file.read()
if not data:
raise HTTPException(status_code=400, detail="Empty file")
# Generate thumbnail
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, data)
asset = Asset(
name=file.filename or "upload",
type=AssetType.UPLOADED,
content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id,
data=data,
thumbnail=thumbnail_bytes,
created_by=str(current_user["_id"]),
project_id=project_id,
)
asset_id = await dao.assets.create_asset(asset)
asset.id = str(asset_id)
logger.info(f"Asset created successfully. ID: {asset_id}")
return AssetResponse(
id=asset.id,
name=asset.name,
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
linked_char_id=asset.linked_char_id,
created_at=asset.created_at,
url=asset.url
)
@router.post("/regenerate_thumbnails", dependencies=[Depends(get_current_user)])
async def regenerate_thumbnails(dao: DAO = Depends(get_dao)):
"""
Regenerates thumbnails for all existing image assets that don't have one.
"""
logger.info("Starting thumbnail regeneration task")
from utils.image_utils import create_thumbnail
import asyncio
# Get all assets (pagination loop might be needed for huge datasets, but simple list for now)
# We'll rely on DAO providing a method or just fetch large chunk.
# Assuming get_assets might have limit, let's fetch in chunks or just all if possible within limit.
# Ideally should use a specific repo method for iteration.
# For now, let's fetch first 1000 or similar.
assets = await dao.assets.get_assets(limit=1000, offset=0, with_data=True)
logger.info(f"Found {len(assets)} assets")
count = 0
updated = 0
for asset in assets:
if asset.content_type == AssetContentType.IMAGE and asset.data :
try:
thumb = await asyncio.to_thread(create_thumbnail, asset.data)
if thumb:
asset.thumbnail = thumb
await dao.assets.update_asset(asset.id, asset)
updated += 1
except Exception as e:
logger.error(f"Failed to regenerate thumbnail for asset {asset.id}: {e}")
count += 1
return {"status": "completed", "processed": count, "updated": updated}
@router.post("/migrate_to_minio", dependencies=[Depends(get_current_user)])
async def migrate_to_minio(dao: DAO = Depends(get_dao)):
"""
Migrates assets from MongoDB to MinIO.
"""
logger.info("Starting migration to MinIO")
result = await dao.assets.migrate_to_minio()
logger.info(f"Migration result: {result}")
return result

123
api/endpoints/auth.py Normal file
View File

@@ -0,0 +1,123 @@
from datetime import timedelta
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from jose import JWTError, jwt
from repos.user_repo import UsersRepo, UserStatus
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from starlette.requests import Request
router = APIRouter(prefix="/api/auth", tags=["auth"])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
async def get_users_repo(request: Request) -> UsersRepo:
if not hasattr(request.app.state, "users_repo"):
raise HTTPException(status_code=500, detail="Users repo not initialized")
return request.app.state.users_repo
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo: Annotated[UsersRepo, Depends(get_users_repo)]):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await repo.get_user_by_username(username)
if user is None:
raise credentials_exception
return user
async def get_current_admin(user: Annotated[dict, Depends(get_current_user)]):
if not user.get("is_admin"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not enough permissions",
)
return user
class UserRegister(BaseModel):
username: str
password: str
full_name: str | None = None
class Token(BaseModel):
access_token: str
token_type: str
class UserResponse(BaseModel):
id: str
username: str
full_name: str | None = None
status: str
is_admin: bool = False
@router.get("/me", response_model=UserResponse)
async def read_users_me(current_user: Annotated[dict, Depends(get_current_user)]):
return current_user
@router.post("/register")
async def register(user_data: UserRegister, repo: Annotated[UsersRepo, Depends(get_users_repo)]):
try:
await repo.create_user(
username=user_data.username,
password=user_data.password,
full_name=user_data.full_name
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"message": "Registration successful. Please wait for administrator approval."}
@router.post("/token", response_model=Token)
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
repo: Annotated[UsersRepo, Depends(get_users_repo)]
):
user = await repo.get_user_by_username(form_data.username)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Проверяем пароль
if not verify_password(form_data.password, user["hashed_password"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
# Проверка статуса
if user.get("status") != UserStatus.ALLOWED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Account is not approved yet. Please contact administrator.",
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}

View File

@@ -0,0 +1,187 @@
from typing import List, Any, Coroutine, Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.requests import Request
from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from models.Asset import Asset
from models.Character import Character
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO
from api.dependency import get_dao
import logging
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
@router.get("/", response_model=List[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
logger.info("get_characters called")
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
return characters
@router.get("/{character_id}/assets", response_model=AssetsResponse)
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
character = await dao.chars.get_character(character_id)
if character is None:
raise HTTPException(status_code=404, detail="Character not found")
# Access Check
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
# Filter assets by user ownership as well?
# Usually if you own character, you see its assets.
# But assets also have specific created_by.
# Let's assume if you own character you can see its assets.
total_count = await dao.assets.get_asset_count(character_id)
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
return AssetsResponse(assets=asset_responses, total_count=total_count)
@router.get("/{character_id}", response_model=Character)
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
logger.debug(f"get_character_by_id called. ID: {character_id}")
character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
if character:
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
return character
@router.post("/", response_model=Character)
async def create_character(
char_req: CharacterCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info("create_character called")
char_req.project_id = project_id
char_data = char_req.model_dump()
char_data["created_by"] = str(current_user["_id"])
if "id" not in char_data:
char_data["id"] = None
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
new_char = Character(**char_data)
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
created_char = await dao.chars.add_character(new_char)
return created_char
@router.put("/{character_id}", response_model=Character)
async def update_character(
character_id: str,
char_update: CharacterUpdateRequest,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info(f"update_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
update_data = char_update.model_dump(exclude_unset=True)
if "project_id" in update_data and update_data["project_id"]:
new_project_id = update_data["project_id"]
project = await dao.projects.get_project(new_project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Target project access denied")
updated_char_data = existing_char.model_dump()
updated_char_data.update(update_data)
updated_char = Character(**updated_char_data)
success = await dao.chars.update_char(character_id, updated_char)
if not success:
raise HTTPException(status_code=500, detail="Failed to update character")
return updated_char
@router.delete("/{character_id}", status_code=204)
async def delete_character(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"delete_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
success = await dao.chars.delete_character(character_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete character")
return
@router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse:
logger.info(f"post_character_generation called. CharacterID: {character_id}")
generation_service = request.app.state.generation_service

View File

@@ -0,0 +1,193 @@
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends
from starlette.requests import Request
from api import service
from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.models.VideoGenerationRequest import VideoGenerationRequest
from api.service.generation_service import GenerationService
from models.Generation import Generation
from starlette import status
import logging
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"])
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service),
current_user: dict = Depends(get_current_user)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.post("/video/_run", response_model=GenerationResponse)
async def post_video_generation(
video_request: VideoGenerationRequest,
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao),
) -> GenerationResponse:
"""Start image-to-video generation using Kling AI."""
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
video_request.project_id = project_id
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
@router.post("/import", response_model=GenerationResponse)
async def import_external_generation(
request: Request,
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
"""
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
import os
from utils.external_auth import verify_signature
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = os.getenv("EXTERNAL_API_SECRET")
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
logger.warning("Invalid signature for external generation import")
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
import json
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
# Import generation
try:
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
logger.error(f"Failed to import external generation: {e}")
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")
return None

View File

@@ -0,0 +1,167 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from models.Project import Project
from repos.dao import DAO
router = APIRouter(prefix="/api/projects", tags=["Projects"])
class ProjectCreate(BaseModel):
name: str
description: Optional[str] = None
class ProjectResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
owner_id: str
members: List[str]
is_owner: bool = False
@router.post("", response_model=ProjectResponse)
async def create_project(
project_data: ProjectCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
new_project = Project(
name=project_data.name,
description=project_data.description,
owner_id=user_id,
members=[user_id]
)
project_id = await dao.projects.create_project(new_project)
# Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return ProjectResponse(
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
@router.get("", response_model=List[ProjectResponse])
async def get_my_projects(
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
projects = await dao.projects.get_projects_by_user(user_id)
responses = []
for p in projects:
responses.append(ProjectResponse(
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
return responses
class MemberAdd(BaseModel):
username: str
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
async def add_member(
project_id: str,
member_data: MemberAdd,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can add members")
target_user = await dao.users.get_user_by_username(member_data.username)
if not target_user:
raise HTTPException(status_code=404, detail="User not found")
target_user_id = str(target_user["_id"])
if target_user_id in project.members:
return {"message": "User already in project"}
await dao.projects.add_member(project_id, target_user_id)
# Update target user's project list
await dao.users.collection.update_one(
{"_id": target_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Member added"}
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
async def join_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
# Retrieve project to verify it exists
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
user_id = str(current_user["_id"])
# Check if user is ALREADY in project
if user_id in project.members:
return {"message": "Already a member"}
# Add member
await dao.projects.add_member(project_id, user_id)
# Update user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Joined project"}
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
async def delete_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can delete project")
await dao.projects.delete_project(project_id)
# Remove project from user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$pull": {"project_ids": project_id}}
)
return {"message": "Project deleted"}

20
api/models/AssetDTO.py Normal file
View File

@@ -0,0 +1,20 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
from models.Asset import Asset
class AssetResponse(BaseModel):
id: str
name: str
type: str # uploaded / generated
content_type: str # image / prompt
linked_char_id: Optional[str] = None
created_at: datetime
url: Optional[str] = None
class AssetsResponse(BaseModel):
assets: List[AssetResponse]
total_count: int

View File

@@ -0,0 +1,18 @@
from typing import Optional
from pydantic import BaseModel
class CharacterCreateRequest(BaseModel):
name: str
character_bio: str
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None
class CharacterUpdateRequest(BaseModel):
name: Optional[str] = None
character_bio: Optional[str] = None
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -0,0 +1,37 @@
from typing import Optional
from pydantic import BaseModel, Field
from models.enums import AspectRatios, Quality
class ExternalGenerationRequest(BaseModel):
"""Request model for importing external generations."""
prompt: str
tech_prompt: Optional[str] = None
# Image can be provided as base64 string OR URL (one must be provided)
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
image_url: Optional[str] = Field(None, description="URL to download image from")
# Generation metadata
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
quality: Quality = Quality.ONEK
# Optional linking
linked_character_id: Optional[str] = None
created_by: str = Field(..., description="User ID from external system")
project_id: Optional[str] = None
# Performance metrics
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
def validate_image_source(self):
"""Ensure at least one image source is provided."""
if not self.image_data and not self.image_url:
raise ValueError("Either image_data or image_url must be provided")
if self.image_data and self.image_url:
raise ValueError("Only one of image_data or image_url should be provided")

View File

@@ -0,0 +1,64 @@
from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel
from models.Asset import Asset
from models.Generation import GenerationStatus
from models.enums import AspectRatios, Quality, GenType
class GenerationRequest(BaseModel):
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
quality: Quality = Quality.ONEK
prompt: str
telegram_id: Optional[int] = None
use_profile_image: bool = True
assets_list: List[str]
project_id: Optional[str] = None
class GenerationsResponse(BaseModel):
generations: List["GenerationResponse"]
total_count: int
class GenerationResponse(BaseModel):
id: str
status: GenerationStatus
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios
quality: Quality
prompt: str
tech_prompt: Optional[str] = None
assets_list: List[str]
result_list: List[str] = []
result: Optional[str] = None
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
progress: int = 0
cost: Optional[float] = None
created_by: Optional[str] = None
# Video-specific
kling_task_id: Optional[str] = None
video_duration: Optional[int] = None
video_mode: Optional[str] = None
created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC)
class PromptRequest(BaseModel):
prompt: str
linked_assets: List[str] = []
class PromptResponse(BaseModel):
prompt: str

View File

@@ -0,0 +1,16 @@
from typing import Optional
from pydantic import BaseModel
class VideoGenerationRequest(BaseModel):
prompt: str = ""
negative_prompt: Optional[str] = ""
image_asset_id: str # ID ассета-картинки для source image
duration: int = 5 # 5 or 10 seconds
mode: str = "std" # "std" or "pro"
model_name: str = "kling-v2-1"
cfg_scale: float = 0.5
aspect_ratio: str = "16:9"
linked_character_id: Optional[str] = None
project_id: Optional[str] = None

0
api/models/__init__.py Normal file
View File

0
api/service/__init__.py Normal file
View File

View File

@@ -0,0 +1,85 @@
from typing import List, Optional
from models.Album import Album
from models.Generation import Generation
from repos.dao import DAO
class AlbumService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
album = Album(name=name, description=description)
album_id = await self.dao.albums.create_album(album)
album.id = album_id
return album
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
return await self.dao.albums.get_albums(limit=limit, offset=offset)
async def get_album(self, album_id: str) -> Optional[Album]:
return await self.dao.albums.get_album(album_id)
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
album = await self.dao.albums.get_album(album_id)
if not album:
return None
if name:
album.name = name
if description is not None:
album.description = description
await self.dao.albums.update_album(album_id, album)
return album
async def delete_album(self, album_id: str) -> bool:
return await self.dao.albums.delete_album(album_id)
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
# Verify album exists
album = await self.dao.albums.get_album(album_id)
if not album:
return False
# Verify generation exists (optional but good practice)
gen = await self.dao.generations.get_generation(generation_id)
if not gen:
return False
if album.cover_asset_id is None and gen.status == 'done':
album.cover_asset_id = gen.result_list[0]
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
return await self.dao.albums.remove_generation(album_id, generation_id)
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
album = await self.dao.albums.get_album(album_id)
if not album or not album.generation_ids:
return []
# Slice the generation IDs (simple pagination on ID list)
# Note: This pagination is on IDs, then we fetch objects.
# Ideally, fetch only slice.
# Reverse to show newest first? Or just follow list order?
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
# Let's assume user wants same order as in list.
sliced_ids = album.generation_ids[offset : offset + limit]
if not sliced_ids:
return []
# Fetch generations by IDs
# We need a method in GenerationRepo to fetch by IDs.
# Currently we only have get_generations with filters.
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
# Let's add get_generations_by_ids to GenerationRepo.
# For now, I will use a loop if I can't modify Repo immediately,
# but I SHOULD modify GenerationRepo.
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
# But get_generations doesn't support generic filter passing.
# I'll update GenerationRepo to support fetching by IDs.
return await self.dao.generations.get_generations_by_ids(sliced_ids)

View File

@@ -0,0 +1,611 @@
import asyncio
import logging
import random
import base64
from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from adapters.kling_adapter import KlingAdapter, KlingApiException
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
from api.models.VideoGenerationRequest import VideoGenerationRequest
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality, GenType
from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
gemini: GoogleAdapter,
) -> Tuple[List[bytes], Dict[str, Any]]:
"""
Обертка для вызова синхронного метода Gemini в отдельном потоке.
Возвращает список байтов сгенерированных изображений.
"""
try :
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
result = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=aspect_ratio,
quality=quality,
)
generated_images_io, metrics = result
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e:
raise e
images_bytes = []
if generated_images_io:
for img_io in generated_images_io:
# Читаем байты из BytesIO
img_io.seek(0)
content = img_io.read()
images_bytes.append(content)
# Закрываем поток
img_io.close()
return images_bytes, metrics
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
self.dao = dao
self.gemini = gemini
self.s3_adapter = s3_adapter
self.bot = bot
self.kling_adapter = kling_adapter
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
future_prompt += prompt
assets_data = []
if assets is not None:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
logger.info(future_prompt)
logger.info(generated_prompt)
return generated_prompt
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
if user_prompt:
technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt."
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
Generation]:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
if gen is None:
return None
else:
return GenerationResponse(**gen.model_dump())
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
gen_id = None
generation_model = None
try:
generation_model = Generation(**generation_request.model_dump())
if user_id:
generation_model.created_by = user_id
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
logger.info(f"Starting background generation task for ID: {gen.id}")
try:
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model))
return GenerationResponse(**generation_model.model_dump())
except Exception:
# если не успели создать запись — нечего помечать
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation):
start_time = datetime.now()
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = []
generation_prompt = generation.prompt
# generation_prompt = f"""
# Create detailed image of character in scene.
# SCENE DESCRIPTION: {generation.prompt}
# Rules:
# - Integrate the character's appearance naturally into the scene description.
# - Focus on lighting, texture, and composition.
# """
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
for asset in reference_assets:
if asset.content_type != AssetContentType.IMAGE:
continue
img_data = None
if asset.minio_object_name:
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
elif asset.data:
img_data = asset.data
if img_data:
media_group_bytes.append(img_data)
if media_group_bytes:
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
# 3. Запускаем процесс генерации и симуляцию прогресса
progress_task = asyncio.create_task(self._simulate_progress(generation))
try:
# Default to Image Generation (Gemini)
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt, # или request.prompt
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
quality=generation.quality,
gemini=self.gemini
)
# Update metrics from API (Common for both)
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
except GoogleGenerationException as e:
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
finally:
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
# 4. Сохраняем полученные изображения как новые Ассеты
created_assets: List[Asset] = []
for idx, img_bytes in enumerate(generated_bytes_list):
# Generate thumbnail
thumbnail_bytes = None
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
# Save to S3
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
new_asset = Asset(
name=f"Generated_{generation.linked_character_id}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=generation.linked_character_id,
data=None, # Not storing bytes in DB anymore
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=generation.created_by,
project_id=generation.project_id
)
# Сохраняем в БД
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
created_assets.append(new_asset)
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = generation_prompt
end_time = datetime.now()
generation.execution_time_seconds = (end_time - start_time).total_seconds()
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
# 6. Send to Telegram if telegram_id is provided
if generation.telegram_id and self.bot:
try:
for asset in created_assets:
if asset.data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
caption=f"Generated from prompt: {generation.prompt[:100]}..."
)
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
except Exception as e:
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
async def _simulate_progress(self, generation: Generation):
"""
Increments progress from 0 to 90 over ~20 seconds.
"""
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
# Random increment between 5 and 15
increment = random.randint(5, 15)
current_progress = min(current_progress + increment, 90)
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
# But for simplicity here we just use the object we have and save it.
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
# Assuming simple update is fine for now.
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
# Task cancelled, generation finished (or failed)
pass
except Exception as e:
logger.error(f"Error in progress simulation: {e}")
async def import_external_generation(self, external_gen) -> Generation:
"""
Import a generation from an external source.
Args:
external_gen: ExternalGenerationRequest with generation data and image
Returns:
Created Generation object
"""
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
# Validate image source
external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
# 1. Process image (download or decode)
image_bytes = None
if external_gen.image_url:
# Download image from URL
logger.info(f"Downloading image from URL: {external_gen.image_url}")
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
image_bytes = response.content
elif external_gen.image_data:
# Decode base64 image
logger.info("Decoding base64 image data")
image_bytes = base64.b64decode(external_gen.image_data)
if not image_bytes:
raise ValueError("Failed to process image data")
# 2. Generate thumbnail
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
# 3. Save to S3
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
# 4. Create Asset
new_asset = Asset(
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=external_gen.linked_character_id,
data=None, # Not storing bytes in DB
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=external_gen.created_by,
project_id=external_gen.project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
logger.info(f"Created asset {asset_id} for external generation")
# 5. Create Generation record
generation = Generation(
status=GenerationStatus.DONE,
linked_character_id=external_gen.linked_character_id,
aspect_ratio=external_gen.aspect_ratio,
quality=external_gen.quality,
prompt=external_gen.prompt,
tech_prompt=external_gen.tech_prompt,
result_list=[new_asset.id],
result=new_asset.id,
progress=100,
execution_time_seconds=external_gen.execution_time_seconds,
api_execution_time_seconds=external_gen.api_execution_time_seconds,
token_usage=external_gen.token_usage,
input_token_usage=external_gen.input_token_usage,
output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
project_id=external_gen.project_id,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
logger.info(f"Created generation {gen_id} from external source")
return generation
# === VIDEO GENERATION (Kling) ===
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
"""Create a video generation task (async, returns immediately)."""
if not self.kling_adapter:
raise Exception("Kling adapter is not configured")
generation = Generation(
status=GenerationStatus.RUNNING,
gen_type=GenType.VIDEO,
linked_character_id=request.linked_character_id,
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
quality=Quality.ONEK,
prompt=request.prompt,
assets_list=[request.image_asset_id],
video_duration=request.duration,
video_mode=request.mode,
project_id=request.project_id,
)
if user_id:
generation.created_by = user_id
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
async def runner(gen, req):
logger.info(f"Starting background video generation task for ID: {gen.id}")
try:
await self.create_video_generation(gen, req)
logger.info(f"Background video generation task finished for ID: {gen.id}")
except Exception:
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen and db_gen.status != GenerationStatus.FAILED:
db_gen.status = GenerationStatus.FAILED
db_gen.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark video generation as FAILED")
logger.exception("create_video_generation task failed")
asyncio.create_task(runner(generation, request))
return GenerationResponse(**generation.model_dump())
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
"""Background video generation: call Kling API, poll, download result, save asset."""
start_time = datetime.now()
try:
# 1. Get source image presigned URL
asset = await self.dao.assets.get_asset(request.image_asset_id)
if not asset:
raise Exception(f"Asset {request.image_asset_id} not found")
if not asset.minio_object_name:
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
if not presigned_url:
raise Exception("Failed to generate presigned URL for source image")
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
# 2. Create Kling task
task_data = await self.kling_adapter.create_video_task(
image_url=presigned_url,
prompt=request.prompt,
negative_prompt=request.negative_prompt or "",
model_name=request.model_name,
duration=request.duration,
mode=request.mode,
cfg_scale=request.cfg_scale,
aspect_ratio=request.aspect_ratio,
)
task_id = task_data.get("task_id")
generation.kling_task_id = task_id
await self.dao.generations.update_generation(generation)
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
# 3. Poll for completion with progress updates
async def progress_callback(progress_pct: int):
generation.progress = progress_pct
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
result = await self.kling_adapter.wait_for_completion(
task_id=task_id,
poll_interval=10,
timeout=600,
progress_callback=progress_callback,
)
# 4. Extract video URL and download
works = result.get("task_result", {}).get("videos", [])
if not works:
raise Exception("No video in Kling result")
video_url = works[0].get("url")
video_duration = works[0].get("duration", request.duration)
if not video_url:
raise Exception("No video URL in Kling result")
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
async with httpx.AsyncClient(timeout=120.0) as client:
video_response = await client.get(video_url)
video_response.raise_for_status()
video_bytes = video_response.content
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
# 5. Upload to S3
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
# 6. Create Asset
new_asset = Asset(
name=f"Video_{generation.linked_character_id or 'gen'}",
type=AssetType.GENERATED,
content_type=AssetContentType.VIDEO,
linked_char_id=generation.linked_character_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=None, # видео thumbnails можно добавить позже
created_by=generation.created_by,
project_id=generation.project_id,
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
# 7. Finalize generation
end_time = datetime.now()
generation.result_list = [new_asset.id]
generation.result = new_asset.id
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.video_duration = video_duration
generation.execution_time_seconds = (end_time - start_time).total_seconds()
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
except KlingApiException as e:
logger.error(f"Kling API error for generation {generation.id}: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
except Exception as e:
logger.error(f"Video generation {generation.id} failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise
async def delete_generation(self, generation_id: str) -> bool:
"""
Soft delete generation by marking it as deleted.
"""
try:
generation = await self.dao.generations.get_generation(generation_id)
if not generation:
return False
generation.is_deleted = True
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
return True
except Exception as e:
logger.error(f"Error deleting generation {generation_id}: {e}")
return False

6
deploy.sh Executable file
View File

@@ -0,0 +1,6 @@
ssh root@31.59.58.220 "
cd /root/bots/ai-char-bot &&
git pull &&
docker compose up -d --build
"

View File

@@ -4,6 +4,26 @@ services:
container_name: ai-bot container_name: ai-bot
build: build:
context: . context: .
network: host # УБРАЛИ network_mode: host
network_mode: host ports:
- "8090:8090" # Вернули проброс порта
restart: unless-stopped restart: unless-stopped
depends_on:
- minio
environment:
# Важно: внутри докера к другим контейнерам обращаемся по имени сервиса!
MINIO_ENDPOINT: "http://minio:9000"
minio:
image: minio/minio:latest
container_name: minio
restart: unless-stopped
command: server /data --console-address ":9001"
environment:
MINIO_ROOT_USER: admin
MINIO_ROOT_PASSWORD: SuperSecretPassword123!
ports:
- "9000:9000"
- "9001:9001"
volumes:
- ./minio_data:/data

104
main.py
View File

@@ -1,104 +0,0 @@
import asyncio
import logging
import os
from aiogram import Bot, Dispatcher, Router, F
from aiogram.client.default import DefaultBotProperties
from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command, CommandObject
from aiogram.types import Message, BufferedInputFile
from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
# Импорты
from adapters.google_adapter import GoogleAdapter
from middlewares.album import AlbumMiddleware
from middlewares.auth import AuthMiddleware
from middlewares.dao import DaoMiddleware
from repos.char_repo import CharacterRepo
from repos.dao import DAO
from repos.user_repo import UsersRepo
from routers import char_router
# ВАЖНО: Импортируем роутер с логикой кнопок, а не создаем пустой
from routers.auth_router import router as auth_router
from routers.gen_router import router as gen_router
from routers.char_router import router as char_router
load_dotenv()
# Настройки
BOT_TOKEN = os.getenv("BOT_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
MONGO_HOST = os.getenv("MONGO_HOST")
ADMIN_ID = int(os.getenv("ADMIN_ID")) # Сразу преобразуем в int
# Инициализация
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# БД
mongo_client = AsyncIOMotorClient(MONGO_HOST)
users_repo = UsersRepo(mongo_client)
char_repo = CharacterRepo(mongo_client)
# Dispatcher
# Если MongoStorage пока не настроен на authSource=admin, можно временно убрать storage=...
dp = Dispatcher(storage=MongoStorage(mongo_client))
# ВНЕДРЕНИЕ ЗАВИСИМОСТЕЙ (чтобы они были доступны в хендлерах)
dp["repo"] = users_repo
dp["admin_id"] = ADMIN_ID
dp["gemini"] = GoogleAdapter(api_key=GEMINI_API_KEY) # Инициализируем тут
# РОУТИНГ
# 1. Роутер авторизации (кнопки) - ПОДКЛЮЧАЕМ ПЕРВЫМ И БЕЗ МИДЛВАРИ
dp.include_router(auth_router)
main_router = Router()
dp.include_router(main_router)
dp.include_router(char_router)
dp.include_router(gen_router)
# 2. Основной роутер (чат с ботом)
# Вешаем защиту ТОЛЬКО на основной роутер
main_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
gen_router.message.middleware(AuthMiddleware(repo=users_repo, admin_id=ADMIN_ID))
gen_router.message.middleware(AlbumMiddleware(latency=0.8))
dp.update.middleware(DaoMiddleware(dao=DAO(client=mongo_client)))
def setup_logging() -> None:
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
# --- ХЕНДЛЕРЫ ОСНОВНОГО РОУТЕРА ---
# Переносим их прямо сюда или в отдельный файл routers/chat_router.py
@main_router.message(Command("help"))
async def show_help(message: Message) -> None:
await message.answer("Для того, чтобы обратиться для текстовой генерации - просто отправь промпт.\n\n"
"Для генерации фото - /image {prompt}\n\n"
"Можно отправить фото и команду /image {prompt}\n\n"
"Диалоги не поддерживаются!!!! <b>Каждое новое сообщение - новый диалог</b>")
@main_router.message(CommandStart())
async def cmd_start(message: Message):
await message.answer("👋 Привет! Я готов к работе.\n\n"
"Для того, чтобы обратиться для текстовой генерации - просто отправь промпт.\n\n"
"Для генерации фото - /image {prompt}\n\n"
"Можно отправить фото и команду /image {prompt}\n\n"
"Диалоги не поддерживаются!!!! <b>Каждое новое сообщение - новый диалог</b>"
)
# --- ЗАПУСК ---
if __name__ == "__main__":
setup_logging()
try:
asyncio.run(dp.start_polling(bot))
except KeyboardInterrupt:
print("Bot stopped")

12
models/Album.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime, UTC
from typing import Optional, List
from pydantic import BaseModel, Field
class Album(BaseModel):
id: Optional[str] = None
name: str
description: Optional[str] = None
cover_asset_id: Optional[str] = None
generation_ids: List[str] = []
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

72
models/Asset.py Normal file
View File

@@ -0,0 +1,72 @@
from datetime import datetime, UTC
from enum import Enum
from typing import Optional, Any, List
from pydantic import BaseModel, computed_field, Field, model_validator
class AssetContentType(str, Enum):
IMAGE = 'image'
VIDEO = 'video'
PROMPT = 'prompt'
class AssetType(str, Enum):
UPLOADED = 'uploaded'
GENERATED = 'generated'
class Asset(BaseModel):
id: Optional[str] = None
name: str
type: AssetType = AssetType.GENERATED
content_type: AssetContentType = AssetContentType.IMAGE
linked_char_id: Optional[str] = None
data: Optional[bytes] = None
tg_doc_file_id: Optional[str] = None
tg_photo_file_id: Optional[str] = None
minio_object_name: Optional[str] = None
minio_bucket: Optional[str] = None
minio_thumbnail_object_name: Optional[str] = None
thumbnail: Optional[bytes] = None
tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode='before')
@classmethod
def check_legacy_type(cls, data: Any) -> Any:
if isinstance(data, dict):
# Если поле type содержит старые значения ("image", "prompt"),
# переносим их в content_type, а type ставим по умолчанию (GENERATED)
# или пытаемся угадать.
# Но по задаче мы дефолтим в GENERATED, и script'ом поправим.
raw_type = data.get('type')
if raw_type in ['image', 'prompt']:
data['content_type'] = raw_type
# Если в базе нет нового поля type, оно встанет в default=GENERATED
# Чтобы не вызывало ошибку валидации AssetType, меняем его или удаляем,
# полагаясь на default.
# Но если мы просто удалим, поле type примет дефолтное значение.
# Однако, если мы хотим явно отличить, можно ничего не делать,
# но тогда валидация поля `type` упадет, т.к. "image" != "generated".
# Поэтому удаляем старое значение из type, чтобы сработал дефолт.
if 'type' in data:
del data['type']
# Если content_type нет в данных (легаси), пытаемся его восстановить из удалённого type
# (выше мы его переложили).
return data
# --- CALCULATED FIELD ---
@computed_field
def url(self) -> str:
"""
Это поле автоматически вычислится и попадет в model_dump() / .json()
"""
if self.id:
return f"/assets/{self.id}"
return ""

View File

@@ -1,10 +1,17 @@
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from pydantic_core.core_schema import computed_field
class Character(BaseModel): class Character(BaseModel):
id: str | None id: Optional[str] = None
name: str name: str
character_image_doc_tg_id: str avatar_asset_id: Optional[str] = None
character_image_tg_id: str | None avatar_image: Optional[str] = None
character_bio: str character_image_data: Optional[bytes] = None
character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: Optional[str] = None
character_bio: Optional[str] = None
created_by: Optional[str] = None
project_id: Optional[str] = None

54
models/Generation.py Normal file
View File

@@ -0,0 +1,54 @@
from datetime import datetime, UTC
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, computed_field
from models.Asset import Asset
from models.enums import AspectRatios, Quality, GenType
class GenerationStatus(str, Enum):
RUNNING = "running"
DONE = "done"
FAILED = "failed"
class Generation(BaseModel):
id: Optional[str] = None
status: GenerationStatus = GenerationStatus.RUNNING
gen_type: GenType = GenType.IMAGE
failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None
telegram_id: Optional[int] = None
use_profile_image: bool = True
aspect_ratio: AspectRatios
quality: Quality
prompt: str
tech_prompt: Optional[str] = None
assets_list: List[str] = Field(default_factory=list)
result_list: List[str] = Field(default_factory=list)
result: Optional[str] = None
progress: int = 0
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
is_deleted: bool = False
album_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
# Video-specific fields
kling_task_id: Optional[str] = None
video_duration: Optional[int] = None # 5 or 10 seconds
video_mode: Optional[str] = None # "std" or "pro"
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@computed_field
def cost(self) -> float:
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
cost_input = self.input_token_usage * 0.000002
cost_output = self.output_token_usage * 0.00012
return round(cost_input + cost_output, 3)
return 0.0

12
models/Project.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class Project(BaseModel):
id: Optional[str] = None
name: str
description: Optional[str] = None
owner_id: str
members: List[str] = [] # List of User IDs
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

View File

@@ -1,19 +1,45 @@
from enum import Enum from enum import Enum
class AspectRatios(Enum): class AspectRatios(str, Enum):
NINESIXTEEN = '9:16' NINESIXTEEN = "NINESIXTEEN"
SIXTEENNINE = '16:9' SIXTEENNINE = "SIXTEENNINE"
THREEFOUR = '3:4' THREEFOUR = "THREEFOUR"
FOURTHREE = '4:3' FOURTHREE = "FOURTHREE"
@property
def value_ratio(self) -> str:
return {
AspectRatios.NINESIXTEEN: "9:16",
AspectRatios.SIXTEENNINE: "16:9",
AspectRatios.THREEFOUR: "3:4",
AspectRatios.FOURTHREE: "4:3",
}[self]
class Quality(Enum): class Quality(str, Enum):
ONEK = '1K' ONEK = 'ONEK'
TWOK = '2K' TWOK = 'TWOK'
FOURK = '4K' FOURK = 'FOURK'
@property
def value_quality(self) -> str:
return {
Quality.ONEK: '1K',
Quality.TWOK: '2K',
Quality.FOURK: '4K'
}[self]
class GenType(Enum): class GenType(str, Enum):
TEXT = 'Text' TEXT = 'Text'
IMAGE = 'Image' IMAGE = 'Image'
VIDEO = 'Video'
@property
def value_type(self) -> str:
return {
GenType.TEXT: 'Text',
GenType.IMAGE: 'Image',
GenType.VIDEO: 'Video',
}[self]

61
repos/albums_repo.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import List, Optional
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Album import Album
logger = logging.getLogger(__name__)
class AlbumsRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["albums"]
async def create_album(self, album: Album) -> str:
res = await self.collection.insert_one(album.model_dump())
return str(res.inserted_id)
async def get_album(self, album_id: str) -> Optional[Album]:
try:
res = await self.collection.find_one({"_id": ObjectId(album_id)})
if not res:
return None
res["id"] = str(res.pop("_id"))
return Album(**res)
except Exception:
return None
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
albums = []
for doc in res:
doc["id"] = str(doc.pop("_id"))
albums.append(Album(**doc))
return albums
async def update_album(self, album_id: str, album: Album) -> bool:
if not album.id:
album.id = album_id
model_dump = album.model_dump()
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
return res.modified_count > 0
async def delete_album(self, album_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
return res.deleted_count > 0
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
)
return res.modified_count > 0
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(album_id)},
{"$pull": {"generation_ids": generation_id}}
)
return res.modified_count > 0

264
repos/assets_repo.py Normal file
View File

@@ -0,0 +1,264 @@
from typing import List, Optional
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset
from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__)
class AssetsRepo:
def __init__(self, client: AsyncIOMotorClient, s3_adapter: Optional[S3Adapter] = None, db_name="bot_db"):
self.collection = client[db_name]["assets"]
self.s3 = s3_adapter
async def create_asset(self, asset: Asset) -> str:
# Если есть S3 и данные - грузим в S3
if self.s3:
# Main data
if asset.data:
ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}"
uploaded = await self.s3.upload_file(object_name, asset.data)
if uploaded:
asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name
asset.data = None # Clear data
else:
logger.error(f"Failed to upload asset {asset.name} to MinIO")
# Thumbnail
if asset.thumbnail:
ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
if uploaded_thumb:
asset.minio_thumbnail_object_name = thumb_name
asset.minio_bucket = self.s3.bucket_name # Assumes same bucket
asset.thumbnail = None # Clear thumbnail data
else:
logger.error(f"Failed to upload thumbnail for {asset.name} to MinIO")
res = await self.collection.insert_one(asset.model_dump())
return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {}
if asset_type:
filter["type"] = asset_type
args = {}
if not with_data:
args["data"] = 0
# We assume thumbnails are fetched only if needed or kept sparse.
# If they are in MinIO, we don't fetch them by default list unless specifically asked?
# User requirement "Get bytes ... from minio" usually refers to full asset. used in detail view.
# In list view, we might want thumbnails.
# If thumbnails are in MinIO, list view will be slow if we fetch all.
# Usually we return a URL. But this bot might serve bytes.
# Let's assuming list view needs thumbnails if they are small.
# But if we moved them to S3, we probably don't want to fetch 10x S3 requests for list.
# For now: If minio_thumbnail_object_name is present, user might need to fetch separately
# or we fetch if `with_data` is True?
# Standard pattern: return URL or ID.
# Let's keep existing logic: args["thumbnail"] = 0 if not with_data.
# EXCEPT if we want to show thumbnails in list.
# Original code:
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
# So list DOES NOT return thumbnails by default.
args["thumbnail"] = 0
if created_by:
filter["created_by"] = created_by
filter['project_id'] = None
if project_id:
filter["project_id"] = project_id
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
assets = []
for doc in res:
doc["id"] = str(doc.pop("_id"))
asset = Asset(**doc)
if with_data and self.s3:
# Fetch data
if asset.minio_object_name:
data = await self.s3.get_file(asset.minio_object_name)
if data: asset.data = data
# Fetch thumbnail
if asset.minio_thumbnail_object_name:
thumb = await self.s3.get_file(asset.minio_thumbnail_object_name)
if thumb: asset.thumbnail = thumb
assets.append(asset)
return assets
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
projection = None
if not with_data:
projection = {"data": 0, "thumbnail": 0}
res = await self.collection.find_one({"_id": ObjectId(asset_id)}, projection)
if not res:
return None
res["id"] = str(res.pop("_id"))
asset = Asset(**res)
if with_data and self.s3:
if asset.minio_object_name:
data = await self.s3.get_file(asset.minio_object_name)
if data: asset.data = data
if asset.minio_thumbnail_object_name:
thumb = await self.s3.get_file(asset.minio_thumbnail_object_name)
if thumb: asset.thumbnail = thumb
return asset
async def update_asset(self, asset_id: str, asset: Asset):
if not asset.id:
if asset_id: asset.id = asset_id
else: raise Exception(f"Asset ID not found: {asset_id}")
# NOTE: simplistic update. If asset has data/thumbnail bytes, we might need to upload?
# Assuming for now we just save what's given.
# If user wants to update data, they should probably use a specialized method or we handle it here.
# Let's handle it: If data/thumbnail is present AND we have S3, upload it.
if self.s3:
if asset.data:
ts = int(asset.created_at.timestamp())
object_name = f"{asset.type.value}/{ts}_{asset.name}"
if await self.s3.upload_file(object_name, asset.data):
asset.minio_object_name = object_name
asset.minio_bucket = self.s3.bucket_name
asset.data = None
if asset.thumbnail:
ts = int(asset.created_at.timestamp())
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, asset.thumbnail):
asset.minio_thumbnail_object_name = thumb_name
asset.thumbnail = None
model_dump = asset.model_dump()
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": model_dump})
async def set_tg_photo_file_id(self, asset_id: str, tg_photo_file_id: str):
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": {"tg_photo_file_id": tg_photo_file_id}})
async def get_assets_by_char_id(self, character_id: str, limit: int = 10, offset: int = 0) -> List[Asset]:
docs = await self.collection.find({"linked_char_id": character_id},
{"data": 0}, sort=[("created_at", -1)]).limit(limit).skip(offset).to_list(
None)
assets = []
for doc in docs:
doc["id"] = str(doc.pop("_id"))
assets.append(Asset(**doc))
return assets
async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
filter = {}
if character_id:
filter["linked_char_id"] = character_id
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
return await self.collection.count_documents(filter)
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
# Original excluded thumbnail too.
assets = []
async for doc in res:
doc["id"] = str(doc.pop("_id"))
assets.append(Asset(**doc))
return assets
async def delete_asset(self, asset_id: str) -> bool:
asset_doc = await self.collection.find_one({"_id": ObjectId(asset_id)})
if not asset_doc:
return False
if self.s3:
if asset_doc.get("minio_object_name"):
await self.s3.delete_file(asset_doc["minio_object_name"])
if asset_doc.get("minio_thumbnail_object_name"):
await self.s3.delete_file(asset_doc["minio_thumbnail_object_name"])
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0
async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3:
return {"error": "MinIO adapter not initialized"}
# 1. Migrate Data
cursor_data = self.collection.find({"data": {"$ne": None}, "minio_object_name": {"$eq": None}})
count_data = 0
errors_data = 0
async for doc in cursor_data:
try:
asset_id = doc["_id"]
data = doc.get("data")
name = doc.get("name", "unknown")
type_ = doc.get("type", "image")
created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0
object_name = f"{type_}/{ts}_{asset_id}_{name}"
if await self.s3.upload_file(object_name, data):
await self.collection.update_one(
{"_id": asset_id},
{"$set": {"minio_object_name": object_name, "minio_bucket": self.s3.bucket_name, "data": None}}
)
count_data += 1
else:
errors_data += 1
except Exception as e:
logger.error(f"Data migration error for {doc.get('_id')}: {e}")
errors_data += 1
# 2. Migrate Thumbnails
cursor_thumb = self.collection.find({"thumbnail": {"$ne": None}, "minio_thumbnail_object_name": {"$eq": None}})
count_thumb = 0
errors_thumb = 0
async for doc in cursor_thumb:
try:
asset_id = doc["_id"]
thumb = doc.get("thumbnail")
name = doc.get("name", "unknown")
type_ = doc.get("type", "image")
created_at = doc.get("created_at")
ts = int(created_at.timestamp()) if created_at else 0
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
if await self.s3.upload_file(thumb_name, thumb):
await self.collection.update_one(
{"_id": asset_id},
{"$set": {"minio_thumbnail_object_name": thumb_name, "minio_bucket": self.s3.bucket_name, "thumbnail": None}}
)
count_thumb += 1
else:
errors_thumb += 1
except Exception as e:
logger.error(f"Thumbnail migration error for {doc.get('_id')}: {e}")
errors_thumb += 1
return {
"migrated_data": count_data,
"errors_data": errors_data,
"migrated_thumbnails": count_thumb,
"errors_thumbnails": errors_thumb
}

View File

@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
@@ -12,30 +12,39 @@ class CharacterRepo:
async def add_character(self, character: Character) -> Character: async def add_character(self, character: Character) -> Character:
op = await self.collection.insert_one(character.model_dump()) op = await self.collection.insert_one(character.model_dump())
character.id = op.inserted_id character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str) -> Character | None: async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
res = await self.collection.find_one({"_id": ObjectId(character_id)}) args = {}
if not with_image_data:
args["character_image_data"] = 0
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
if res is None: if res is None:
return None return None
else: else:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
docs = await self.collection.find().to_list(None) filter = {}
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
characters = [] args = {"character_image_data": 0} # don't return image data for list
for doc in docs: res = await self.collection.find(filter, args).to_list(None)
# Конвертируем ObjectId в строку и кладем в поле id chars = []
for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))
chars.append(Character(**doc))
return chars
# Создаем объект async def update_char(self, char_id: str, character: Character) -> bool:
characters.append(Character(**doc)) result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
return result.modified_count > 0
return characters
async def update_char(self, char_id: str, character: Character) -> None:
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
async def delete_character(self, char_id: str) -> bool:
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
return result.deleted_count > 0

View File

@@ -1,9 +1,21 @@
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from repos.assets_repo import AssetsRepo
from repos.char_repo import CharacterRepo from repos.char_repo import CharacterRepo
from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from typing import Optional
from adapters.s3_adapter import S3Adapter
class DAO: class DAO:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): def __init__(self, client: AsyncIOMotorClient, s3_adapter: Optional[S3Adapter] = None, db_name="bot_db"):
self.chars = CharacterRepo(client, db_name) self.chars = CharacterRepo(client, db_name)
self.assets = AssetsRepo(client, s3_adapter, db_name)
self.generations = GenerationRepo(client, db_name)
self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)

79
repos/generation_repo.py Normal file
View File

@@ -0,0 +1,79 @@
from typing import Optional, List
from PIL.ImageChops import offset
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from api.models.GenerationRequest import GenerationResponse
from models.Generation import Generation, GenerationStatus
class GenerationRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["generations"]
async def create_generation(self, generation: Generation) -> str:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
else:
res["id"] = str(res.pop("_id"))
return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False}
if character_id is not None:
filter["linked_character_id"] = character_id
if status is not None:
filter["status"] = status
if created_by is not None:
filter["created_by"] = created_by
filter["project_id"] = None
if project_id is not None:
filter["project_id"] = project_id
res = await self.collection.find(filter).sort("created_at", -1).skip(
offset).limit(limit).to_list(None)
generations: List[Generation] = []
for generation in res:
generation["id"] = str(generation.pop("_id"))
generations.append(Generation(**generation))
return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
args = {}
if character_id is not None:
args["linked_character_id"] = character_id
if status is not None:
args["status"] = status
if created_by is not None:
args["created_by"] = created_by
if project_id is not None:
args["project_id"] = project_id
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
generations: List[Generation] = []
# Maintain order of generation_ids
gen_map = {str(doc["_id"]): doc for doc in res}
for gen_id in generation_ids:
doc = gen_map.get(gen_id)
if doc:
doc["id"] = str(doc.pop("_id"))
generations.append(Generation(**doc))
return generations
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})

62
repos/project_repo.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Project import Project
class ProjectRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["projects"]
async def create_project(self, project: Project) -> str:
res = await self.collection.insert_one(project.model_dump())
return str(res.inserted_id)
async def get_project(self, project_id: str) -> Optional[Project]:
if not ObjectId.is_valid(project_id):
return None
res = await self.collection.find_one({"_id": ObjectId(project_id)})
if res:
res["id"] = str(res.pop("_id"))
return Project(**res)
return None
async def get_projects_by_user(self, user_id: str) -> List[Project]:
# Find projects where user is owner OR in members
filter = {
"$or": [
{"owner_id": user_id},
{"members": user_id}
],
"is_deleted": False
}
cursor = self.collection.find(filter).sort("created_at", -1)
projects = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
projects.append(Project(**doc))
return projects
async def add_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$addToSet": {"members": user_id}}
)
return res.modified_count > 0
async def remove_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$pull": {"members": user_id}}
)
return res.modified_count > 0
async def update_project(self, project_id: str, updates: dict) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$set": updates}
)
return res.modified_count > 0
async def delete_project(self, project_id: str) -> bool:
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
return res.modified_count > 0

View File

@@ -1,8 +1,10 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from typing import Optional
from aiogram.types import User from aiogram.types import User
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from utils.security import get_password_hash
class UserStatus: class UserStatus:
@@ -17,12 +19,65 @@ class UsersRepo:
self.collection = client[db_name]["users"] self.collection = client[db_name]["users"]
async def get_user(self, user_id: int): async def get_user(self, user_id: int):
return await self.collection.find_one({"user_id": user_id}) user = await self.collection.find_one({"user_id": user_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_user_by_username(self, username: str):
user = await self.collection.find_one({"username": username})
if user:
user["id"] = str(user["_id"])
return user
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
"""Создает нового пользователя с username/паролем"""
existing = await self.get_user_by_username(username)
if existing:
raise ValueError("User with this username already exists")
user_doc = {
"username": username,
"hashed_password": get_password_hash(password),
"full_name": full_name,
"status": UserStatus.PENDING, # По умолчанию PENDING
"created_at": datetime.now(),
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
"is_web_user": True,
"is_admin": False,
"project_ids": [],
"current_project_id": None
}
result = await self.collection.insert_one(user_doc)
user = await self.collection.find_one({"_id": result.inserted_id})
if user:
user["id"] = str(user["_id"])
return user
async def get_pending_users(self):
"""Возвращает список пользователей со статусом PENDING"""
cursor = self.collection.find({"status": UserStatus.PENDING})
users = await cursor.to_list(length=100)
for user in users:
user["id"] = str(user["_id"])
return users
async def approve_user(self, username: str):
await self.collection.update_one(
{"username": username},
{"$set": {"status": UserStatus.ALLOWED}}
)
async def deny_user(self, username: str):
await self.collection.update_one(
{"username": username},
{"$set": {"status": UserStatus.DENIED}}
)
async def create_or_update_request(self, user: User): async def create_or_update_request(self, user: User):
""" """
Обновляет дату последнего запроса и ставит статус PENDING. Обновляет дату последнего запроса и ставит статус PENDING.
Сохраняет всю инфу о юзере. Сохраняет всю инфу о юзере (для Telegram пользователей).
""" """
now = datetime.now() now = datetime.now()
data = { data = {
@@ -30,7 +85,8 @@ class UsersRepo:
"username": user.username, "username": user.username,
"full_name": user.full_name, "full_name": user.full_name,
"status": UserStatus.PENDING, "status": UserStatus.PENDING,
"last_request_date": now "last_request_date": now,
"is_email_user": False
} }
await self.collection.update_one( await self.collection.update_one(
{"user_id": user.id}, {"user_id": user.id},

View File

@@ -3,15 +3,18 @@ aiogram==3.24.0
aiohappyeyeballs==2.6.1 aiohappyeyeballs==2.6.1
aiohttp==3.11.18 aiohttp==3.11.18
aiosignal==1.4.0 aiosignal==1.4.0
annotated-doc==0.0.4
annotated-types==0.7.0 annotated-types==0.7.0
anyio==4.12.1 anyio==4.12.1
attrs==25.4.0 attrs==25.4.0
certifi==2026.1.4 certifi==2026.1.4
cffi==2.0.0 cffi==2.0.0
charset-normalizer==3.4.4 charset-normalizer==3.4.4
click==8.3.1
cryptography==46.0.4 cryptography==46.0.4
distro==1.9.0 distro==1.9.0
dnspython==2.8.0 dnspython==2.8.0
fastapi==0.128.0
frozenlist==1.8.0 frozenlist==1.8.0
google-auth==2.48.0 google-auth==2.48.0
google-genai==1.61.0 google-genai==1.61.0
@@ -31,11 +34,21 @@ pydantic==2.10.6
pydantic_core==2.27.2 pydantic_core==2.27.2
pymongo==4.16.0 pymongo==4.16.0
python-dotenv==1.2.1 python-dotenv==1.2.1
python-multipart==0.0.22
requests==2.32.5 requests==2.32.5
rsa==4.9.1 rsa==4.9.1
sniffio==1.3.1 sniffio==1.3.1
starlette==0.50.0
tenacity==9.1.2 tenacity==9.1.2
typing_extensions==4.15.0 typing_extensions==4.15.0
urllib3==2.6.3 urllib3==2.6.3
uvicorn==0.40.0
websockets==15.0.1 websockets==15.0.1
yarl==1.22.0 yarl==1.22.0
aioboto3==13.3.0
passlib[argon2]==1.7.4
python-jose[cryptography]==3.3.0
python-multipart==0.0.22
email-validator
prometheus-fastapi-instrumentator
PyJWT

46
routers/assets_router.py Normal file
View File

@@ -0,0 +1,46 @@
from aiogram import Router, F
from aiogram.filters import Command
from aiogram.types import Message, InputMediaPhoto, InputMedia, BufferedInputFile, InlineKeyboardButton, CallbackQuery, \
InlineKeyboardMarkup
from repos.dao import DAO
router = Router()
@router.message(Command("assets"))
async def assets_command(msg: Message, dao: DAO):
assets = await dao.assets.get_assets(limit=10, offset=0)
media_group = []
keyboard = []
for index, asset in enumerate(assets):
if asset.tg_photo_file_id:
media_group.append(InputMediaPhoto(media=asset.tg_photo_file_id))
elif asset.tg_doc_file_id:
asset_full_info = await dao.assets.get_asset(asset.id)
asset = asset_full_info
media_group.append(InputMediaPhoto(media=BufferedInputFile(asset_full_info.data, asset_full_info.name)))
else:
continue
keyboard.append(InlineKeyboardButton(text=F"{index + 1}", callback_data=f"asset_doc_{asset.id}"))
mg = await msg.answer_media_group(media_group,
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboard]))
await msg.answer("Для запроса документов выбери фото и на кнопку ниже:",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboard]))
for media_index, media in enumerate(mg):
if assets[media_index].tg_photo_file_id is None:
await dao.assets.set_tg_photo_file_id(assets[media_index].id, media.photo[-1].file_id)
@router.callback_query(F.data.startswith("asset_doc_"))
async def assets_callback_query(call: CallbackQuery, dao: DAO):
await call.answer()
assets_id = call.data.split("asset_doc_")[-1]
asset = await dao.assets.get_asset(assets_id, with_data=False)
if asset.tg_doc_file_id:
await call.message.answer_document(asset.tg_doc_file_id)
else:
asset_full_info = await dao.assets.get_asset(assets_id)
doc = await call.message.answer_document(BufferedInputFile(asset_full_info.data, asset_full_info.name))
asset_full_info.tg_doc_file_id = doc.document.file_id
await dao.assets.update_asset(assets_id, asset_full_info)

View File

@@ -1,3 +1,4 @@
import io
import logging import logging
import traceback import traceback
@@ -7,6 +8,7 @@ from aiogram.fsm.state import State, StatesGroup
from aiogram.types import * from aiogram.types import *
from aiogram import Router, F, Bot from aiogram import Router, F, Bot
from models.Asset import Asset, AssetType, AssetContentType
from models.Character import Character from models.Character import Character
from repos.dao import DAO from repos.dao import DAO
@@ -50,35 +52,46 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
try: try:
# ВОТ ТУТ скачиваем файл (прямо перед сохранением) # ВОТ ТУТ скачиваем файл (прямо перед сохранением)
# file_io = await bot.download(file_id) file_io = await bot.download(file_id)
# photo_bytes = file_io.getvalue() # Получаем байты # photo_bytes = file_io.getvalue() # Получаем байты
# Создаем модель # Создаем модель
char = Character( char = Character(
id=None, id=None,
name=name, name=name,
# character_image=photo_bytes, character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio character_bio=bio,
created_by=str(message.from_user.id)
) )
file_io.close()
# Сохраняем через DAO # Сохраняем через DAO
await dao.chars.add_character(char) await dao.chars.add_character(char)
file_info = await bot.get_file(char.character_image_doc_tg_id) file_info = await bot.get_file(char.character_image_doc_tg_id)
file_bytes = await bot.download_file(file_info.file_path) file_bytes = await bot.download_file(file_info.file_path)
file_io = file_bytes.read()
avatar_asset = await dao.assets.create_asset(
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
tg_doc_file_id=file_id))
char.avatar_image = avatar_asset.link
# Отправляем подтверждение # Отправляем подтверждение
# Используем байты для отправки обратно # Используем байты для отправки обратно
photo_msg = await message.answer_photo( photo_msg = await message.answer_photo(
photo=BufferedInputFile(file_bytes.read(), photo=BufferedInputFile(file_io,
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id, filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
caption=( caption=(
"🎉 <b>Персонаж создан!</b>\n\n" "🎉 <b>Персонаж создан!</b>\n\n"
f"👤 <b>Имя:</b> {char.name}\n" f"👤 <b>Имя:</b> {char.name}\n"
f"📝 <b>Био:</b> {char.character_bio}" f"📝 <b>Био:</b> {char.character_bio}"
) )
) )
file_bytes.close()
char.character_image_tg_id = photo_msg.photo[0].file_id char.character_image_tg_id = photo_msg.photo[0].file_id
await dao.chars.update_char(char.id, char) await dao.chars.update_char(char.id, char)
await wait_msg.delete() await wait_msg.delete()

View File

@@ -14,6 +14,7 @@ from aiogram.types import *
import keyboards import keyboards
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from models.Asset import Asset, AssetType, AssetContentType
from models.Character import Character from models.Character import Character
from models.enums import AspectRatios, Quality, GenType from models.enums import AspectRatios, Quality, GenType
from repos.dao import DAO from repos.dao import DAO
@@ -34,6 +35,25 @@ async def init_gen_mode(state: FSMContext, dao: DAO):
await state.update_data(data) await state.update_data(data)
@router.message(Command("image"))
async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemini: GoogleAdapter, bot: Bot):
wait_msg = await message.answer("Генерирую...")
if message.photo:
res = await generate_image(prompt=message.caption, media=[message.photo[0].file_id], state=state, dao=dao,
bot=bot,
gemini=gemini)
else:
res = await generate_image(prompt=message.text, media=None, state=state, dao=dao, bot=bot,
gemini=gemini)
await wait_msg.delete()
doc = await message.answer_document(res[0], caption="Generated result 💫")
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
@router.message(Command("gen_mode")) @router.message(Command("gen_mode"))
async def gen_mode(message: Message, state: FSMContext, dao: DAO): async def gen_mode(message: Message, state: FSMContext, dao: DAO):
state_on = await state.get_state() state_on = await state.get_state()
@@ -142,12 +162,13 @@ async def change_quality(call: CallbackQuery, state: FSMContext, dao: DAO):
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type') @router.callback_query(States.gen_mode, F.data == 'gen_mode_change_type')
async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO): async def gen_mode_change_type(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer() await call.answer()
keyboards = [] data = await state.get_data()
for gen_type in GenType: if GenType[data['type']] is GenType.IMAGE:
keyboards.append(InlineKeyboardButton(text=gen_type.value, callback_data=f'select_type_{gen_type.name}')) await state.update_data({"type": GenType.TEXT.name})
await call.message.edit_caption(caption="Выбери тип генерации", reply_markup=InlineKeyboardMarkup( else:
inline_keyboard=[keyboards, await state.update_data({"type": GenType.IMAGE.name})
[InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_type_change")]]))
await gen_mode_base_msg(call.message, state=state, dao=dao)
@router.callback_query(States.gen_mode, F.data.startswith('select_type_')) @router.callback_query(States.gen_mode, F.data.startswith('select_type_'))
@@ -178,11 +199,15 @@ async def gen_mode_base_msg(message: Message, state: FSMContext, dao: DAO, call_
else: else:
try: try:
await message.edit_caption( await message.edit_caption(
caption="🎉 Режим генерации включен!", caption="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n"
"<b>Фото девушки грузить не надо, оно загрузится по дефолту</b>\n\n"
"Но дополнительные фото можно загрузить.",
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao)) reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
except TelegramBadRequest as tbr: except TelegramBadRequest as tbr:
await message.edit_text( await message.edit_text(
text="🎉 Режим генерации включен!", text="🎉 Режим генерации включен! Просто пиши мне промпт и я отправлю в генерацию по указанным настройкам.\n\n"
"<b>Фото девушки грузить не надо, оно загрузится по дефолту</b>\n\n"
"Но дополнительные фото можно загрузить.",
reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao)) reply_markup=await keyboards.get_gen_mode_kb(state=state, dao=dao))
@@ -211,9 +236,6 @@ async def handle_album(
for msg in album: for msg in album:
if msg.photo: if msg.photo:
file_ids.append(msg.photo[-1].file_id) file_ids.append(msg.photo[-1].file_id)
elif msg.video:
# Если нужно, можно добавить обработку видео (пока пропускаем)
pass
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...") await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
wait_msg = await message.answer("🎨 Генерирую...") wait_msg = await message.answer("🎨 Генерирую...")
@@ -230,14 +252,18 @@ async def handle_album(
) )
await wait_msg.delete() await wait_msg.delete()
data = await state.get_data()
# 4. Отправляем результат # 4. Отправляем результат
if generated_files: if generated_files:
for file in generated_files: for file in generated_files:
await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
linked_char_id = data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Генерация не вернула изображений.") await message.answer("❌ Генерация не вернула изображений.")
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
except Exception as e: except Exception as e:
await wait_msg.edit_text(f"❌ Ошибка: {e}") await wait_msg.edit_text(f"❌ Ошибка: {e}")
@@ -270,7 +296,6 @@ async def gen_mode_start(
media_ids.append(message.reply_to_message.photo[-1].file_id) media_ids.append(message.reply_to_message.photo[-1].file_id)
wait_msg = await message.answer("🎨 Генерирую...") wait_msg = await message.answer("🎨 Генерирую...")
data = await state.get_data() data = await state.get_data()
try: try:
if GenType[data['type']] is GenType.IMAGE: if GenType[data['type']] is GenType.IMAGE:
@@ -287,7 +312,11 @@ async def gen_mode_start(
if generated_files: if generated_files:
for file in generated_files: for file in generated_files:
await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
linked_char_id=data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Ничего не сгенерировалось.") await message.answer("❌ Ничего не сгенерировалось.")
@@ -309,17 +338,27 @@ async def generate_image(
bot: Bot, bot: Bot,
gemini: GoogleAdapter gemini: GoogleAdapter
) -> List[BufferedInputFile]: ) -> List[BufferedInputFile]:
data = await state.get_data() state_value = await state.get_state()
char_id = data.get("char_id") media_group_bytes = []
if not char_id: ar = AspectRatios.NINESIXTEEN
raise ValueError("Character ID not found in state") quality = Quality.TWOK
char: Character = await dao.chars.get_character(char_id) if state_value is not None:
data = await state.get_data()
ar = AspectRatios[data['aspect_ratio']]
quality = Quality[data['quality']]
char_id = data.get("char_id")
if not char_id:
raise ValueError("Character ID not found in state")
char: Character = await dao.chars.get_character(char_id)
# Начинаем список с фото персонажа
file_byes = await bot.download(char.character_image_doc_tg_id)
media_group_bytes.append(file_byes.read())
file_byes.close()
# Начинаем список с фото персонажа
file_byes = await bot.download(char.character_image_doc_tg_id)
media_group_bytes = [file_byes.read()]
file_byes.close()
if media: if media:
# Скачиваем файлы # Скачиваем файлы
# tasks вернут список объектов BytesIO # tasks вернут список объектов BytesIO
@@ -342,11 +381,10 @@ async def generate_image(
gemini.generate_image, gemini.generate_image,
prompt=prompt, prompt=prompt,
images_list=media_group_bytes, images_list=media_group_bytes,
aspect_ratio=AspectRatios[data['aspect_ratio']], aspect_ratio=ar,
quality=Quality[data['quality']], quality=quality,
) )
images = [] images = []
if generated_images_io: if generated_images_io:
for img_io in generated_images_io: for img_io in generated_images_io:
@@ -359,7 +397,7 @@ async def generate_image(
images.append( images.append(
BufferedInputFile( BufferedInputFile(
content, content,
filename=f"img_{random.randint(1000, 9999)}.png" filename=f"img_{random.randint(1000, 99999)}.png"
) )
) )
@@ -375,11 +413,12 @@ async def handle_text(message: Message, gemini: GoogleAdapter, state: FSMContext
async def gen_start_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot, async def gen_start_text(message: Message, gemini: GoogleAdapter, state: FSMContext, dao: DAO, bot: Bot,
char_id: str = None) -> str: char_id: str = None) -> str:
await bot.send_chat_action(message.chat.id, "typing") await bot.send_chat_action(message.chat.id, "typing")
prompt = message.text prompt = "Use a TELEGRAM HTML formatting. If you write a prompt use <pre> tag.\n\n"
prompt += f"PROMPT:\n{message.text}\n\n"
if char_id: if char_id:
char = await dao.chars.get_character(message.chat.id) char = await dao.chars.get_character(message.chat.id)
prompt += char.character_bio prompt += char.character_bio
result = await asyncio.to_thread(gemini.generate_text, prompt=message.text) result = await asyncio.to_thread(gemini.generate_text, prompt=prompt)
if result: if result:
return result return result
else: else:

View File

@@ -0,0 +1,22 @@
import pytest
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
def test_api_protection():
# 1. Assets
response = client.get("/api/assets")
assert response.status_code == 401
# 2. Characters
response = client.get("/api/characters")
assert response.status_code == 401
# 3. Generations
response = client.get("/api/generations")
assert response.status_code == 401
# 4. Upload Asset (POST)
response = client.post("/api/assets/upload")
assert response.status_code == 401

107
tests/test_auth_flow.py Normal file
View File

@@ -0,0 +1,107 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, MagicMock
from datetime import datetime
from main import app
from api.endpoints.auth import get_users_repo
from repos.user_repo import UsersRepo, UserStatus
from utils.security import get_password_hash
# Mock Repository
class MockUsersRepo:
def __init__(self):
self.users = {}
async def get_user_by_username(self, username: str):
return self.users.get(username)
async def create_user(self, username: str, password: str, full_name: str = None):
if username in self.users:
raise ValueError("User with this username already exists")
user = {
"username": username,
"hashed_password": get_password_hash(password),
"full_name": full_name,
"status": UserStatus.PENDING,
"is_email_user": False,
"is_admin": False,
"created_at": datetime.now()
}
self.users[username] = user
return user
async def get_pending_users(self):
return [u for u in self.users.values() if u["status"] == UserStatus.PENDING]
async def approve_user(self, username: str):
if username in self.users:
self.users[username]["status"] = UserStatus.ALLOWED
async def deny_user(self, username: str):
if username in self.users:
self.users[username]["status"] = UserStatus.DENIED
mock_repo = MockUsersRepo()
# Override Dependency
app.dependency_overrides[get_users_repo] = lambda: mock_repo
client = TestClient(app)
def test_auth_flow_with_approval():
# 1. Register (User)
user_data = {
"username": "newuser",
"password": "password123",
"full_name": "New User"
}
response = client.post("/auth/register", json=user_data)
assert response.status_code == 200
assert response.json()["message"] == "Registration successful. Please wait for administrator approval."
# 2. Try Login (User) -> Should Fail (Pending)
login_data = {
"username": "newuser",
"password": "password123"
}
response = client.post("/auth/token", data=login_data)
assert response.status_code == 403
assert "not approved" in response.json()["detail"]
# 3. Setup Admin (Backdoor for test)
mock_repo.users["admin"] = {
"username": "admin",
"hashed_password": get_password_hash("adminpass"),
"status": UserStatus.ALLOWED,
"is_admin": True,
"created_at": datetime.now()
}
# 4. Admin Login
admin_login = {
"username": "admin",
"password": "adminpass"
}
response = client.post("/auth/token", data=admin_login)
assert response.status_code == 200
admin_token = response.json()["access_token"]
admin_auth = {"Authorization": f"Bearer {admin_token}"}
# 5. List Pending (Admin)
response = client.get("/admin/approvals", headers=admin_auth)
assert response.status_code == 200
users = response.json()
assert len(users) >= 1
assert users[0]["username"] == "newuser"
assert users[0]["status"] == "pending"
# 6. Approve User (Admin)
response = client.post("/admin/approve/newuser", headers=admin_auth)
assert response.status_code == 200
assert response.json()["message"] == "User newuser approved"
# 7. Login User (Again) -> Should Success
response = client.post("/auth/token", data=login_data)
assert response.status_code == 200
assert "access_token" in response.json()

View File

@@ -0,0 +1,101 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio
from main import app
from api.endpoints.auth import get_current_user
from api.dependency import get_dao
from repos.dao import DAO
from models.Character import Character
# Config for test DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_chars"
# Mock User
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed"
}
# Override get_current_user to bypass auth
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
# Setup Real DAO with Test DB
client_mongo = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client_mongo, db_name=DB_NAME)
def mock_get_dao():
return dao
app.dependency_overrides[get_dao] = mock_get_dao
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
# Setup: Ensure clean state
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
yield
# Teardown
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
loop.close()
def test_character_crud_flow():
# 1. Create Character
create_payload = {
"name": "Test Character",
"character_bio": "A bio for test character",
"character_image_doc_tg_id": "file_123",
"avatar_image": "http://example.com/avatar.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
assert char_data["created_by"] == MOCK_USER_ID
char_id = char_data["id"]
assert char_id is not None
# 2. Get Character
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update Character
update_payload = {
"name": "Updated Name",
"character_bio": "Updated bio"
}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
updated_data = response.json()
assert updated_data["name"] == "Updated Name"
assert updated_data["character_bio"] == "Updated bio"
# Verify update persistent
response = client.get(f"/api/characters/{char_id}")
assert response.json()["name"] == "Updated Name"
# 4. Delete Character
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# Verify deletion
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404, "Deleted character should return 404"

View File

@@ -0,0 +1,64 @@
import os
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
# 1. Set Auth Bypass and Test Config
os.environ["DB_NAME"] = "bot_db_test_integration"
# We keep MONGO_HOST as is (it works in verified script)
# 2. Import app AFTER setting env
from main import app
from api.endpoints.auth import get_current_user
# 3. Override Auth
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed",
"project_ids": []
}
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
client = TestClient(app)
def test_character_crud_lifecycle():
# 1. Create
create_payload = {
"name": "Integration Test Char",
"character_bio": "Testing with real app structure",
"character_image_doc_tg_id": "doc_123",
"avatar_image": "http://example.com/img.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
char_id = char_data["id"]
# 2. Get
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update
update_payload = {"name": "Updated Int Name"}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
assert response.json()["name"] == "Updated Int Name"
# 4. Delete
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# 5. Verify Delete
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404

63
tests/test_external_import.py Executable file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
Test script for external generation import API.
This script demonstrates how to call the import endpoint with proper HMAC signature.
"""
import hmac
import hashlib
import json
import requests
import base64
import os
from dotenv import load_dotenv
load_dotenv()
# Configuration
API_URL = "http://localhost:8090/api/generations/import"
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
# Sample generation data
generation_data = {
"prompt": "A beautiful sunset over mountains",
"tech_prompt": "High quality landscape photography",
"image_url": "https://picsum.photos/512/512", # Sample image URL
# OR use base64:
# "image_data": "base64_encoded_image_string_here",
"aspect_ratio": "9:16",
"quality": "1k",
"created_by": "external_user_123",
"execution_time_seconds": 5.2,
"token_usage": 1000,
"input_token_usage": 200,
"output_token_usage": 800
}
# Convert to JSON
body = json.dumps(generation_data).encode('utf-8')
# Compute HMAC signature
signature = hmac.new(
SECRET.encode('utf-8'),
body,
hashlib.sha256
).hexdigest()
# Make request
headers = {
"Content-Type": "application/json",
"X-Signature": signature
}
print(f"Sending request to {API_URL}")
print(f"Signature: {signature}")
try:
response = requests.post(API_URL, data=body, headers=headers)
print(f"\nStatus Code: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
except Exception as e:
print(f"Error: {e}")
if hasattr(e, 'response'):
print(f"Response text: {e.response.text}")

View File

@@ -0,0 +1,44 @@
import asyncio
import os
from dotenv import load_dotenv
from adapters.s3_adapter import S3Adapter
async def test_s3():
load_dotenv()
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
access_key = os.getenv("MINIO_ACCESS_KEY")
secret_key = os.getenv("MINIO_SECRET_KEY")
bucket = os.getenv("MINIO_BUCKET")
print(f"Connecting to {endpoint}, bucket: {bucket}")
s3 = S3Adapter(endpoint, access_key, secret_key, bucket)
test_filename = "test_connection.txt"
test_data = b"Hello MinIO!"
print("Uploading...")
success = await s3.upload_file(test_filename, test_data)
if success:
print("Upload successful!")
else:
print("Upload failed!")
return
print("Downloading...")
data = await s3.get_file(test_filename)
if data == test_data:
print("Download successful and data matches!")
else:
print(f"Download mismatch: {data}")
print("Deleting...")
deleted = await s3.delete_file(test_filename)
if deleted:
print("Delete successful!")
else:
print("Delete failed!")
if __name__ == "__main__":
asyncio.run(test_s3())

View File

@@ -0,0 +1,91 @@
import asyncio
import os
import sys
# Add project root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from motor.motor_asyncio import AsyncIOMotorClient
from repos.dao import DAO
from models.Album import Album
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
# Mock config
# Use the same host as aiws.py but different DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_albums"
async def test_albums():
print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...")
# Needs to run inside a loop from main
client = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client, db_name=DB_NAME)
try:
# 1. Clean up
await client[DB_NAME]["albums"].drop()
await client[DB_NAME]["generations"].drop()
print("✅ Cleaned up test database")
# 2. Create Album
album = Album(name="Test Album", description="A test album")
print("Creating album...")
album_id = await dao.albums.create_album(album)
print(f"✅ Created Album: {album_id}")
# 3. Create Generations
gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
print("Creating generations...")
gen1_id = await dao.generations.create_generation(gen1)
gen2_id = await dao.generations.create_generation(gen2)
print(f"✅ Created Generations: {gen1_id}, {gen2_id}")
# 4. Add generations to album
print("Adding generations to album...")
await dao.albums.add_generation(album_id, gen1_id)
await dao.albums.add_generation(album_id, gen2_id)
print("✅ Added generations to album")
# 5. Fetch album and check generation_ids
album_fetched = await dao.albums.get_album(album_id)
assert album_fetched is not None
assert len(album_fetched.generation_ids) == 2
assert gen1_id in album_fetched.generation_ids
assert gen2_id in album_fetched.generation_ids
print("✅ Verified generations in album")
# 6. Fetch generations by IDs via GenerationRepo
generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id])
assert len(generations) == 2
# Ensure ID type match (str vs ObjectId handling in repo)
gen_ids_fetched = [g.id for g in generations]
assert gen1_id in gen_ids_fetched
assert gen2_id in gen_ids_fetched
print("✅ Verified fetching generations by IDs")
# 7. Remove generation
print("Removing generation...")
await dao.albums.remove_generation(album_id, gen1_id)
album_fetched = await dao.albums.get_album(album_id)
assert len(album_fetched.generation_ids) == 1
assert album_fetched.generation_ids[0] == gen2_id
print("✅ Verified removing generation from album")
print("🎉 Album Verification SUCCESS")
finally:
# Cleanup client
client.close()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
try:
asyncio.run(test_albums())
except Exception as e:
print(f"Error: {e}")

View File

@@ -0,0 +1,84 @@
import asyncio
import os
from datetime import datetime
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from models.Asset import Asset, AssetType
from repos.assets_repo import AssetsRepo
from adapters.s3_adapter import S3Adapter
# Load env to get credentials
load_dotenv()
async def test_integration():
print("🚀 Starting integration test...")
# 1. Setup Dependencies
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
client = AsyncIOMotorClient(mongo_uri)
db_name = os.getenv("DB_NAME", "bot_db_test")
s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
)
repo = AssetsRepo(client, s3_adapter, db_name=db_name)
# 2. Create Asset with Data and Thumbnail
print("📝 Creating asset...")
dummy_data = b"image_data_bytes"
dummy_thumb = b"thumbnail_bytes"
asset = Asset(
name="test_asset_with_thumb.png",
type=AssetType.IMAGE,
data=dummy_data,
thumbnail=dummy_thumb
)
asset_id = await repo.create_asset(asset)
print(f"✅ Asset created with ID: {asset_id}")
# 3. Verify object names in Mongo (Raw check)
print("🔍 Verifying Mongo metadata...")
# Used repo to fetch is better
fetched_asset = await repo.get_asset(asset_id, with_data=False)
if fetched_asset.minio_object_name:
print(f"✅ minio_object_name set: {fetched_asset.minio_object_name}")
else:
print("❌ minio_object_name NOT set!")
if fetched_asset.minio_thumbnail_object_name:
print(f"✅ minio_thumbnail_object_name set: {fetched_asset.minio_thumbnail_object_name}")
else:
print("❌ minio_thumbnail_object_name NOT set!")
# 4. Fetch Data from S3 via Repo
print("📥 Fetching data from MinIO...")
full_asset = await repo.get_asset(asset_id, with_data=True)
if full_asset.data == dummy_data:
print("✅ Data matches!")
else:
print(f"❌ Data mismatch! Got: {full_asset.data}")
if full_asset.thumbnail == dummy_thumb:
print("✅ Thumbnail matches!")
else:
print(f"❌ Thumbnail mismatch! Got: {full_asset.thumbnail}")
# 5. Clean up
print("🧹 Cleaning up...")
deleted = await repo.delete_asset(asset_id)
if deleted:
print("✅ Asset deleted")
else:
print("❌ Failed to delete asset")
if __name__ == "__main__":
asyncio.run(test_integration())

46
utils/external_auth.py Normal file
View File

@@ -0,0 +1,46 @@
import hmac
import hashlib
import os
from fastapi import Header, HTTPException
from typing import Optional
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
"""
Verify HMAC-SHA256 signature.
Args:
body: Raw request body bytes
signature: Signature from X-Signature header
secret: Shared secret key
Returns:
True if signature is valid, False otherwise
"""
expected_signature = hmac.new(
secret.encode('utf-8'),
body,
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
async def verify_external_signature(
x_signature: Optional[str] = Header(None, alias="X-Signature")
):
"""
FastAPI dependency to verify external API signature.
Raises:
HTTPException: If signature is missing or invalid
"""
if not x_signature:
raise HTTPException(
status_code=401,
detail="Missing X-Signature header"
)
# Note: We'll need to access the raw request body in the endpoint
# This dependency just validates the header exists
# Actual signature verification happens in the endpoint
return x_signature

27
utils/image_utils.py Normal file
View File

@@ -0,0 +1,27 @@
from io import BytesIO
from typing import Tuple, Optional
from PIL import Image
import logging
logger = logging.getLogger(__name__)
def create_thumbnail(image_data: bytes, size: Tuple[int, int] = (800, 800)) -> Optional[bytes]:
"""
Creates a thumbnail from image bytes.
Returns the thumbnail as bytes (JPEG format) or None if failed.
"""
try:
with Image.open(BytesIO(image_data)) as img:
# Convert to RGB if necessary (e.g. for RGBA/P images saving as JPEG)
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
img.thumbnail(size)
thumb_io = BytesIO()
img.save(thumb_io, format='JPEG', quality=85)
thumb_io.seek(0)
return thumb_io.read()
except Exception as e:
logger.error(f"Failed to create thumbnail: {e}")
return None

35
utils/security.py Normal file
View File

@@ -0,0 +1,35 @@
from datetime import datetime, timedelta
from typing import Optional, Union, Any
from jose import jwt
from passlib.context import CryptContext
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
# SECRET_KEY должен быть сложным и секретным в продакшене!
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
# Стандартное поле 'exp' для JWT
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt