Compare commits
51 Commits
9ae6e8e08e
...
video
| Author | SHA1 | Date | |
|---|---|---|---|
| 32ff77e04b | |||
| d1f67c773f | |||
| c63b51ef75 | |||
| 456562ec1d | |||
| 0d0fbdf7d6 | |||
| f63bcedb13 | |||
| be92c766ac | |||
| 482bc1d9b7 | |||
| a2321cf070 | |||
| 29ccd5743e | |||
| d9de2f48d2 | |||
| 1ddeb0af46 | |||
| a7c2319f13 | |||
| 00e83b8561 | |||
| a9d24c725e | |||
| 458b6ebfc3 | |||
| 668aadcdc9 | |||
| 4461964791 | |||
| fa3e1bb05f | |||
| 8a89b27624 | |||
| c17c47ccc1 | |||
| c25b029006 | |||
| a449f65de9 | |||
| 3cf7db5cdf | |||
| 288515fa04 | |||
| f1033210cc | |||
| 1832d07caa | |||
| b704707abc | |||
| 31893414eb | |||
| aa50b1cc03 | |||
| 305ad24576 | |||
| ce87ac7edb | |||
| 2f8de7a298 | |||
| b8e96a2dca | |||
| 137279bcc5 | |||
| 553335940f | |||
| fd1b023e7d | |||
| eeea0f5b8f | |||
| ac5cc53006 | |||
| c3b13360e0 | |||
| 63292a1699 | |||
| 59c40524e0 | |||
| cdb09e84fc | |||
| 37e69088a1 | |||
| 7e2f79aab1 | |||
| c0debab0cb | |||
| 002c949f08 | |||
| d4682b1418 | |||
| 463e73fa1e | |||
| 76dd976854 | |||
| 736e5a8c12 |
19
.dockerignore
Normal file
19
.dockerignore
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
.venv/
|
||||||
|
node_modules/
|
||||||
|
tmp/
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
.cache/
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
9
.env
9
.env
@@ -1,4 +1,13 @@
|
|||||||
BOT_TOKEN=8495170789:AAHyjjhHwwVtd9_ROnjHqPHRdnmyVr1aeaY
|
BOT_TOKEN=8495170789:AAHyjjhHwwVtd9_ROnjHqPHRdnmyVr1aeaY
|
||||||
|
# BOT_TOKEN=8011562605:AAF3kyzrZJgii0Jx-H8Sum5Njbo0BdbsiAo
|
||||||
GEMINI_API_KEY=AIzaSyAHzDYhgjOqZZnvOnOFRGaSkKu4OAN3kZE
|
GEMINI_API_KEY=AIzaSyAHzDYhgjOqZZnvOnOFRGaSkKu4OAN3kZE
|
||||||
MONGO_HOST=mongodb://admin:super_secure_password@31.59.58.220:27017/
|
MONGO_HOST=mongodb://admin:super_secure_password@31.59.58.220:27017/
|
||||||
ADMIN_ID=567047
|
ADMIN_ID=567047
|
||||||
|
MINIO_ENDPOINT=http://31.59.58.220:9000
|
||||||
|
MINIO_ACCESS_KEY=admin
|
||||||
|
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||||
|
MINIO_BUCKET=ai-char
|
||||||
|
MODE=production
|
||||||
|
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||||
|
KLING_ACCESS_KEY=AngRfYYeLhPQB3pmr9CpHfgHPCrmeeM4
|
||||||
|
KLING_SECRET_KEY=ndJfyayfQgbg4bMnE49yHnkACPChKMp4
|
||||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
minio_backup.tar.gz
|
||||||
|
.DS_Store
|
||||||
|
**/__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.cpython-*.pyc
|
||||||
|
**/.DS_Store
|
||||||
|
.idea/ai-char-bot.iml
|
||||||
|
.idea
|
||||||
|
.venv
|
||||||
|
.vscode
|
||||||
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
10
.idea/ai-char-bot.iml
generated
Normal file
10
.idea/ai-char-bot.iml
generated
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.13 (ai-char-bot)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
16
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
16
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyAssertTypeInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
|
||||||
|
<inspection_tool class="PyAsyncCallInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
|
||||||
|
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredErrors">
|
||||||
|
<list>
|
||||||
|
<option value="N802" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyTypeCheckerInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
|
||||||
|
<inspection_tool class="PyUnreachableCodeInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES" />
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.13 (ai-char-bot)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.13 (ai-char-bot)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/ai-char-bot.iml" filepath="$PROJECT_DIR$/.idea/ai-char-bot.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
46
.vscode/launch.json
vendored
Normal file
46
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: FastAPI",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "uvicorn",
|
||||||
|
"args": [
|
||||||
|
"aiws:app",
|
||||||
|
"--reload",
|
||||||
|
"--port",
|
||||||
|
"8090",
|
||||||
|
"--host",
|
||||||
|
"0.0.0.0"
|
||||||
|
],
|
||||||
|
"jinja": true,
|
||||||
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Python: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Debug Tests: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "pytest",
|
||||||
|
"args": [
|
||||||
|
"${file}"
|
||||||
|
],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true,
|
||||||
|
"env": {
|
||||||
|
"PYTHONPATH": "${workspaceFolder}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -11,4 +11,4 @@ RUN pip install --no-cache-dir -r requirements.txt
|
|||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
# Запуск приложения (замени app.py на свой файл)
|
# Запуск приложения (замени app.py на свой файл)
|
||||||
CMD ["python", "main.py"]
|
CMD ["uvicorn", "aiws:app", "--host", "0.0.0.0", "--port", "8090"]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
class GoogleGenerationException(Exception):
|
class GoogleGenerationException(Exception):
|
||||||
message: str
|
def __init__(self, message: str):
|
||||||
pass
|
self.message = message
|
||||||
|
super().__init__(message)
|
||||||
Binary file not shown.
Binary file not shown.
@@ -1,7 +1,7 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Union
|
from typing import List, Union, Tuple, Dict, Any
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from google import genai
|
from google import genai
|
||||||
@@ -27,6 +27,7 @@ class GoogleAdapter:
|
|||||||
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
"""Вспомогательный метод для подготовки контента (текст + картинки)"""
|
||||||
contents = [prompt]
|
contents = [prompt]
|
||||||
if images_list:
|
if images_list:
|
||||||
|
logger.info(f"Preparing content with {len(images_list)} images")
|
||||||
for img_bytes in images_list:
|
for img_bytes in images_list:
|
||||||
try:
|
try:
|
||||||
# Gemini API требует PIL Image на входе
|
# Gemini API требует PIL Image на входе
|
||||||
@@ -34,6 +35,8 @@ class GoogleAdapter:
|
|||||||
contents.append(image)
|
contents.append(image)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing input image: {e}")
|
logger.error(f"Error processing input image: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("Preparing content with no images")
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
|
||||||
@@ -59,19 +62,24 @@ class GoogleAdapter:
|
|||||||
for part in response.parts:
|
for part in response.parts:
|
||||||
if part.text:
|
if part.text:
|
||||||
result_text += part.text
|
result_text += part.text
|
||||||
logger.error(f"Generated text: {result_text}")
|
logger.info(f"Generated text length: {len(result_text)}")
|
||||||
return result_text
|
return result_text
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini Text API Error: {e}")
|
logger.error(f"Gemini Text API Error: {e}")
|
||||||
return f"Ошибка генерации текста: {e}"
|
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||||||
|
|
||||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> List[io.BytesIO]:
|
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||||
Возвращает список байтовых потоков (готовых к отправке).
|
Возвращает список байтовых потоков (готовых к отправке).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
contents = self._prepare_contents(prompt, images_list)
|
contents = self._prepare_contents(prompt, images_list)
|
||||||
|
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
token_usage = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
@@ -86,6 +94,13 @@ class GoogleAdapter:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_time = datetime.now()
|
||||||
|
api_duration = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
if response.usage_metadata:
|
||||||
|
token_usage = response.usage_metadata.total_token_count
|
||||||
|
|
||||||
if response.parts is None and response.candidates[0].finish_reason is not None:
|
if response.parts is None and response.candidates[0].finish_reason is not None:
|
||||||
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
raise GoogleGenerationException(f"Generation blocked in cause of {response.candidates[0].finish_reason.value}")
|
||||||
|
|
||||||
@@ -111,9 +126,25 @@ class GoogleAdapter:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing output image: {e}")
|
logger.error(f"Error processing output image: {e}")
|
||||||
|
|
||||||
return generated_images
|
if generated_images:
|
||||||
|
logger.info(f"Successfully generated {len(generated_images)} images in {api_duration:.2f}s. Tokens: {token_usage}")
|
||||||
|
else:
|
||||||
|
logger.warning("No images text generated from parts")
|
||||||
|
|
||||||
|
input_tokens = 0
|
||||||
|
output_tokens = 0
|
||||||
|
if response.usage_metadata:
|
||||||
|
input_tokens = response.usage_metadata.prompt_token_count
|
||||||
|
output_tokens = response.usage_metadata.candidates_token_count
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"api_execution_time_seconds": api_duration,
|
||||||
|
"token_usage": token_usage,
|
||||||
|
"input_token_usage": input_tokens,
|
||||||
|
"output_token_usage": output_tokens
|
||||||
|
}
|
||||||
|
return generated_images, metrics
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Gemini Image API Error: {e}")
|
logger.error(f"Gemini Image API Error: {e}")
|
||||||
# В случае ошибки возвращаем пустой список (или можно рейзить исключение)
|
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||||||
return []
|
|
||||||
165
adapters/kling_adapter.py
Normal file
165
adapters/kling_adapter.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KLING_API_BASE = "https://api.klingai.com"
|
||||||
|
|
||||||
|
|
||||||
|
class KlingApiException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KlingAdapter:
|
||||||
|
def __init__(self, access_key: str, secret_key: str):
|
||||||
|
if not access_key or not secret_key:
|
||||||
|
raise ValueError("Kling API credentials are missing")
|
||||||
|
self.access_key = access_key
|
||||||
|
self.secret_key = secret_key
|
||||||
|
|
||||||
|
def _generate_token(self) -> str:
|
||||||
|
"""Generate a JWT token for Kling API authentication."""
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"iss": self.access_key,
|
||||||
|
"exp": now + 1800, # 30 minutes
|
||||||
|
"iat": now - 5, # небольшой запас назад
|
||||||
|
"nbf": now - 5,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, self.secret_key, algorithm="HS256",
|
||||||
|
headers={"typ": "JWT", "alg": "HS256"})
|
||||||
|
|
||||||
|
def _headers(self) -> dict:
|
||||||
|
return {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self._generate_token()}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def create_video_task(
|
||||||
|
self,
|
||||||
|
image_url: str,
|
||||||
|
prompt: str = "",
|
||||||
|
negative_prompt: str = "",
|
||||||
|
model_name: str = "kling-v2-6",
|
||||||
|
duration: int = 5,
|
||||||
|
mode: str = "std",
|
||||||
|
cfg_scale: float = 0.5,
|
||||||
|
aspect_ratio: str = "16:9",
|
||||||
|
callback_url: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Create an image-to-video generation task.
|
||||||
|
Returns the full task data dict including task_id.
|
||||||
|
"""
|
||||||
|
body: Dict[str, Any] = {
|
||||||
|
"model_name": model_name,
|
||||||
|
"image": image_url,
|
||||||
|
"prompt": prompt,
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
"duration": str(duration),
|
||||||
|
"mode": mode,
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
}
|
||||||
|
if callback_url:
|
||||||
|
body["callback_url"] = callback_url
|
||||||
|
|
||||||
|
logger.info(f"Creating Kling video task. Model: {model_name}, Duration: {duration}s, Mode: {mode}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video",
|
||||||
|
headers=self._headers(),
|
||||||
|
json=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
logger.info(f"Kling create task response: code={data.get('code')}, message={data.get('message')}")
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown Kling API error")
|
||||||
|
raise KlingApiException(f"Failed to create video task: {error_msg} (code={data.get('code')})")
|
||||||
|
|
||||||
|
task_data = data.get("data", {})
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
raise KlingApiException("No task_id returned from Kling API")
|
||||||
|
|
||||||
|
logger.info(f"Kling video task created: task_id={task_id}")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Query the status of a video generation task.
|
||||||
|
Returns the full task data dict.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{KLING_API_BASE}/v1/videos/image2video/{task_id}",
|
||||||
|
headers=self._headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if response.status_code != 200 or data.get("code") != 0:
|
||||||
|
error_msg = data.get("message", "Unknown error")
|
||||||
|
raise KlingApiException(f"Failed to query task {task_id}: {error_msg}")
|
||||||
|
|
||||||
|
return data.get("data", {})
|
||||||
|
|
||||||
|
async def wait_for_completion(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
poll_interval: int = 10,
|
||||||
|
timeout: int = 600,
|
||||||
|
progress_callback=None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Poll the task status until completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Kling task ID
|
||||||
|
poll_interval: seconds between polls
|
||||||
|
timeout: max seconds to wait
|
||||||
|
progress_callback: async callable(progress_pct: int) to report progress
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final task data dict with video URL on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KlingApiException on failure or timeout.
|
||||||
|
"""
|
||||||
|
start = time.time()
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
elapsed = time.time() - start
|
||||||
|
if elapsed > timeout:
|
||||||
|
raise KlingApiException(f"Video generation timed out after {timeout}s for task {task_id}")
|
||||||
|
|
||||||
|
task_data = await self.get_task_status(task_id)
|
||||||
|
status = task_data.get("task_status")
|
||||||
|
|
||||||
|
logger.info(f"Kling task {task_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||||
|
|
||||||
|
if status == "succeed":
|
||||||
|
logger.info(f"Kling task {task_id} completed successfully")
|
||||||
|
return task_data
|
||||||
|
|
||||||
|
if status == "failed":
|
||||||
|
fail_reason = task_data.get("task_status_msg", "Unknown failure")
|
||||||
|
raise KlingApiException(f"Video generation failed: {fail_reason}")
|
||||||
|
|
||||||
|
# Report progress estimate (linear approximation based on typical time)
|
||||||
|
if progress_callback:
|
||||||
|
# Estimate: typical gen is ~120s, cap at 90%
|
||||||
|
estimated_progress = min(int((elapsed / 120) * 90), 90)
|
||||||
|
attempt += 1
|
||||||
|
await progress_callback(estimated_progress)
|
||||||
|
|
||||||
|
await asyncio.sleep(poll_interval)
|
||||||
81
adapters/s3_adapter.py
Normal file
81
adapters/s3_adapter.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Optional, BinaryIO
|
||||||
|
import aioboto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
import os
|
||||||
|
|
||||||
|
class S3Adapter:
|
||||||
|
def __init__(self,
|
||||||
|
endpoint_url: str,
|
||||||
|
aws_access_key_id: str,
|
||||||
|
aws_secret_access_key: str,
|
||||||
|
bucket_name: str):
|
||||||
|
self.endpoint_url = endpoint_url
|
||||||
|
self.aws_access_key_id = aws_access_key_id
|
||||||
|
self.aws_secret_access_key = aws_secret_access_key
|
||||||
|
self.bucket_name = bucket_name
|
||||||
|
self.session = aioboto3.Session()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _get_client(self):
|
||||||
|
async with self.session.client(
|
||||||
|
"s3",
|
||||||
|
endpoint_url=self.endpoint_url,
|
||||||
|
aws_access_key_id=self.aws_access_key_id,
|
||||||
|
aws_secret_access_key=self.aws_secret_access_key,
|
||||||
|
) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
async def upload_file(self, object_name: str, data: bytes, content_type: Optional[str] = None):
|
||||||
|
"""Uploads bytes data to S3."""
|
||||||
|
try:
|
||||||
|
extra_args = {}
|
||||||
|
if content_type:
|
||||||
|
extra_args["ContentType"] = content_type
|
||||||
|
|
||||||
|
async with self._get_client() as client:
|
||||||
|
await client.put_object(
|
||||||
|
Bucket=self.bucket_name,
|
||||||
|
Key=object_name,
|
||||||
|
Body=data,
|
||||||
|
**extra_args
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except ClientError as e:
|
||||||
|
# logging.error(e)
|
||||||
|
print(f"Error uploading to S3: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_file(self, object_name: str) -> Optional[bytes]:
|
||||||
|
"""Downloads a file from S3 and returns bytes."""
|
||||||
|
try:
|
||||||
|
async with self._get_client() as client:
|
||||||
|
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
|
||||||
|
return await response['Body'].read()
|
||||||
|
except ClientError as e:
|
||||||
|
print(f"Error downloading from S3: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def delete_file(self, object_name: str):
|
||||||
|
"""Deletes a file from S3."""
|
||||||
|
try:
|
||||||
|
async with self._get_client() as client:
|
||||||
|
await client.delete_object(Bucket=self.bucket_name, Key=object_name)
|
||||||
|
return True
|
||||||
|
except ClientError as e:
|
||||||
|
print(f"Error deleting from S3: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_presigned_url(self, object_name: str, expiration: int = 3600) -> Optional[str]:
|
||||||
|
"""Generate a presigned URL to share an S3 object."""
|
||||||
|
try:
|
||||||
|
async with self._get_client() as client:
|
||||||
|
response = await client.generate_presigned_url(
|
||||||
|
'get_object',
|
||||||
|
Params={'Bucket': self.bucket_name, 'Key': object_name},
|
||||||
|
ExpiresIn=expiration
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except ClientError as e:
|
||||||
|
print(f"Error generating presigned URL: {e}")
|
||||||
|
return None
|
||||||
@@ -12,11 +12,16 @@ from aiogram.fsm.storage.mongo import MongoStorage
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from prometheus_client import Info
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
|
from api.service.album_service import AlbumService
|
||||||
from middlewares.album import AlbumMiddleware
|
from middlewares.album import AlbumMiddleware
|
||||||
from middlewares.auth import AuthMiddleware
|
from middlewares.auth import AuthMiddleware
|
||||||
from middlewares.dao import DaoMiddleware
|
from middlewares.dao import DaoMiddleware
|
||||||
@@ -25,8 +30,6 @@ from middlewares.dao import DaoMiddleware
|
|||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
# Предполагаю, что AssetsDAO лежит тут или в repos.assets_dao.
|
|
||||||
# Если нет - поправьте импорт!
|
|
||||||
|
|
||||||
|
|
||||||
# Роутеры
|
# Роутеры
|
||||||
@@ -37,12 +40,18 @@ from routers.assets_router import router as assets_router # Роутер бот
|
|||||||
from api.endpoints.assets_router import router as api_assets_router # Роутер FastAPI
|
from api.endpoints.assets_router import router as api_assets_router # Роутер FastAPI
|
||||||
from api.endpoints.character_router import router as api_char_router # Роутер FastAPI
|
from api.endpoints.character_router import router as api_char_router # Роутер FastAPI
|
||||||
from api.endpoints.generation_router import router as api_gen_router
|
from api.endpoints.generation_router import router as api_gen_router
|
||||||
|
from api.endpoints.auth import router as api_auth_router
|
||||||
|
from api.endpoints.admin import router as api_admin_router
|
||||||
|
from api.endpoints.album_router import router as api_album_router
|
||||||
|
from api.endpoints.project_router import router as project_api_router
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# --- КОНФИГУРАЦИЯ ---
|
# --- КОНФИГУРАЦИЯ ---
|
||||||
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
||||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||||
|
|
||||||
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
||||||
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
||||||
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||||
@@ -50,7 +59,8 @@ ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
|||||||
|
|
||||||
def setup_logging():
|
def setup_logging():
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
format="%(asctime)s [%(levelname)s] %(name)s (%(filename)s:%(lineno)d): %(message)s",
|
||||||
|
force=True)
|
||||||
|
|
||||||
|
|
||||||
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
|
||||||
@@ -59,12 +69,34 @@ bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTM
|
|||||||
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
|
||||||
mongo_client = AsyncIOMotorClient(MONGO_HOST)
|
mongo_client = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
|
||||||
|
# Репозитории
|
||||||
# Репозитории
|
# Репозитории
|
||||||
users_repo = UsersRepo(mongo_client)
|
users_repo = UsersRepo(mongo_client)
|
||||||
char_repo = CharacterRepo(mongo_client)
|
char_repo = CharacterRepo(mongo_client)
|
||||||
dao = DAO(mongo_client) # Главный DAO для бота
|
|
||||||
|
# S3 Adapter
|
||||||
|
s3_adapter = S3Adapter(
|
||||||
|
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
|
||||||
|
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||||
|
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||||
|
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||||
|
)
|
||||||
|
|
||||||
|
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||||
generation_service = GenerationService(dao, gemini)
|
|
||||||
|
# Kling Adapter (optional, for video generation)
|
||||||
|
kling_access_key = os.getenv("KLING_ACCESS_KEY", "")
|
||||||
|
kling_secret_key = os.getenv("KLING_SECRET_KEY", "")
|
||||||
|
kling_adapter = None
|
||||||
|
if kling_access_key and kling_secret_key:
|
||||||
|
kling_adapter = KlingAdapter(access_key=kling_access_key, secret_key=kling_secret_key)
|
||||||
|
logger.info("Kling adapter initialized")
|
||||||
|
else:
|
||||||
|
logger.warning("KLING_ACCESS_KEY / KLING_SECRET_KEY not set — video generation disabled")
|
||||||
|
|
||||||
|
generation_service = GenerationService(dao, gemini, s3_adapter, bot, kling_adapter)
|
||||||
|
album_service = AlbumService(dao)
|
||||||
|
|
||||||
# Dispatcher
|
# Dispatcher
|
||||||
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
||||||
@@ -104,6 +136,7 @@ gen_router.message.middleware(AlbumMiddleware(latency=0.8))
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# --- STARTUP ---
|
# --- STARTUP ---
|
||||||
|
setup_logging()
|
||||||
print("🚀 Starting up...")
|
print("🚀 Starting up...")
|
||||||
|
|
||||||
# 1. Настройка DAO для FastAPI
|
# 1. Настройка DAO для FastAPI
|
||||||
@@ -115,16 +148,21 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
app.state.mongo_client = mongo_client
|
app.state.mongo_client = mongo_client
|
||||||
app.state.gemini_client = gemini
|
app.state.gemini_client = gemini
|
||||||
|
app.state.bot = bot
|
||||||
|
app.state.s3_adapter = s3_adapter
|
||||||
|
app.state.kling_adapter = kling_adapter
|
||||||
|
app.state.album_service = album_service
|
||||||
|
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||||
|
|
||||||
print("✅ DB & DAO initialized")
|
print("✅ DB & DAO initialized")
|
||||||
|
|
||||||
# 2. ЗАПУСК БОТА (в фоне)
|
# 2. ЗАПУСК БОТА (в фоне)
|
||||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
||||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
||||||
polling_task = asyncio.create_task(
|
# polling_task = asyncio.create_task(
|
||||||
dp.start_polling(bot, handle_signals=False)
|
# dp.start_polling(bot, handle_signals=False)
|
||||||
)
|
# )
|
||||||
print("🤖 Bot polling started")
|
# print("🤖 Bot polling started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -155,10 +193,26 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Подключаем роутер API
|
# Подключаем роутеры API
|
||||||
|
app.include_router(api_auth_router)
|
||||||
|
app.include_router(api_admin_router)
|
||||||
app.include_router(api_assets_router)
|
app.include_router(api_assets_router)
|
||||||
app.include_router(api_char_router)
|
app.include_router(api_char_router)
|
||||||
app.include_router(api_gen_router)
|
app.include_router(api_gen_router)
|
||||||
|
app.include_router(api_album_router)
|
||||||
|
app.include_router(project_api_router)
|
||||||
|
|
||||||
|
# Prometheus Metrics (Instrument after all routers are added)
|
||||||
|
Instrumentator(
|
||||||
|
should_group_status_codes=False, # 200/201/204 отдельно (по желанию)
|
||||||
|
should_ignore_untemplated=False, # НЕ игнорировать "сырые" пути
|
||||||
|
# should_group_untemplated=False, # (опционально) не схлопывать untemplated в "none"
|
||||||
|
).instrument(
|
||||||
|
app,
|
||||||
|
metric_namespace="ai_bot",
|
||||||
|
).expose(app, endpoint="/metrics", include_in_schema=False)
|
||||||
|
app_info = Info("fastapi_app_info", "FastAPI application info")
|
||||||
|
app_info.info({"app_name": "ai-bot"})
|
||||||
|
|
||||||
|
|
||||||
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
# --- ХЕНДЛЕРЫ БОТА (Main Router) ---
|
||||||
@@ -182,11 +236,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Создаем конфигурацию uvicorn вручную
|
# Создаем конфигурацию uvicorn вручную
|
||||||
# loop="asyncio" заставляет использовать стандартный цикл
|
# loop="asyncio" заставляет использовать стандартный цикл
|
||||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
# Запускаем сервер (lifespan запустится внутри)
|
# Запускаем сервер (lifespan запустится внутри)
|
||||||
@@ -3,12 +3,18 @@ from fastapi import Request, Depends
|
|||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.kling_adapter import KlingAdapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
|
||||||
# ... ваши импорты ...
|
# ... ваши импорты ...
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
|
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
# Провайдеры "сырых" клиентов из состояния приложения
|
# Провайдеры "сырых" клиентов из состояния приложения
|
||||||
def get_mongo_client(request: Request) -> AsyncIOMotorClient:
|
def get_mongo_client(request: Request) -> AsyncIOMotorClient:
|
||||||
return request.app.state.mongo_client
|
return request.app.state.mongo_client
|
||||||
@@ -16,15 +22,35 @@ def get_mongo_client(request: Request) -> AsyncIOMotorClient:
|
|||||||
def get_gemini_client(request: Request) -> GoogleAdapter:
|
def get_gemini_client(request: Request) -> GoogleAdapter:
|
||||||
return request.app.state.gemini_client
|
return request.app.state.gemini_client
|
||||||
|
|
||||||
|
def get_bot_client(request: Request) -> Bot:
|
||||||
|
return request.app.state.bot
|
||||||
|
|
||||||
|
def get_s3_adapter(request: Request) -> Optional[S3Adapter]:
|
||||||
|
return request.app.state.s3_adapter
|
||||||
|
|
||||||
# Провайдер DAO (собирается из mongo_client)
|
# Провайдер DAO (собирается из mongo_client)
|
||||||
def get_dao(mongo_client: AsyncIOMotorClient = Depends(get_mongo_client)) -> DAO:
|
def get_dao(
|
||||||
|
mongo_client: AsyncIOMotorClient = Depends(get_mongo_client),
|
||||||
|
s3_adapter: Optional[S3Adapter] = Depends(get_s3_adapter)
|
||||||
|
) -> DAO:
|
||||||
# FastAPI кэширует результат Depends в рамках одного запроса,
|
# FastAPI кэширует результат Depends в рамках одного запроса,
|
||||||
# так что DAO создастся один раз за запрос.
|
# так что DAO создастся один раз за запрос.
|
||||||
return DAO(mongo_client)
|
return DAO(mongo_client, s3_adapter)
|
||||||
|
|
||||||
|
def get_kling_adapter(request: Request) -> Optional[KlingAdapter]:
|
||||||
|
return request.app.state.kling_adapter
|
||||||
|
|
||||||
# Провайдер сервиса (собирается из DAO и Gemini)
|
# Провайдер сервиса (собирается из DAO и Gemini)
|
||||||
def get_generation_service(
|
def get_generation_service(
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
gemini: GoogleAdapter = Depends(get_gemini_client)
|
gemini: GoogleAdapter = Depends(get_gemini_client),
|
||||||
|
s3_adapter: S3Adapter = Depends(get_s3_adapter),
|
||||||
|
bot: Bot = Depends(get_bot_client),
|
||||||
|
kling_adapter: Optional[KlingAdapter] = Depends(get_kling_adapter),
|
||||||
) -> GenerationService:
|
) -> GenerationService:
|
||||||
return GenerationService(dao, gemini)
|
return GenerationService(dao, gemini, s3_adapter, bot, kling_adapter=kling_adapter)
|
||||||
|
|
||||||
|
from fastapi import Header
|
||||||
|
|
||||||
|
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
|
||||||
|
return x_project_id
|
||||||
96
api/endpoints/admin.py
Normal file
96
api/endpoints/admin.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from typing import Annotated, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from repos.user_repo import UsersRepo, UserStatus
|
||||||
|
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token")
|
||||||
|
|
||||||
|
from api.endpoints.auth import get_users_repo
|
||||||
|
|
||||||
|
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo: Annotated[UsersRepo, Depends(get_users_repo)]):
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
user = await repo.get_user_by_username(username)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_current_admin(user: Annotated[dict, Depends(get_current_user)]):
|
||||||
|
if not user.get("is_admin"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
username: str
|
||||||
|
full_name: str | None = None
|
||||||
|
status: str
|
||||||
|
created_at: str | None = None
|
||||||
|
is_admin: bool
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
@router.get("/approvals", response_model=List[UserResponse])
|
||||||
|
async def list_pending_users(
|
||||||
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
|
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||||
|
):
|
||||||
|
users = await repo.get_pending_users()
|
||||||
|
# Pydantic conversion handles the list of dicts
|
||||||
|
return [
|
||||||
|
UserResponse(
|
||||||
|
username=u["username"],
|
||||||
|
full_name=u.get("full_name"),
|
||||||
|
status=u["status"],
|
||||||
|
created_at=str(u.get("created_at")),
|
||||||
|
is_admin=u.get("is_admin", False)
|
||||||
|
) for u in users
|
||||||
|
]
|
||||||
|
|
||||||
|
@router.post("/approve/{username}")
|
||||||
|
async def approve_user(
|
||||||
|
username: str,
|
||||||
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
|
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||||
|
):
|
||||||
|
user = await repo.get_user_by_username(username)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
await repo.approve_user(username)
|
||||||
|
return {"message": f"User {username} approved"}
|
||||||
|
|
||||||
|
@router.post("/deny/{username}")
|
||||||
|
async def deny_user(
|
||||||
|
username: str,
|
||||||
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
|
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||||
|
):
|
||||||
|
user = await repo.get_user_by_username(username)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
await repo.deny_user(username)
|
||||||
|
return {"message": f"User {username} denied"}
|
||||||
81
api/endpoints/album_router.py
Normal file
81
api/endpoints/album_router.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from api.models.GenerationRequest import GenerationResponse
|
||||||
|
from models.Album import Album
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||||
|
|
||||||
|
class AlbumCreateRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class AlbumUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class AlbumResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
generation_ids: List[str] = []
|
||||||
|
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||||
|
|
||||||
|
@router.post("", response_model=AlbumResponse)
|
||||||
|
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.get("", response_model=List[AlbumResponse])
|
||||||
|
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
albums = await service.get_albums(limit=limit, offset=offset)
|
||||||
|
return [AlbumResponse(**album.model_dump()) for album in albums]
|
||||||
|
|
||||||
|
@router.get("/{album_id}", response_model=AlbumResponse)
|
||||||
|
async def get_album(request: Request, album_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.put("/{album_id}", response_model=AlbumResponse)
|
||||||
|
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
|
||||||
|
if not album:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_album(request: Request, album_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
deleted = await service.delete_album(album_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
|
||||||
|
@router.post("/{album_id}/generations/{generation_id}")
|
||||||
|
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
success = await service.add_generation_to_album(album_id, generation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||||
|
return {"status": "success"}
|
||||||
|
|
||||||
|
@router.delete("/{album_id}/generations/{generation_id}")
|
||||||
|
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
success = await service.remove_generation_from_album(album_id, generation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||||
|
return {"status": "success"}
|
||||||
|
|
||||||
|
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
||||||
|
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||||
|
return [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
@@ -1,42 +1,212 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
|
from bson import ObjectId
|
||||||
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
from fastapi import APIRouter, UploadFile, File, Form, Depends
|
||||||
from fastapi.openapi.models import MediaType
|
from fastapi.openapi.models import MediaType
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from pymongo import MongoClient
|
||||||
from starlette import status
|
from starlette import status
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response, JSONResponse
|
from starlette.responses import Response, JSONResponse
|
||||||
|
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||||
from models.Asset import Asset, AssetType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.dependency import get_dao
|
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_project_id
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
router = APIRouter(prefix="/api/assets", tags=["Assets"])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{asset_id}")
|
@router.get("/{asset_id}")
|
||||||
async def get_asset(asset_id: str, request: Request,dao: DAO = Depends(get_dao),) -> Response:
|
async def get_asset(
|
||||||
|
asset_id: str,
|
||||||
|
request: Request,
|
||||||
|
thumbnail: bool = False,
|
||||||
|
dao: DAO = Depends(get_dao)
|
||||||
|
) -> Response:
|
||||||
|
logger.debug(f"get_asset called for ID: {asset_id}, thumbnail={thumbnail}")
|
||||||
asset = await dao.assets.get_asset(asset_id)
|
asset = await dao.assets.get_asset(asset_id)
|
||||||
# 2. Проверка на существование
|
# 2. Проверка на существование
|
||||||
if not asset:
|
if not asset:
|
||||||
raise HTTPException(status_code=404, detail="Asset not found")
|
raise HTTPException(status_code=404, detail="Asset not found")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
# Кэшировать на 1 год (31536000 сек)
|
# Кэшировать на 1 год (31536000 сек)
|
||||||
"Cache-Control": "public, max-age=31536000, immutable"
|
"Cache-Control": "public, max-age=31536000, immutable"
|
||||||
}
|
}
|
||||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
|
||||||
|
content = asset.data
|
||||||
|
media_type = "image/png" # Default, or detect
|
||||||
|
|
||||||
|
if thumbnail and asset.thumbnail:
|
||||||
|
content = asset.thumbnail
|
||||||
|
media_type = "image/jpeg"
|
||||||
|
|
||||||
|
return Response(content=content, media_type=media_type, headers=headers)
|
||||||
|
|
||||||
|
@router.delete("/orphans", dependencies=[Depends(get_current_user)])
|
||||||
|
async def delete_orphan_assets_from_minio(
|
||||||
|
mongo: AsyncIOMotorClient = Depends(get_mongo_client),
|
||||||
|
minio_client: S3Adapter = Depends(get_s3_adapter),
|
||||||
|
*,
|
||||||
|
assets_collection: str = "assets",
|
||||||
|
generations_collection: str = "generations",
|
||||||
|
asset_type: Optional[str] = "generated",
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
dry_run: bool = True,
|
||||||
|
mark_assets_deleted: bool = False,
|
||||||
|
batch_size: int = 500,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||||
|
assets = db[assets_collection]
|
||||||
|
|
||||||
|
match_assets: Dict[str, Any] = {}
|
||||||
|
if asset_type is not None:
|
||||||
|
match_assets["type"] = asset_type
|
||||||
|
if project_id is not None:
|
||||||
|
match_assets["project_id"] = project_id
|
||||||
|
|
||||||
|
pipeline: List[Dict[str, Any]] = [
|
||||||
|
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||||
|
{
|
||||||
|
"$lookup": {
|
||||||
|
"from": generations_collection,
|
||||||
|
"let": {"assetIdStr": {"$toString": "$_id"}},
|
||||||
|
"pipeline": [
|
||||||
|
# считаем "живыми" те, где is_deleted != True (т.е. false или поля нет)
|
||||||
|
{"$match": {"is_deleted": {"$ne": True}}},
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"$expr": {
|
||||||
|
"$in": [
|
||||||
|
"$$assetIdStr",
|
||||||
|
{"$ifNull": ["$result_list", []]},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$limit": 1},
|
||||||
|
],
|
||||||
|
"as": "alive_generations",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"$expr": {"$eq": [{"$size": "$alive_generations"}, 0]}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"_id": 1,
|
||||||
|
"minio_object_name": 1,
|
||||||
|
"minio_thumbnail_object_name": 1,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
print(pipeline)
|
||||||
|
cursor = assets.aggregate(pipeline, allowDiskUse=True, batchSize=batch_size)
|
||||||
|
|
||||||
|
deleted_objects = 0
|
||||||
|
deleted_assets = 0
|
||||||
|
errors: List[Dict[str, Any]] = []
|
||||||
|
orphan_asset_ids: List[ObjectId] = []
|
||||||
|
|
||||||
|
async for asset in cursor:
|
||||||
|
aid = asset["_id"]
|
||||||
|
obj = asset.get("minio_object_name")
|
||||||
|
thumb = asset.get("minio_thumbnail_object_name")
|
||||||
|
|
||||||
|
orphan_asset_ids.append(aid)
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
print(f"[DRY RUN] orphan asset={aid} obj={obj} thumb={thumb}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if obj:
|
||||||
|
await minio_client.delete_file(obj)
|
||||||
|
deleted_objects += 1
|
||||||
|
|
||||||
|
if thumb:
|
||||||
|
await minio_client.delete_file(thumb)
|
||||||
|
deleted_objects += 1
|
||||||
|
|
||||||
|
deleted_assets += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
errors.append({"asset_id": str(aid), "error": str(e)})
|
||||||
|
|
||||||
|
if (not dry_run) and mark_assets_deleted and orphan_asset_ids:
|
||||||
|
res = await assets.update_many(
|
||||||
|
{"_id": {"$in": orphan_asset_ids}},
|
||||||
|
{"$set": {"is_deleted": True}},
|
||||||
|
)
|
||||||
|
marked = res.modified_count
|
||||||
|
else:
|
||||||
|
marked = 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dry_run": dry_run,
|
||||||
|
"filter": {
|
||||||
|
"asset_type": asset_type,
|
||||||
|
"project_id": project_id,
|
||||||
|
},
|
||||||
|
"orphans_found": len(orphan_asset_ids),
|
||||||
|
"deleted_assets": deleted_assets,
|
||||||
|
"deleted_objects": deleted_objects,
|
||||||
|
"marked_assets_deleted": marked,
|
||||||
|
"errors": errors,
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.delete("/{asset_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||||
|
async def delete_asset(
|
||||||
|
asset_id: str,
|
||||||
|
dao: DAO = Depends(get_dao)
|
||||||
|
):
|
||||||
|
logger.info(f"delete_asset called for ID: {asset_id}")
|
||||||
|
# 1. Проверяем наличие (опционально, delete_one вернет false если нет, но для 404 нужно знать)
|
||||||
|
# Можно просто попробовать удалить
|
||||||
|
deleted = await dao.assets.delete_asset(asset_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Asset not found")
|
||||||
|
|
||||||
|
logger.info(f"Asset {asset_id} deleted successfully")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
@router.get("", dependencies=[Depends(get_current_user)])
|
||||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), limit: int = 10, offset: int = 0) -> AssetsResponse:
|
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
|
||||||
assets = await dao.assets.get_assets(limit, offset)
|
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||||
assets = await dao.assets.get_assets()
|
|
||||||
total_count = await dao.assets.get_asset_count()
|
|
||||||
|
|
||||||
return AssetsResponse(assets=assets, total_count=total_count)
|
user_id_filter = current_user["id"]
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
|
||||||
|
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
|
||||||
|
total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
# Manually map to DTO to trigger computed fields validation if necessary,
|
||||||
|
# but primarily to ensure valid Pydantic models for the response list.
|
||||||
|
# Asset.model_dump() generally includes computed fields (url) if configured.
|
||||||
|
# Let's ensure strict conversion.
|
||||||
|
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
||||||
|
|
||||||
|
return AssetsResponse(assets=asset_responses, total_count=total_count)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -45,31 +215,96 @@ async def upload_asset(
|
|||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
linked_char_id: Optional[str] = Form(None),
|
linked_char_id: Optional[str] = Form(None),
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id)
|
||||||
):
|
):
|
||||||
|
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||||
if not file.content_type:
|
if not file.content_type:
|
||||||
raise HTTPException(status_code=400, detail="Unknown file type")
|
raise HTTPException(status_code=400, detail="Unknown file type")
|
||||||
|
|
||||||
if not file.content_type.startswith("image/"):
|
if not file.content_type.startswith("image/"):
|
||||||
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
|
||||||
data = await file.read()
|
data = await file.read()
|
||||||
if not data:
|
if not data:
|
||||||
raise HTTPException(status_code=400, detail="Empty file")
|
raise HTTPException(status_code=400, detail="Empty file")
|
||||||
|
|
||||||
|
# Generate thumbnail
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, data)
|
||||||
|
|
||||||
asset = Asset(
|
asset = Asset(
|
||||||
name=file.filename or "upload",
|
name=file.filename or "upload",
|
||||||
type=AssetType.IMAGE,
|
type=AssetType.UPLOADED,
|
||||||
|
content_type=AssetContentType.IMAGE,
|
||||||
linked_char_id=linked_char_id,
|
linked_char_id=linked_char_id,
|
||||||
data=data,
|
data=data,
|
||||||
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=str(current_user["_id"]),
|
||||||
|
project_id=project_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
asset_id = await dao.assets.create_asset(asset)
|
asset_id = await dao.assets.create_asset(asset)
|
||||||
asset.id = str(asset_id)
|
asset.id = str(asset_id)
|
||||||
|
logger.info(f"Asset created successfully. ID: {asset_id}")
|
||||||
|
|
||||||
return AssetResponse(
|
return AssetResponse(
|
||||||
id=asset.id,
|
id=asset.id,
|
||||||
name=asset.name,
|
name=asset.name,
|
||||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||||
|
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
|
||||||
linked_char_id=asset.linked_char_id,
|
linked_char_id=asset.linked_char_id,
|
||||||
created_at=asset.created_at,
|
created_at=asset.created_at,
|
||||||
|
url=asset.url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/regenerate_thumbnails", dependencies=[Depends(get_current_user)])
|
||||||
|
async def regenerate_thumbnails(dao: DAO = Depends(get_dao)):
|
||||||
|
"""
|
||||||
|
Regenerates thumbnails for all existing image assets that don't have one.
|
||||||
|
"""
|
||||||
|
logger.info("Starting thumbnail regeneration task")
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Get all assets (pagination loop might be needed for huge datasets, but simple list for now)
|
||||||
|
# We'll rely on DAO providing a method or just fetch large chunk.
|
||||||
|
# Assuming get_assets might have limit, let's fetch in chunks or just all if possible within limit.
|
||||||
|
# Ideally should use a specific repo method for iteration.
|
||||||
|
# For now, let's fetch first 1000 or similar.
|
||||||
|
assets = await dao.assets.get_assets(limit=1000, offset=0, with_data=True)
|
||||||
|
logger.info(f"Found {len(assets)} assets")
|
||||||
|
count = 0
|
||||||
|
updated = 0
|
||||||
|
|
||||||
|
for asset in assets:
|
||||||
|
if asset.content_type == AssetContentType.IMAGE and asset.data :
|
||||||
|
try:
|
||||||
|
thumb = await asyncio.to_thread(create_thumbnail, asset.data)
|
||||||
|
if thumb:
|
||||||
|
asset.thumbnail = thumb
|
||||||
|
await dao.assets.update_asset(asset.id, asset)
|
||||||
|
updated += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to regenerate thumbnail for asset {asset.id}: {e}")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
|
||||||
|
return {"status": "completed", "processed": count, "updated": updated}
|
||||||
|
|
||||||
|
@router.post("/migrate_to_minio", dependencies=[Depends(get_current_user)])
|
||||||
|
async def migrate_to_minio(dao: DAO = Depends(get_dao)):
|
||||||
|
"""
|
||||||
|
Migrates assets from MongoDB to MinIO.
|
||||||
|
"""
|
||||||
|
logger.info("Starting migration to MinIO")
|
||||||
|
result = await dao.assets.migrate_to_minio()
|
||||||
|
logger.info(f"Migration result: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|||||||
123
api/endpoints/auth.py
Normal file
123
api/endpoints/auth.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from repos.user_repo import UsersRepo, UserStatus
|
||||||
|
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token")
|
||||||
|
|
||||||
|
async def get_users_repo(request: Request) -> UsersRepo:
|
||||||
|
if not hasattr(request.app.state, "users_repo"):
|
||||||
|
raise HTTPException(status_code=500, detail="Users repo not initialized")
|
||||||
|
return request.app.state.users_repo
|
||||||
|
|
||||||
|
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo: Annotated[UsersRepo, Depends(get_users_repo)]):
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
user = await repo.get_user_by_username(username)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_current_admin(user: Annotated[dict, Depends(get_current_user)]):
|
||||||
|
if not user.get("is_admin"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Not enough permissions",
|
||||||
|
)
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
class UserRegister(BaseModel):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
full_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
full_name: str | None = None
|
||||||
|
status: str
|
||||||
|
is_admin: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def read_users_me(current_user: Annotated[dict, Depends(get_current_user)]):
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register")
|
||||||
|
async def register(user_data: UserRegister, repo: Annotated[UsersRepo, Depends(get_users_repo)]):
|
||||||
|
try:
|
||||||
|
await repo.create_user(
|
||||||
|
username=user_data.username,
|
||||||
|
password=user_data.password,
|
||||||
|
full_name=user_data.full_name
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
return {"message": "Registration successful. Please wait for administrator approval."}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/token", response_model=Token)
|
||||||
|
async def login_for_access_token(
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||||
|
):
|
||||||
|
user = await repo.get_user_by_username(form_data.username)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверяем пароль
|
||||||
|
if not verify_password(form_data.password, user["hashed_password"]):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверка статуса
|
||||||
|
if user.get("status") != UserStatus.ALLOWED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Account is not approved yet. Please contact administrator.",
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": user["username"]}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
@@ -1,44 +1,187 @@
|
|||||||
from typing import List, Any, Coroutine
|
from typing import List, Any, Coroutine, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from api.models.AssetDTO import AssetsResponse
|
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
|
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
from api.dependency import get_dao
|
from api.dependency import get_dao
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/characters", tags=["Characters"])
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_project_id
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Character])
|
@router.get("/", response_model=List[Character])
|
||||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]:
|
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
|
||||||
characters = await dao.chars.get_all_characters()
|
logger.info("get_characters called")
|
||||||
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
|
||||||
return characters
|
return characters
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
@router.get("/{character_id}/assets", response_model=AssetsResponse)
|
||||||
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
|
||||||
offset: int = 0, ) -> AssetsResponse:
|
offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
|
||||||
|
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||||
character = await dao.chars.get_character(character_id)
|
character = await dao.chars.get_character(character_id)
|
||||||
if character is None:
|
if character is None:
|
||||||
raise HTTPException(status_code=404, detail="Character not found")
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
# Access Check
|
||||||
|
is_creator = character.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
|
||||||
|
# Filter assets by user ownership as well?
|
||||||
|
# Usually if you own character, you see its assets.
|
||||||
|
# But assets also have specific created_by.
|
||||||
|
# Let's assume if you own character you can see its assets.
|
||||||
|
|
||||||
total_count = await dao.assets.get_asset_count(character_id)
|
total_count = await dao.assets.get_asset_count(character_id)
|
||||||
return AssetsResponse(assets=assets, total_count=total_count)
|
|
||||||
|
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
|
||||||
|
return AssetsResponse(assets=asset_responses, total_count=total_count)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{character_id}", response_model=Character)
|
@router.get("/{character_id}", response_model=Character)
|
||||||
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character:
|
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
|
||||||
|
logger.debug(f"get_character_by_id called. ID: {character_id}")
|
||||||
character = await dao.chars.get_character(character_id)
|
character = await dao.chars.get_character(character_id)
|
||||||
|
|
||||||
|
if not character:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
if character:
|
||||||
|
is_creator = character.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
return character
|
return character
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{character_id}/_run", response_model=Asset)
|
@router.post("/", response_model=Character)
|
||||||
|
async def create_character(
|
||||||
|
char_req: CharacterCreateRequest,
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
) -> Character:
|
||||||
|
logger.info("create_character called")
|
||||||
|
char_req.project_id = project_id
|
||||||
|
char_data = char_req.model_dump()
|
||||||
|
char_data["created_by"] = str(current_user["_id"])
|
||||||
|
if "id" not in char_data:
|
||||||
|
char_data["id"] = None
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
|
||||||
|
new_char = Character(**char_data)
|
||||||
|
new_char.avatar_asset_id = new_char.avatar_image.split("/")[-1]
|
||||||
|
created_char = await dao.chars.add_character(new_char)
|
||||||
|
return created_char
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{character_id}", response_model=Character)
|
||||||
|
async def update_character(
|
||||||
|
character_id: str,
|
||||||
|
char_update: CharacterUpdateRequest,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
) -> Character:
|
||||||
|
logger.info(f"update_character called. ID: {character_id}")
|
||||||
|
|
||||||
|
existing_char = await dao.chars.get_character(character_id)
|
||||||
|
if not existing_char:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
update_data = char_update.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
if "project_id" in update_data and update_data["project_id"]:
|
||||||
|
new_project_id = update_data["project_id"]
|
||||||
|
project = await dao.projects.get_project(new_project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Target project access denied")
|
||||||
|
|
||||||
|
updated_char_data = existing_char.model_dump()
|
||||||
|
updated_char_data.update(update_data)
|
||||||
|
|
||||||
|
updated_char = Character(**updated_char_data)
|
||||||
|
|
||||||
|
success = await dao.chars.update_char(character_id, updated_char)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update character")
|
||||||
|
|
||||||
|
return updated_char
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{character_id}", status_code=204)
|
||||||
|
async def delete_character(
|
||||||
|
character_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
logger.info(f"delete_character called. ID: {character_id}")
|
||||||
|
|
||||||
|
existing_char = await dao.chars.get_character(character_id)
|
||||||
|
if not existing_char:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
is_creator = existing_char.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
success = await dao.chars.delete_character(character_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||||
request: Request) -> GenerationResponse:
|
request: Request) -> GenerationResponse:
|
||||||
|
logger.info(f"post_character_generation called. CharacterID: {character_id}")
|
||||||
generation_service = request.app.state.generation_service
|
generation_service = request.app.state.generation_service
|
||||||
|
|||||||
@@ -1,47 +1,193 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||||
from fastapi.params import Depends
|
from fastapi.params import Depends
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from api import service
|
from api import service
|
||||||
from api.dependency import get_generation_service
|
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, PromptResponse, PromptRequest
|
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from models.Generation import Generation
|
from models.Generation import Generation
|
||||||
|
|
||||||
|
from starlette import status
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
|
||||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(
|
generation_service: GenerationService = Depends(
|
||||||
get_generation_service)) -> PromptResponse:
|
get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||||
|
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||||
return PromptResponse(prompt=generated_prompt)
|
return PromptResponse(prompt=generated_prompt)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[GenerationResponse])
|
@router.post("/prompt-from-image", response_model=PromptResponse)
|
||||||
async def get_generations(character_id: Optional[str], limit: int = 10, offset: int = 0,
|
async def prompt_from_image(
|
||||||
generation_service: GenerationService = Depends(get_generation_service)):
|
prompt: Optional[str] = Form(None),
|
||||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
|
images: List[UploadFile] = File(...),
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
) -> PromptResponse:
|
||||||
|
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||||
|
images_bytes = []
|
||||||
|
for image in images:
|
||||||
|
content = await image.read()
|
||||||
|
images_bytes.append(content)
|
||||||
|
|
||||||
|
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
||||||
|
return PromptResponse(prompt=generated_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=GenerationsResponse)
|
||||||
|
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)):
|
||||||
|
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
||||||
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None # Show all project generations
|
||||||
|
|
||||||
|
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/_run", response_model=GenerationResponse)
|
@router.post("/_run", response_model=GenerationResponse)
|
||||||
async def post_generation(generation: GenerationRequest, request: Request,
|
async def post_generation(generation: GenerationRequest, request: Request,
|
||||||
generation_service: GenerationService = Depends(
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
get_generation_service)) -> GenerationResponse:
|
current_user: dict = Depends(get_current_user),
|
||||||
return await generation_service.create_generation_task(generation)
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)) -> GenerationResponse:
|
||||||
|
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
generation.project_id = project_id
|
||||||
|
|
||||||
|
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||||
async def get_generation(generation_id: str,
|
async def get_generation(generation_id: str,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse:
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
return await generation_service.get_generation(generation_id)
|
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
||||||
|
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||||
|
gen = await generation_service.get_generation(generation_id)
|
||||||
|
if gen and gen.created_by != str(current_user["_id"]):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
return gen
|
||||||
|
|
||||||
|
|
||||||
@router.get("/running")
|
@router.get("/running")
|
||||||
async def get_running_generations(request: Request,
|
async def get_running_generations(request: Request,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)):
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
return await generation_service.get_running_generations()
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)):
|
||||||
|
|
||||||
|
user_id_filter = str(current_user["_id"])
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
user_id_filter = None
|
||||||
|
|
||||||
|
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/video/_run", response_model=GenerationResponse)
|
||||||
|
async def post_video_generation(
|
||||||
|
video_request: VideoGenerationRequest,
|
||||||
|
request: Request,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: Optional[str] = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
) -> GenerationResponse:
|
||||||
|
"""Start image-to-video generation using Kling AI."""
|
||||||
|
logger.info(f"post_video_generation called. AssetId: {video_request.image_asset_id}, Duration: {video_request.duration}s, Mode: {video_request.mode}")
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
video_request.project_id = project_id
|
||||||
|
|
||||||
|
return await generation_service.create_video_generation_task(video_request, user_id=str(current_user.get("_id")))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
|
async def import_external_generation(
|
||||||
|
request: Request,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
x_signature: str = Header(..., alias="X-Signature")
|
||||||
|
) -> GenerationResponse:
|
||||||
|
"""
|
||||||
|
Import a generation from an external source.
|
||||||
|
Requires server-to-server authentication via HMAC signature.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from utils.external_auth import verify_signature
|
||||||
|
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||||
|
|
||||||
|
logger.info("import_external_generation called")
|
||||||
|
|
||||||
|
# Get raw request body for signature verification
|
||||||
|
body = await request.body()
|
||||||
|
|
||||||
|
# Verify signature
|
||||||
|
secret = os.getenv("EXTERNAL_API_SECRET")
|
||||||
|
if not secret:
|
||||||
|
logger.error("EXTERNAL_API_SECRET not configured")
|
||||||
|
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||||
|
|
||||||
|
if not verify_signature(body, x_signature, secret):
|
||||||
|
logger.warning("Invalid signature for external generation import")
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||||
|
|
||||||
|
# Parse request body
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = json.loads(body.decode('utf-8'))
|
||||||
|
external_gen = ExternalGenerationRequest(**data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to parse request body: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||||
|
|
||||||
|
# Import generation
|
||||||
|
try:
|
||||||
|
generation = await generation_service.import_external_generation(external_gen)
|
||||||
|
return GenerationResponse(**generation.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to import external generation: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_generation(generation_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)):
|
||||||
|
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||||
|
deleted = await generation_service.delete_generation(generation_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
|
return None
|
||||||
167
api/endpoints/project_router.py
Normal file
167
api/endpoints/project_router.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from api.dependency import get_dao
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from models.Project import Project
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
||||||
|
|
||||||
|
class ProjectCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class ProjectResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
owner_id: str
|
||||||
|
members: List[str]
|
||||||
|
is_owner: bool = False
|
||||||
|
|
||||||
|
@router.post("", response_model=ProjectResponse)
|
||||||
|
async def create_project(
|
||||||
|
project_data: ProjectCreate,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
new_project = Project(
|
||||||
|
name=project_data.name,
|
||||||
|
description=project_data.description,
|
||||||
|
owner_id=user_id,
|
||||||
|
members=[user_id]
|
||||||
|
)
|
||||||
|
project_id = await dao.projects.create_project(new_project)
|
||||||
|
|
||||||
|
# Add project to user's project list
|
||||||
|
# Assuming user_repo has a method to add project or we do it directly?
|
||||||
|
# UserRepo doesn't have add_project method yet.
|
||||||
|
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
|
||||||
|
# Better to update UserRepo. For now, let's just return success.
|
||||||
|
# But user needs to see it in list.
|
||||||
|
# Update user in DB
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ProjectResponse(
|
||||||
|
id=project_id,
|
||||||
|
name=new_project.name,
|
||||||
|
description=new_project.description,
|
||||||
|
owner_id=new_project.owner_id,
|
||||||
|
members=new_project.members,
|
||||||
|
is_owner=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("", response_model=List[ProjectResponse])
|
||||||
|
async def get_my_projects(
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
projects = await dao.projects.get_projects_by_user(user_id)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
for p in projects:
|
||||||
|
responses.append(ProjectResponse(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
owner_id=p.owner_id,
|
||||||
|
members=p.members,
|
||||||
|
is_owner=(p.owner_id == user_id)
|
||||||
|
))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
class MemberAdd(BaseModel):
|
||||||
|
username: str
|
||||||
|
|
||||||
|
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
|
||||||
|
async def add_member(
|
||||||
|
project_id: str,
|
||||||
|
member_data: MemberAdd,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
if project.owner_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Only owner can add members")
|
||||||
|
|
||||||
|
target_user = await dao.users.get_user_by_username(member_data.username)
|
||||||
|
if not target_user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
target_user_id = str(target_user["_id"])
|
||||||
|
|
||||||
|
if target_user_id in project.members:
|
||||||
|
return {"message": "User already in project"}
|
||||||
|
|
||||||
|
await dao.projects.add_member(project_id, target_user_id)
|
||||||
|
|
||||||
|
# Update target user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": target_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Member added"}
|
||||||
|
|
||||||
|
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
|
||||||
|
async def join_project(
|
||||||
|
project_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
# Retrieve project to verify it exists
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
|
||||||
|
# Check if user is ALREADY in project
|
||||||
|
if user_id in project.members:
|
||||||
|
return {"message": "Already a member"}
|
||||||
|
|
||||||
|
# Add member
|
||||||
|
await dao.projects.add_member(project_id, user_id)
|
||||||
|
|
||||||
|
# Update user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$addToSet": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Joined project"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
|
||||||
|
async def delete_project(
|
||||||
|
project_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
user_id = str(current_user["_id"])
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise HTTPException(status_code=404, detail="Project not found")
|
||||||
|
|
||||||
|
if project.owner_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Only owner can delete project")
|
||||||
|
|
||||||
|
await dao.projects.delete_project(project_id)
|
||||||
|
|
||||||
|
# Remove project from user's project list
|
||||||
|
await dao.users.collection.update_one(
|
||||||
|
{"_id": current_user["_id"]},
|
||||||
|
{"$pull": {"project_ids": project_id}}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": "Project deleted"}
|
||||||
@@ -6,14 +6,15 @@ from pydantic import BaseModel
|
|||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
|
|
||||||
|
|
||||||
class AssetsResponse(BaseModel):
|
|
||||||
assets: List[Asset]
|
|
||||||
total_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class AssetResponse(BaseModel):
|
class AssetResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
type: str
|
type: str # uploaded / generated
|
||||||
|
content_type: str # image / prompt
|
||||||
linked_char_id: Optional[str] = None
|
linked_char_id: Optional[str] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
url: Optional[str] = None
|
||||||
|
|
||||||
|
class AssetsResponse(BaseModel):
|
||||||
|
assets: List[AssetResponse]
|
||||||
|
total_count: int
|
||||||
18
api/models/CharacterDTO.py
Normal file
18
api/models/CharacterDTO.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class CharacterCreateRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
character_bio: str
|
||||||
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
|
avatar_image: Optional[str] = None
|
||||||
|
character_image_tg_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
class CharacterUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
character_bio: Optional[str] = None
|
||||||
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
|
avatar_image: Optional[str] = None
|
||||||
|
character_image_tg_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
37
api/models/ExternalGenerationDTO.py
Normal file
37
api/models/ExternalGenerationDTO.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalGenerationRequest(BaseModel):
|
||||||
|
"""Request model for importing external generations."""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
tech_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
# Image can be provided as base64 string OR URL (one must be provided)
|
||||||
|
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
||||||
|
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||||
|
|
||||||
|
# Generation metadata
|
||||||
|
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||||
|
quality: Quality = Quality.ONEK
|
||||||
|
|
||||||
|
# Optional linking
|
||||||
|
linked_character_id: Optional[str] = None
|
||||||
|
created_by: str = Field(..., description="User ID from external system")
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Performance metrics
|
||||||
|
execution_time_seconds: Optional[float] = None
|
||||||
|
api_execution_time_seconds: Optional[float] = None
|
||||||
|
token_usage: Optional[int] = None
|
||||||
|
input_token_usage: Optional[int] = None
|
||||||
|
output_token_usage: Optional[int] = None
|
||||||
|
|
||||||
|
def validate_image_source(self):
|
||||||
|
"""Ensure at least one image source is provided."""
|
||||||
|
if not self.image_data and not self.image_url:
|
||||||
|
raise ValueError("Either image_data or image_url must be provided")
|
||||||
|
if self.image_data and self.image_url:
|
||||||
|
raise ValueError("Only one of image_data or image_url should be provided")
|
||||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Generation import GenerationStatus
|
from models.Generation import GenerationStatus
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest(BaseModel):
|
class GenerationRequest(BaseModel):
|
||||||
@@ -13,19 +13,43 @@ class GenerationRequest(BaseModel):
|
|||||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||||
quality: Quality = Quality.ONEK
|
quality: Quality = Quality.ONEK
|
||||||
prompt: str
|
prompt: str
|
||||||
|
telegram_id: Optional[int] = None
|
||||||
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationsResponse(BaseModel):
|
||||||
|
generations: List["GenerationResponse"]
|
||||||
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
class GenerationResponse(BaseModel):
|
class GenerationResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: GenerationStatus
|
status: GenerationStatus
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
|
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
aspect_ratio: AspectRatios
|
aspect_ratio: AspectRatios
|
||||||
quality: Quality
|
quality: Quality
|
||||||
prompt: str
|
prompt: str
|
||||||
|
tech_prompt: Optional[str] = None
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
|
result_list: List[str] = []
|
||||||
result: Optional[str] = None
|
result: Optional[str] = None
|
||||||
|
execution_time_seconds: Optional[float] = None
|
||||||
|
api_execution_time_seconds: Optional[float] = None
|
||||||
|
token_usage: Optional[int] = None
|
||||||
|
input_token_usage: Optional[int] = None
|
||||||
|
output_token_usage: Optional[int] = None
|
||||||
|
progress: int = 0
|
||||||
|
cost: Optional[float] = None
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
# Video-specific
|
||||||
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None
|
||||||
|
video_mode: Optional[str] = None
|
||||||
created_at: datetime = datetime.now(UTC)
|
created_at: datetime = datetime.now(UTC)
|
||||||
updated_at: datetime = datetime.now(UTC)
|
updated_at: datetime = datetime.now(UTC)
|
||||||
|
|
||||||
|
|||||||
16
api/models/VideoGenerationRequest.py
Normal file
16
api/models/VideoGenerationRequest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VideoGenerationRequest(BaseModel):
|
||||||
|
prompt: str = ""
|
||||||
|
negative_prompt: Optional[str] = ""
|
||||||
|
image_asset_id: str # ID ассета-картинки для source image
|
||||||
|
duration: int = 5 # 5 or 10 seconds
|
||||||
|
mode: str = "std" # "std" or "pro"
|
||||||
|
model_name: str = "kling-v2-1"
|
||||||
|
cfg_scale: float = 0.5
|
||||||
|
aspect_ratio: str = "16:9"
|
||||||
|
linked_character_id: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
85
api/service/album_service.py
Normal file
85
api/service/album_service.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from models.Album import Album
|
||||||
|
from models.Generation import Generation
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
class AlbumService:
|
||||||
|
def __init__(self, dao: DAO):
|
||||||
|
self.dao = dao
|
||||||
|
|
||||||
|
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
|
||||||
|
album = Album(name=name, description=description)
|
||||||
|
album_id = await self.dao.albums.create_album(album)
|
||||||
|
album.id = album_id
|
||||||
|
return album
|
||||||
|
|
||||||
|
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||||
|
return await self.dao.albums.get_albums(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||||
|
return await self.dao.albums.get_album(album_id)
|
||||||
|
|
||||||
|
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if name:
|
||||||
|
album.name = name
|
||||||
|
if description is not None:
|
||||||
|
album.description = description
|
||||||
|
|
||||||
|
await self.dao.albums.update_album(album_id, album)
|
||||||
|
return album
|
||||||
|
|
||||||
|
async def delete_album(self, album_id: str) -> bool:
|
||||||
|
return await self.dao.albums.delete_album(album_id)
|
||||||
|
|
||||||
|
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
# Verify album exists
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify generation exists (optional but good practice)
|
||||||
|
gen = await self.dao.generations.get_generation(generation_id)
|
||||||
|
if not gen:
|
||||||
|
return False
|
||||||
|
if album.cover_asset_id is None and gen.status == 'done':
|
||||||
|
album.cover_asset_id = gen.result_list[0]
|
||||||
|
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
|
||||||
|
|
||||||
|
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
return await self.dao.albums.remove_generation(album_id, generation_id)
|
||||||
|
|
||||||
|
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album or not album.generation_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Slice the generation IDs (simple pagination on ID list)
|
||||||
|
# Note: This pagination is on IDs, then we fetch objects.
|
||||||
|
# Ideally, fetch only slice.
|
||||||
|
|
||||||
|
# Reverse to show newest first? Or just follow list order?
|
||||||
|
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
|
||||||
|
# Let's assume user wants same order as in list.
|
||||||
|
|
||||||
|
sliced_ids = album.generation_ids[offset : offset + limit]
|
||||||
|
if not sliced_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch generations by IDs
|
||||||
|
# We need a method in GenerationRepo to fetch by IDs.
|
||||||
|
# Currently we only have get_generations with filters.
|
||||||
|
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
|
||||||
|
# Let's add get_generations_by_ids to GenerationRepo.
|
||||||
|
|
||||||
|
# For now, I will use a loop if I can't modify Repo immediately,
|
||||||
|
# but I SHOULD modify GenerationRepo.
|
||||||
|
|
||||||
|
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
|
||||||
|
# But get_generations doesn't support generic filter passing.
|
||||||
|
|
||||||
|
# I'll update GenerationRepo to support fetching by IDs.
|
||||||
|
return await self.dao.generations.get_generations_by_ids(sliced_ids)
|
||||||
@@ -1,18 +1,25 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import base64
|
||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple, Any, Dict
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
|
from aiogram.types import BufferedInputFile
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
from adapters.kling_adapter import KlingAdapter, KlingApiException
|
||||||
|
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse
|
||||||
|
from api.models.VideoGenerationRequest import VideoGenerationRequest
|
||||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||||
from models.Asset import Asset, AssetType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Generation import Generation, GenerationStatus
|
from models.Generation import Generation, GenerationStatus
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -23,21 +30,26 @@ async def generate_image_task(
|
|||||||
media_group_bytes: List[bytes],
|
media_group_bytes: List[bytes],
|
||||||
aspect_ratio: AspectRatios,
|
aspect_ratio: AspectRatios,
|
||||||
quality: Quality,
|
quality: Quality,
|
||||||
gemini: GoogleAdapter
|
gemini: GoogleAdapter,
|
||||||
) -> List[bytes]:
|
|
||||||
|
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
||||||
Возвращает список байтов сгенерированных изображений.
|
Возвращает список байтов сгенерированных изображений.
|
||||||
"""
|
"""
|
||||||
try :
|
try :
|
||||||
|
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
||||||
generated_images_io: List[BytesIO] = await asyncio.to_thread(
|
result = await asyncio.to_thread(
|
||||||
gemini.generate_image,
|
gemini.generate_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
images_list=media_group_bytes,
|
images_list=media_group_bytes,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
)
|
)
|
||||||
|
generated_images_io, metrics = result
|
||||||
|
|
||||||
|
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||||
except GoogleGenerationException as e:
|
except GoogleGenerationException as e:
|
||||||
raise e
|
raise e
|
||||||
images_bytes = []
|
images_bytes = []
|
||||||
@@ -51,31 +63,46 @@ async def generate_image_task(
|
|||||||
# Закрываем поток
|
# Закрываем поток
|
||||||
img_io.close()
|
img_io.close()
|
||||||
|
|
||||||
return images_bytes
|
return images_bytes, metrics
|
||||||
|
|
||||||
|
|
||||||
class GenerationService:
|
class GenerationService:
|
||||||
def __init__(self, dao: DAO, gemini: GoogleAdapter):
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None, kling_adapter: Optional[KlingAdapter] = None):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
self.gemini = gemini
|
self.gemini = gemini
|
||||||
|
self.s3_adapter = s3_adapter
|
||||||
|
self.bot = bot
|
||||||
|
self.kling_adapter = kling_adapter
|
||||||
|
|
||||||
|
|
||||||
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
|
||||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
||||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
||||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT:"""
|
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
||||||
future_prompt += prompt
|
future_prompt += prompt
|
||||||
assets_data = []
|
assets_data = []
|
||||||
if assets is not None:
|
if assets is not None:
|
||||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||||
assets_data.extend(asset.data for asset in assets_db)
|
assets_data.extend(asset.data for asset in assets_db)
|
||||||
generated_prompt = self.gemini.generate_text(future_prompt, assets_data)
|
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||||||
logger.info(future_prompt)
|
logger.info(future_prompt)
|
||||||
logger.info(generated_prompt)
|
logger.info(generated_prompt)
|
||||||
return generated_prompt
|
return generated_prompt
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[
|
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||||||
|
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||||
|
if user_prompt:
|
||||||
|
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||||
|
|
||||||
|
technical_prompt += "Provide ONLY the detailed prompt."
|
||||||
|
|
||||||
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||||
|
|
||||||
|
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
|
||||||
Generation]:
|
Generation]:
|
||||||
return await self.dao.generations.get_generations(limit=limit, offset=offset)
|
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
|
||||||
|
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
|
||||||
|
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
|
return GenerationsResponse(generations=generations, total_count=total_count)
|
||||||
|
|
||||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
||||||
gen = await self.dao.generations.get_generation(generation_id)
|
gen = await self.dao.generations.get_generation(generation_id)
|
||||||
@@ -84,21 +111,26 @@ class GenerationService:
|
|||||||
else:
|
else:
|
||||||
return GenerationResponse(**gen.model_dump())
|
return GenerationResponse(**gen.model_dump())
|
||||||
|
|
||||||
async def get_running_generations(self) -> List[Generation]:
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING)
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||||
|
|
||||||
async def create_generation_task(self, generation_request: GenerationRequest) -> GenerationResponse:
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
gen_id = None
|
gen_id = None
|
||||||
generation_model = None
|
generation_model = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_model = Generation(**generation_request.model_dump())
|
generation_model = Generation(**generation_request.model_dump())
|
||||||
|
if user_id:
|
||||||
|
generation_model.created_by = user_id
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
gen_id = await self.dao.generations.create_generation(generation_model)
|
||||||
generation_model.id = gen_id
|
generation_model.id = gen_id
|
||||||
|
|
||||||
async def runner(gen):
|
async def runner(gen):
|
||||||
|
logger.info(f"Starting background generation task for ID: {gen.id}")
|
||||||
try:
|
try:
|
||||||
await self.create_generation(gen)
|
await self.create_generation(gen)
|
||||||
|
logger.info(f"Background generation task finished for ID: {gen.id}")
|
||||||
except Exception:
|
except Exception:
|
||||||
# если генерация уже пошла и упала — пометим FAILED
|
# если генерация уже пошла и упала — пометим FAILED
|
||||||
try:
|
try:
|
||||||
@@ -125,59 +157,121 @@ class GenerationService:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def create_generation(self, generation: Generation):
|
async def create_generation(self, generation: Generation):
|
||||||
|
start_time = datetime.now()
|
||||||
|
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||||
|
|
||||||
# 2. Получаем ассеты-референсы (если они есть)
|
# 2. Получаем ассеты-референсы (если они есть)
|
||||||
reference_assets: List[Asset] = []
|
reference_assets: List[Asset] = []
|
||||||
media_group_bytes: List[bytes] = []
|
media_group_bytes: List[bytes] = []
|
||||||
generation_prompt = "You are creating image. "
|
generation_prompt = generation.prompt
|
||||||
|
# generation_prompt = f"""
|
||||||
|
|
||||||
|
# Create detailed image of character in scene.
|
||||||
|
|
||||||
|
# SCENE DESCRIPTION: {generation.prompt}
|
||||||
|
|
||||||
|
# Rules:
|
||||||
|
# - Integrate the character's appearance naturally into the scene description.
|
||||||
|
# - Focus on lighting, texture, and composition.
|
||||||
|
# """
|
||||||
if generation.linked_character_id is not None:
|
if generation.linked_character_id is not None:
|
||||||
char_info = await self.dao.chars.get_character(generation.linked_character_id, with_image_data=True)
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||||
if char_info is None:
|
if char_info is None:
|
||||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||||
media_group_bytes.append(char_info.character_image_data)
|
if generation.use_profile_image:
|
||||||
generation_prompt = f"""You are creating image for {char_info.character_bio}"""
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||||
|
if avatar_asset:
|
||||||
|
media_group_bytes.append(avatar_asset.data)
|
||||||
|
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||||
|
|
||||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
|
||||||
# Фильтруем, чтобы отправлять только картинки, и где есть data
|
|
||||||
media_group_bytes.extend(
|
|
||||||
asset.data
|
|
||||||
for asset in reference_assets
|
|
||||||
if asset.data is not None and asset.type == AssetType.IMAGE
|
|
||||||
)
|
|
||||||
generation_prompt+=f"PROMPT: {generation.prompt}"
|
|
||||||
|
|
||||||
# 3. Запускаем процесс генерации
|
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||||
|
for asset in reference_assets:
|
||||||
|
if asset.content_type != AssetContentType.IMAGE:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_data = None
|
||||||
|
if asset.minio_object_name:
|
||||||
|
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||||
|
elif asset.data:
|
||||||
|
img_data = asset.data
|
||||||
|
|
||||||
|
if img_data:
|
||||||
|
media_group_bytes.append(img_data)
|
||||||
|
|
||||||
|
if media_group_bytes:
|
||||||
|
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||||
|
|
||||||
|
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
||||||
|
|
||||||
|
# 3. Запускаем процесс генерации и симуляцию прогресса
|
||||||
|
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generated_bytes_list = await generate_image_task(
|
|
||||||
|
# Default to Image Generation (Gemini)
|
||||||
|
generated_bytes_list, metrics = await generate_image_task(
|
||||||
prompt=generation_prompt, # или request.prompt
|
prompt=generation_prompt, # или request.prompt
|
||||||
media_group_bytes=media_group_bytes,
|
media_group_bytes=media_group_bytes,
|
||||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
||||||
quality=generation.quality,
|
quality=generation.quality,
|
||||||
gemini=self.gemini
|
gemini=self.gemini
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Update metrics from API (Common for both)
|
||||||
|
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||||
|
generation.token_usage = metrics.get("token_usage")
|
||||||
|
generation.input_token_usage = metrics.get("input_token_usage")
|
||||||
|
generation.output_token_usage = metrics.get("output_token_usage")
|
||||||
|
|
||||||
except GoogleGenerationException as e:
|
except GoogleGenerationException as e:
|
||||||
generation.status = GenerationStatus.FAILED
|
generation.status = GenerationStatus.FAILED
|
||||||
generation.failed_reason = str(e.message)
|
generation.failed_reason = str(e)
|
||||||
generation.updated_at = datetime.now(UTC)
|
generation.updated_at = datetime.now(UTC)
|
||||||
await self.dao.generations.update_generation(generation)
|
await self.dao.generations.update_generation(generation)
|
||||||
raise
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Тут стоит добавить логирование ошибки
|
# Тут стоит добавить логирование ошибки
|
||||||
logging.error(f"Generation failed: {e}")
|
logging.error(f"Generation failed: {e}")
|
||||||
# Можно обновить статус генерации на FAILED в БД
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
raise e
|
raise e
|
||||||
|
finally:
|
||||||
|
if not progress_task.done():
|
||||||
|
progress_task.cancel()
|
||||||
|
try:
|
||||||
|
await progress_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
# 4. Сохраняем полученные изображения как новые Ассеты
|
||||||
created_assets: List[Asset] = []
|
created_assets: List[Asset] = []
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
for idx, img_bytes in enumerate(generated_bytes_list):
|
||||||
|
# Generate thumbnail
|
||||||
|
thumbnail_bytes = None
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
||||||
|
|
||||||
|
# Save to S3
|
||||||
|
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||||
|
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
|
||||||
|
|
||||||
new_asset = Asset(
|
new_asset = Asset(
|
||||||
name=f"Generated_{generation.linked_character_id}_{random.randint(1000, 9999)}",
|
name=f"Generated_{generation.linked_character_id}",
|
||||||
type=AssetType.IMAGE,
|
type=AssetType.GENERATED,
|
||||||
linked_char_id=generation.linked_character_id, # Если генерация привязана к персонажу
|
content_type=AssetContentType.IMAGE,
|
||||||
data=img_bytes,
|
linked_char_id=generation.linked_character_id,
|
||||||
# Остальные поля заполнятся дефолтными значениями (created_at)
|
data=None, # Not storing bytes in DB anymore
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Сохраняем в БД
|
# Сохраняем в БД
|
||||||
@@ -190,8 +284,328 @@ class GenerationService:
|
|||||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
||||||
result_ids = [a.id for a in created_assets]
|
result_ids = [a.id for a in created_assets]
|
||||||
|
|
||||||
generation.assets_list = result_ids
|
generation.result_list = result_ids
|
||||||
generation.status = GenerationStatus.DONE
|
generation.status = GenerationStatus.DONE
|
||||||
|
generation.progress = 100
|
||||||
generation.updated_at = datetime.now(UTC)
|
generation.updated_at = datetime.now(UTC)
|
||||||
generation.tech_prompt = generation_prompt
|
generation.tech_prompt = generation_prompt
|
||||||
|
|
||||||
|
end_time = datetime.now()
|
||||||
|
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
||||||
|
|
||||||
await self.dao.generations.update_generation(generation)
|
await self.dao.generations.update_generation(generation)
|
||||||
|
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
||||||
|
|
||||||
|
# 6. Send to Telegram if telegram_id is provided
|
||||||
|
if generation.telegram_id and self.bot:
|
||||||
|
try:
|
||||||
|
for asset in created_assets:
|
||||||
|
if asset.data:
|
||||||
|
await self.bot.send_photo(
|
||||||
|
chat_id=generation.telegram_id,
|
||||||
|
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
||||||
|
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
||||||
|
)
|
||||||
|
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _simulate_progress(self, generation: Generation):
|
||||||
|
"""
|
||||||
|
Increments progress from 0 to 90 over ~20 seconds.
|
||||||
|
"""
|
||||||
|
current_progress = 0
|
||||||
|
try:
|
||||||
|
while current_progress < 90:
|
||||||
|
await asyncio.sleep(4)
|
||||||
|
# Random increment between 5 and 15
|
||||||
|
increment = random.randint(5, 15)
|
||||||
|
current_progress = min(current_progress + increment, 90)
|
||||||
|
|
||||||
|
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
||||||
|
# But for simplicity here we just use the object we have and save it.
|
||||||
|
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
||||||
|
# Assuming simple update is fine for now.
|
||||||
|
generation.progress = current_progress
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task cancelled, generation finished (or failed)
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in progress simulation: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def import_external_generation(self, external_gen) -> Generation:
|
||||||
|
"""
|
||||||
|
Import a generation from an external source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_gen: ExternalGenerationRequest with generation data and image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created Generation object
|
||||||
|
"""
|
||||||
|
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||||
|
|
||||||
|
# Validate image source
|
||||||
|
external_gen.validate_image_source()
|
||||||
|
|
||||||
|
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||||
|
|
||||||
|
# 1. Process image (download or decode)
|
||||||
|
image_bytes = None
|
||||||
|
|
||||||
|
if external_gen.image_url:
|
||||||
|
# Download image from URL
|
||||||
|
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_bytes = response.content
|
||||||
|
elif external_gen.image_data:
|
||||||
|
# Decode base64 image
|
||||||
|
logger.info("Decoding base64 image data")
|
||||||
|
image_bytes = base64.b64decode(external_gen.image_data)
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
raise ValueError("Failed to process image data")
|
||||||
|
|
||||||
|
# 2. Generate thumbnail
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||||
|
|
||||||
|
# 3. Save to S3
|
||||||
|
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||||
|
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||||
|
|
||||||
|
# 4. Create Asset
|
||||||
|
new_asset = Asset(
|
||||||
|
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.IMAGE,
|
||||||
|
linked_char_id=external_gen.linked_character_id,
|
||||||
|
data=None, # Not storing bytes in DB
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=external_gen.created_by,
|
||||||
|
project_id=external_gen.project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
|
||||||
|
logger.info(f"Created asset {asset_id} for external generation")
|
||||||
|
|
||||||
|
# 5. Create Generation record
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.DONE,
|
||||||
|
linked_character_id=external_gen.linked_character_id,
|
||||||
|
aspect_ratio=external_gen.aspect_ratio,
|
||||||
|
quality=external_gen.quality,
|
||||||
|
prompt=external_gen.prompt,
|
||||||
|
tech_prompt=external_gen.tech_prompt,
|
||||||
|
result_list=[new_asset.id],
|
||||||
|
result=new_asset.id,
|
||||||
|
progress=100,
|
||||||
|
execution_time_seconds=external_gen.execution_time_seconds,
|
||||||
|
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||||
|
token_usage=external_gen.token_usage,
|
||||||
|
input_token_usage=external_gen.input_token_usage,
|
||||||
|
output_token_usage=external_gen.output_token_usage,
|
||||||
|
created_by=external_gen.created_by,
|
||||||
|
project_id=external_gen.project_id,
|
||||||
|
created_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
logger.info(f"Created generation {gen_id} from external source")
|
||||||
|
|
||||||
|
return generation
|
||||||
|
|
||||||
|
# === VIDEO GENERATION (Kling) ===
|
||||||
|
|
||||||
|
async def create_video_generation_task(self, request: VideoGenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
|
"""Create a video generation task (async, returns immediately)."""
|
||||||
|
if not self.kling_adapter:
|
||||||
|
raise Exception("Kling adapter is not configured")
|
||||||
|
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.RUNNING,
|
||||||
|
gen_type=GenType.VIDEO,
|
||||||
|
linked_character_id=request.linked_character_id,
|
||||||
|
aspect_ratio=AspectRatios.SIXTEENNINE, # default for video
|
||||||
|
quality=Quality.ONEK,
|
||||||
|
prompt=request.prompt,
|
||||||
|
assets_list=[request.image_asset_id],
|
||||||
|
video_duration=request.duration,
|
||||||
|
video_mode=request.mode,
|
||||||
|
project_id=request.project_id,
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
generation.created_by = user_id
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
async def runner(gen, req):
|
||||||
|
logger.info(f"Starting background video generation task for ID: {gen.id}")
|
||||||
|
try:
|
||||||
|
await self.create_video_generation(gen, req)
|
||||||
|
logger.info(f"Background video generation task finished for ID: {gen.id}")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
db_gen = await self.dao.generations.get_generation(gen.id)
|
||||||
|
if db_gen and db_gen.status != GenerationStatus.FAILED:
|
||||||
|
db_gen.status = GenerationStatus.FAILED
|
||||||
|
db_gen.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(db_gen)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to mark video generation as FAILED")
|
||||||
|
logger.exception("create_video_generation task failed")
|
||||||
|
|
||||||
|
asyncio.create_task(runner(generation, request))
|
||||||
|
return GenerationResponse(**generation.model_dump())
|
||||||
|
|
||||||
|
async def create_video_generation(self, generation: Generation, request: VideoGenerationRequest):
|
||||||
|
"""Background video generation: call Kling API, poll, download result, save asset."""
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Get source image presigned URL
|
||||||
|
asset = await self.dao.assets.get_asset(request.image_asset_id)
|
||||||
|
if not asset:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} not found")
|
||||||
|
|
||||||
|
if not asset.minio_object_name:
|
||||||
|
raise Exception(f"Asset {request.image_asset_id} has no S3 object")
|
||||||
|
|
||||||
|
presigned_url = await self.s3_adapter.get_presigned_url(asset.minio_object_name, expiration=3600)
|
||||||
|
if not presigned_url:
|
||||||
|
raise Exception("Failed to generate presigned URL for source image")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: got presigned URL for asset {request.image_asset_id}")
|
||||||
|
|
||||||
|
# 2. Create Kling task
|
||||||
|
task_data = await self.kling_adapter.create_video_task(
|
||||||
|
image_url=presigned_url,
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt or "",
|
||||||
|
model_name=request.model_name,
|
||||||
|
duration=request.duration,
|
||||||
|
mode=request.mode,
|
||||||
|
cfg_scale=request.cfg_scale,
|
||||||
|
aspect_ratio=request.aspect_ratio,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_id = task_data.get("task_id")
|
||||||
|
generation.kling_task_id = task_id
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: Kling task created, task_id={task_id}")
|
||||||
|
|
||||||
|
# 3. Poll for completion with progress updates
|
||||||
|
async def progress_callback(progress_pct: int):
|
||||||
|
generation.progress = progress_pct
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
result = await self.kling_adapter.wait_for_completion(
|
||||||
|
task_id=task_id,
|
||||||
|
poll_interval=10,
|
||||||
|
timeout=600,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Extract video URL and download
|
||||||
|
works = result.get("task_result", {}).get("videos", [])
|
||||||
|
if not works:
|
||||||
|
raise Exception("No video in Kling result")
|
||||||
|
|
||||||
|
video_url = works[0].get("url")
|
||||||
|
video_duration = works[0].get("duration", request.duration)
|
||||||
|
if not video_url:
|
||||||
|
raise Exception("No video URL in Kling result")
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloading video from {video_url}")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
video_response = await client.get(video_url)
|
||||||
|
video_response.raise_for_status()
|
||||||
|
video_bytes = video_response.content
|
||||||
|
|
||||||
|
logger.info(f"Video gen {generation.id}: downloaded {len(video_bytes)} bytes")
|
||||||
|
|
||||||
|
# 5. Upload to S3
|
||||||
|
filename = f"generated_video/{generation.linked_character_id or 'no_char'}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.mp4"
|
||||||
|
await self.s3_adapter.upload_file(filename, video_bytes, content_type="video/mp4")
|
||||||
|
|
||||||
|
# 6. Create Asset
|
||||||
|
new_asset = Asset(
|
||||||
|
name=f"Video_{generation.linked_character_id or 'gen'}",
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.VIDEO,
|
||||||
|
linked_char_id=generation.linked_character_id,
|
||||||
|
data=None,
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=None, # видео thumbnails можно добавить позже
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
|
||||||
|
# 7. Finalize generation
|
||||||
|
end_time = datetime.now()
|
||||||
|
generation.result_list = [new_asset.id]
|
||||||
|
generation.result = new_asset.id
|
||||||
|
generation.status = GenerationStatus.DONE
|
||||||
|
generation.progress = 100
|
||||||
|
generation.video_duration = video_duration
|
||||||
|
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
logger.info(f"Video generation {generation.id} completed. Asset: {new_asset.id}, Time: {generation.execution_time_seconds:.1f}s")
|
||||||
|
|
||||||
|
except KlingApiException as e:
|
||||||
|
logger.error(f"Kling API error for generation {generation.id}: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Video generation {generation.id} failed: {e}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
generation.failed_reason = str(e)
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def delete_generation(self, generation_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Soft delete generation by marking it as deleted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
generation = await self.dao.generations.get_generation(generation_id)
|
||||||
|
if not generation:
|
||||||
|
return False
|
||||||
|
|
||||||
|
generation.is_deleted = True
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting generation {generation_id}: {e}")
|
||||||
|
return False
|
||||||
6
deploy.sh
Executable file
6
deploy.sh
Executable file
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
ssh root@31.59.58.220 "
|
||||||
|
cd /root/bots/ai-char-bot &&
|
||||||
|
git pull &&
|
||||||
|
docker compose up -d --build
|
||||||
|
"
|
||||||
@@ -4,6 +4,26 @@ services:
|
|||||||
container_name: ai-bot
|
container_name: ai-bot
|
||||||
build:
|
build:
|
||||||
context: .
|
context: .
|
||||||
network: host
|
# УБРАЛИ network_mode: host
|
||||||
network_mode: host
|
ports:
|
||||||
|
- "8090:8090" # Вернули проброс порта
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
depends_on:
|
||||||
|
- minio
|
||||||
|
environment:
|
||||||
|
# Важно: внутри докера к другим контейнерам обращаемся по имени сервиса!
|
||||||
|
MINIO_ENDPOINT: "http://minio:9000"
|
||||||
|
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
container_name: minio
|
||||||
|
restart: unless-stopped
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: admin
|
||||||
|
MINIO_ROOT_PASSWORD: SuperSecretPassword123!
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
volumes:
|
||||||
|
- ./minio_data:/data
|
||||||
Binary file not shown.
Binary file not shown.
12
models/Album.py
Normal file
12
models/Album.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from datetime import datetime, UTC
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class Album(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
cover_asset_id: Optional[str] = None
|
||||||
|
generation_ids: List[str] = []
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
@@ -2,26 +2,65 @@ from datetime import datetime, UTC
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Any, List
|
from typing import Optional, Any, List
|
||||||
|
|
||||||
from pydantic import BaseModel, computed_field, Field
|
from pydantic import BaseModel, computed_field, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class AssetContentType(str, Enum):
|
||||||
|
IMAGE = 'image'
|
||||||
|
VIDEO = 'video'
|
||||||
|
PROMPT = 'prompt'
|
||||||
|
|
||||||
class AssetType(str, Enum):
|
class AssetType(str, Enum):
|
||||||
IMAGE = 'image'
|
UPLOADED = 'uploaded'
|
||||||
PROMPT = 'prompt'
|
GENERATED = 'generated'
|
||||||
|
|
||||||
|
|
||||||
class Asset(BaseModel):
|
class Asset(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
name: str
|
name: str
|
||||||
type: AssetType
|
type: AssetType = AssetType.GENERATED
|
||||||
|
content_type: AssetContentType = AssetContentType.IMAGE
|
||||||
linked_char_id: Optional[str] = None
|
linked_char_id: Optional[str] = None
|
||||||
data: Optional[bytes] = None
|
data: Optional[bytes] = None
|
||||||
tg_doc_file_id: Optional[str] = None
|
tg_doc_file_id: Optional[str] = None
|
||||||
tg_photo_file_id: Optional[str] = None
|
tg_photo_file_id: Optional[str] = None
|
||||||
|
minio_object_name: Optional[str] = None
|
||||||
|
minio_bucket: Optional[str] = None
|
||||||
|
minio_thumbnail_object_name: Optional[str] = None
|
||||||
|
thumbnail: Optional[bytes] = None
|
||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
@model_validator(mode='before')
|
||||||
|
@classmethod
|
||||||
|
def check_legacy_type(cls, data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# Если поле type содержит старые значения ("image", "prompt"),
|
||||||
|
# переносим их в content_type, а type ставим по умолчанию (GENERATED)
|
||||||
|
# или пытаемся угадать.
|
||||||
|
# Но по задаче мы дефолтим в GENERATED, и script'ом поправим.
|
||||||
|
|
||||||
|
raw_type = data.get('type')
|
||||||
|
if raw_type in ['image', 'prompt']:
|
||||||
|
data['content_type'] = raw_type
|
||||||
|
# Если в базе нет нового поля type, оно встанет в default=GENERATED
|
||||||
|
# Чтобы не вызывало ошибку валидации AssetType, меняем его или удаляем,
|
||||||
|
# полагаясь на default.
|
||||||
|
# Но если мы просто удалим, поле type примет дефолтное значение.
|
||||||
|
# Однако, если мы хотим явно отличить, можно ничего не делать,
|
||||||
|
# но тогда валидация поля `type` упадет, т.к. "image" != "generated".
|
||||||
|
# Поэтому удаляем старое значение из type, чтобы сработал дефолт.
|
||||||
|
if 'type' in data:
|
||||||
|
del data['type']
|
||||||
|
|
||||||
|
# Если content_type нет в данных (легаси), пытаемся его восстановить из удалённого type
|
||||||
|
# (выше мы его переложили).
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
# --- CALCULATED FIELD ---
|
# --- CALCULATED FIELD ---
|
||||||
@computed_field
|
@computed_field
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from pydantic_core.core_schema import computed_field
|
|||||||
|
|
||||||
|
|
||||||
class Character(BaseModel):
|
class Character(BaseModel):
|
||||||
id: str | None
|
id: Optional[str] = None
|
||||||
name: str
|
name: str
|
||||||
|
avatar_asset_id: Optional[str] = None
|
||||||
avatar_image: Optional[str] = None
|
avatar_image: Optional[str] = None
|
||||||
character_image_data: Optional[bytes] = None
|
character_image_data: Optional[bytes] = None
|
||||||
character_image_doc_tg_id: str
|
character_image_doc_tg_id: Optional[str] = None
|
||||||
character_image_tg_id: str | None
|
character_image_tg_id: Optional[str] = None
|
||||||
character_bio: str
|
character_bio: Optional[str] = None
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ from datetime import datetime, UTC
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
|
|
||||||
|
|
||||||
class GenerationStatus(str, Enum):
|
class GenerationStatus(str, Enum):
|
||||||
@@ -16,13 +16,39 @@ class GenerationStatus(str, Enum):
|
|||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
id: Optional[str] = None
|
id: Optional[str] = None
|
||||||
status: GenerationStatus = GenerationStatus.RUNNING
|
status: GenerationStatus = GenerationStatus.RUNNING
|
||||||
|
gen_type: GenType = GenType.IMAGE
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: Optional[str] = None
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: Optional[str] = None
|
||||||
|
telegram_id: Optional[int] = None
|
||||||
|
use_profile_image: bool = True
|
||||||
aspect_ratio: AspectRatios
|
aspect_ratio: AspectRatios
|
||||||
quality: Quality
|
quality: Quality
|
||||||
prompt: str
|
prompt: str
|
||||||
tech_prompt: Optional[str] = None
|
tech_prompt: Optional[str] = None
|
||||||
assets_list: List[str]
|
assets_list: List[str] = Field(default_factory=list)
|
||||||
|
result_list: List[str] = Field(default_factory=list)
|
||||||
result: Optional[str] = None
|
result: Optional[str] = None
|
||||||
|
progress: int = 0
|
||||||
|
execution_time_seconds: Optional[float] = None
|
||||||
|
api_execution_time_seconds: Optional[float] = None
|
||||||
|
token_usage: Optional[int] = None
|
||||||
|
input_token_usage: Optional[int] = None
|
||||||
|
output_token_usage: Optional[int] = None
|
||||||
|
is_deleted: bool = False
|
||||||
|
album_id: Optional[str] = None
|
||||||
|
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
# Video-specific fields
|
||||||
|
kling_task_id: Optional[str] = None
|
||||||
|
video_duration: Optional[int] = None # 5 or 10 seconds
|
||||||
|
video_mode: Optional[str] = None # "std" or "pro"
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
def cost(self) -> float:
|
||||||
|
if self.status == GenerationStatus.DONE and self.input_token_usage and self.output_token_usage:
|
||||||
|
cost_input = self.input_token_usage * 0.000002
|
||||||
|
cost_output = self.output_token_usage * 0.00012
|
||||||
|
return round(cost_input + cost_output, 3)
|
||||||
|
return 0.0
|
||||||
12
models/Project.py
Normal file
12
models/Project.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class Project(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
owner_id: str
|
||||||
|
members: List[str] = [] # List of User IDs
|
||||||
|
is_deleted: bool = False
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
@@ -34,10 +34,12 @@ class Quality(str, Enum):
|
|||||||
class GenType(str, Enum):
|
class GenType(str, Enum):
|
||||||
TEXT = 'Text'
|
TEXT = 'Text'
|
||||||
IMAGE = 'Image'
|
IMAGE = 'Image'
|
||||||
|
VIDEO = 'Video'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self) -> str:
|
def value_type(self) -> str:
|
||||||
return {
|
return {
|
||||||
GenType.TEXT: 'Text',
|
GenType.TEXT: 'Text',
|
||||||
GenType.IMAGE: 'Image'
|
GenType.IMAGE: 'Image',
|
||||||
|
GenType.VIDEO: 'Video',
|
||||||
}[self]
|
}[self]
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
61
repos/albums_repo.py
Normal file
61
repos/albums_repo.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from models.Album import Album
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AlbumsRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["albums"]
|
||||||
|
|
||||||
|
async def create_album(self, album: Album) -> str:
|
||||||
|
res = await self.collection.insert_one(album.model_dump())
|
||||||
|
return str(res.inserted_id)
|
||||||
|
|
||||||
|
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||||
|
try:
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(album_id)})
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Album(**res)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||||
|
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||||
|
albums = []
|
||||||
|
for doc in res:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
albums.append(Album(**doc))
|
||||||
|
return albums
|
||||||
|
|
||||||
|
async def update_album(self, album_id: str, album: Album) -> bool:
|
||||||
|
if not album.id:
|
||||||
|
album.id = album_id
|
||||||
|
|
||||||
|
model_dump = album.model_dump()
|
||||||
|
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def delete_album(self, album_id: str) -> bool:
|
||||||
|
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
|
||||||
|
return res.deleted_count > 0
|
||||||
|
|
||||||
|
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(album_id)},
|
||||||
|
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(album_id)},
|
||||||
|
{"$pull": {"generation_ids": generation_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
@@ -1,47 +1,154 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AssetsRepo:
|
class AssetsRepo:
|
||||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
def __init__(self, client: AsyncIOMotorClient, s3_adapter: Optional[S3Adapter] = None, db_name="bot_db"):
|
||||||
self.collection = client[db_name]["assets"]
|
self.collection = client[db_name]["assets"]
|
||||||
|
self.s3 = s3_adapter
|
||||||
|
|
||||||
async def create_asset(self, asset: Asset) -> str:
|
async def create_asset(self, asset: Asset) -> str:
|
||||||
res = await self.collection.insert_one(asset.model_dump())
|
# Если есть S3 и данные - грузим в S3
|
||||||
|
if self.s3:
|
||||||
|
# Main data
|
||||||
|
if asset.data:
|
||||||
|
ts = int(asset.created_at.timestamp())
|
||||||
|
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||||
|
|
||||||
|
uploaded = await self.s3.upload_file(object_name, asset.data)
|
||||||
|
if uploaded:
|
||||||
|
asset.minio_object_name = object_name
|
||||||
|
asset.minio_bucket = self.s3.bucket_name
|
||||||
|
asset.data = None # Clear data
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to upload asset {asset.name} to MinIO")
|
||||||
|
|
||||||
|
# Thumbnail
|
||||||
|
if asset.thumbnail:
|
||||||
|
ts = int(asset.created_at.timestamp())
|
||||||
|
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||||
|
|
||||||
|
uploaded_thumb = await self.s3.upload_file(thumb_name, asset.thumbnail)
|
||||||
|
if uploaded_thumb:
|
||||||
|
asset.minio_thumbnail_object_name = thumb_name
|
||||||
|
asset.minio_bucket = self.s3.bucket_name # Assumes same bucket
|
||||||
|
asset.thumbnail = None # Clear thumbnail data
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to upload thumbnail for {asset.name} to MinIO")
|
||||||
|
|
||||||
|
|
||||||
|
res = await self.collection.insert_one(asset.model_dump())
|
||||||
return str(res.inserted_id)
|
return str(res.inserted_id)
|
||||||
|
|
||||||
async def get_assets(self, limit: int = 10, offset: int = 0) -> List[Asset]:
|
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
|
||||||
res = await self.collection.find({}, {"data": 0}).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
filter = {}
|
||||||
|
if asset_type:
|
||||||
|
filter["type"] = asset_type
|
||||||
|
args = {}
|
||||||
|
if not with_data:
|
||||||
|
args["data"] = 0
|
||||||
|
# We assume thumbnails are fetched only if needed or kept sparse.
|
||||||
|
# If they are in MinIO, we don't fetch them by default list unless specifically asked?
|
||||||
|
# User requirement "Get bytes ... from minio" usually refers to full asset. used in detail view.
|
||||||
|
# In list view, we might want thumbnails.
|
||||||
|
# If thumbnails are in MinIO, list view will be slow if we fetch all.
|
||||||
|
# Usually we return a URL. But this bot might serve bytes.
|
||||||
|
# Let's assuming list view needs thumbnails if they are small.
|
||||||
|
# But if we moved them to S3, we probably don't want to fetch 10x S3 requests for list.
|
||||||
|
# For now: If minio_thumbnail_object_name is present, user might need to fetch separately
|
||||||
|
# or we fetch if `with_data` is True?
|
||||||
|
# Standard pattern: return URL or ID.
|
||||||
|
# Let's keep existing logic: args["thumbnail"] = 0 if not with_data.
|
||||||
|
# EXCEPT if we want to show thumbnails in list.
|
||||||
|
# Original code:
|
||||||
|
# if not with_data: args["data"] = 0; args["thumbnail"] = 0
|
||||||
|
# So list DOES NOT return thumbnails by default.
|
||||||
|
args["thumbnail"] = 0
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
filter['project_id'] = None
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
|
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||||
assets = []
|
assets = []
|
||||||
for doc in res:
|
for doc in res:
|
||||||
# Конвертируем ObjectId в строку и кладем в поле id
|
|
||||||
doc["id"] = str(doc.pop("_id"))
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
asset = Asset(**doc)
|
||||||
|
|
||||||
# Создаем объект
|
if with_data and self.s3:
|
||||||
assets.append(Asset(**doc))
|
# Fetch data
|
||||||
|
if asset.minio_object_name:
|
||||||
|
data = await self.s3.get_file(asset.minio_object_name)
|
||||||
|
if data: asset.data = data
|
||||||
|
|
||||||
|
# Fetch thumbnail
|
||||||
|
if asset.minio_thumbnail_object_name:
|
||||||
|
thumb = await self.s3.get_file(asset.minio_thumbnail_object_name)
|
||||||
|
if thumb: asset.thumbnail = thumb
|
||||||
|
|
||||||
|
assets.append(asset)
|
||||||
|
|
||||||
return assets
|
return assets
|
||||||
|
|
||||||
|
|
||||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
||||||
projection = {"_id": 1, "name": 1, "type": 1, "tg_doc_file_id": 1}
|
projection = None
|
||||||
if with_data:
|
if not with_data:
|
||||||
projection["data"] = 1
|
projection = {"data": 0, "thumbnail": 0}
|
||||||
|
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(asset_id)}, projection)
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
|
||||||
res = await self.collection.find_one({"_id": ObjectId(asset_id)},
|
|
||||||
projection)
|
|
||||||
res["id"] = str(res.pop("_id"))
|
res["id"] = str(res.pop("_id"))
|
||||||
return Asset(**res)
|
asset = Asset(**res)
|
||||||
|
|
||||||
|
if with_data and self.s3:
|
||||||
|
if asset.minio_object_name:
|
||||||
|
data = await self.s3.get_file(asset.minio_object_name)
|
||||||
|
if data: asset.data = data
|
||||||
|
|
||||||
|
if asset.minio_thumbnail_object_name:
|
||||||
|
thumb = await self.s3.get_file(asset.minio_thumbnail_object_name)
|
||||||
|
if thumb: asset.thumbnail = thumb
|
||||||
|
|
||||||
|
return asset
|
||||||
|
|
||||||
async def update_asset(self, asset_id: str, asset: Asset):
|
async def update_asset(self, asset_id: str, asset: Asset):
|
||||||
if not asset.id:
|
if not asset.id:
|
||||||
raise Exception(f"Asset ID not found: {asset_id}")
|
if asset_id: asset.id = asset_id
|
||||||
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": asset.model_dump()})
|
else: raise Exception(f"Asset ID not found: {asset_id}")
|
||||||
|
|
||||||
|
# NOTE: simplistic update. If asset has data/thumbnail bytes, we might need to upload?
|
||||||
|
# Assuming for now we just save what's given.
|
||||||
|
# If user wants to update data, they should probably use a specialized method or we handle it here.
|
||||||
|
# Let's handle it: If data/thumbnail is present AND we have S3, upload it.
|
||||||
|
|
||||||
|
if self.s3:
|
||||||
|
if asset.data:
|
||||||
|
ts = int(asset.created_at.timestamp())
|
||||||
|
object_name = f"{asset.type.value}/{ts}_{asset.name}"
|
||||||
|
if await self.s3.upload_file(object_name, asset.data):
|
||||||
|
asset.minio_object_name = object_name
|
||||||
|
asset.minio_bucket = self.s3.bucket_name
|
||||||
|
asset.data = None
|
||||||
|
|
||||||
|
if asset.thumbnail:
|
||||||
|
ts = int(asset.created_at.timestamp())
|
||||||
|
thumb_name = f"{asset.type.value}/thumbs/{ts}_{asset.name}_thumb.jpg"
|
||||||
|
if await self.s3.upload_file(thumb_name, asset.thumbnail):
|
||||||
|
asset.minio_thumbnail_object_name = thumb_name
|
||||||
|
asset.thumbnail = None
|
||||||
|
|
||||||
|
model_dump = asset.model_dump()
|
||||||
|
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": model_dump})
|
||||||
|
|
||||||
async def set_tg_photo_file_id(self, asset_id: str, tg_photo_file_id: str):
|
async def set_tg_photo_file_id(self, asset_id: str, tg_photo_file_id: str):
|
||||||
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": {"tg_photo_file_id": tg_photo_file_id}})
|
await self.collection.update_one({"_id": ObjectId(asset_id)}, {"$set": {"tg_photo_file_id": tg_photo_file_id}})
|
||||||
@@ -56,16 +163,102 @@ class AssetsRepo:
|
|||||||
assets.append(Asset(**doc))
|
assets.append(Asset(**doc))
|
||||||
return assets
|
return assets
|
||||||
|
|
||||||
async def get_asset_count(self, character_id: Optional[str] = None) -> int:
|
async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||||
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {})
|
filter = {}
|
||||||
|
if character_id:
|
||||||
|
filter["linked_char_id"] = character_id
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
return await self.collection.count_documents(filter)
|
||||||
|
|
||||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
||||||
res = self.collection.find({"_id": {"$in": object_ids}})
|
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
|
||||||
|
# Original excluded thumbnail too.
|
||||||
assets = []
|
assets = []
|
||||||
async for doc in res:
|
async for doc in res:
|
||||||
doc["id"] = str(doc.pop("_id"))
|
doc["id"] = str(doc.pop("_id"))
|
||||||
assets.append(Asset(**doc))
|
assets.append(Asset(**doc))
|
||||||
return assets
|
return assets
|
||||||
|
|
||||||
|
async def delete_asset(self, asset_id: str) -> bool:
|
||||||
|
asset_doc = await self.collection.find_one({"_id": ObjectId(asset_id)})
|
||||||
|
if not asset_doc:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.s3:
|
||||||
|
if asset_doc.get("minio_object_name"):
|
||||||
|
await self.s3.delete_file(asset_doc["minio_object_name"])
|
||||||
|
if asset_doc.get("minio_thumbnail_object_name"):
|
||||||
|
await self.s3.delete_file(asset_doc["minio_thumbnail_object_name"])
|
||||||
|
|
||||||
|
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
|
||||||
|
return res.deleted_count > 0
|
||||||
|
|
||||||
|
async def migrate_to_minio(self) -> dict:
|
||||||
|
"""Переносит данные и thumbnails из Mongo в MinIO."""
|
||||||
|
if not self.s3:
|
||||||
|
return {"error": "MinIO adapter not initialized"}
|
||||||
|
|
||||||
|
# 1. Migrate Data
|
||||||
|
cursor_data = self.collection.find({"data": {"$ne": None}, "minio_object_name": {"$eq": None}})
|
||||||
|
count_data = 0
|
||||||
|
errors_data = 0
|
||||||
|
|
||||||
|
async for doc in cursor_data:
|
||||||
|
try:
|
||||||
|
asset_id = doc["_id"]
|
||||||
|
data = doc.get("data")
|
||||||
|
name = doc.get("name", "unknown")
|
||||||
|
type_ = doc.get("type", "image")
|
||||||
|
created_at = doc.get("created_at")
|
||||||
|
ts = int(created_at.timestamp()) if created_at else 0
|
||||||
|
|
||||||
|
object_name = f"{type_}/{ts}_{asset_id}_{name}"
|
||||||
|
if await self.s3.upload_file(object_name, data):
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": asset_id},
|
||||||
|
{"$set": {"minio_object_name": object_name, "minio_bucket": self.s3.bucket_name, "data": None}}
|
||||||
|
)
|
||||||
|
count_data += 1
|
||||||
|
else:
|
||||||
|
errors_data += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Data migration error for {doc.get('_id')}: {e}")
|
||||||
|
errors_data += 1
|
||||||
|
|
||||||
|
# 2. Migrate Thumbnails
|
||||||
|
cursor_thumb = self.collection.find({"thumbnail": {"$ne": None}, "minio_thumbnail_object_name": {"$eq": None}})
|
||||||
|
count_thumb = 0
|
||||||
|
errors_thumb = 0
|
||||||
|
|
||||||
|
async for doc in cursor_thumb:
|
||||||
|
try:
|
||||||
|
asset_id = doc["_id"]
|
||||||
|
thumb = doc.get("thumbnail")
|
||||||
|
name = doc.get("name", "unknown")
|
||||||
|
type_ = doc.get("type", "image")
|
||||||
|
created_at = doc.get("created_at")
|
||||||
|
ts = int(created_at.timestamp()) if created_at else 0
|
||||||
|
|
||||||
|
thumb_name = f"{type_}/thumbs/{ts}_{asset_id}_{name}_thumb.jpg"
|
||||||
|
if await self.s3.upload_file(thumb_name, thumb):
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": asset_id},
|
||||||
|
{"$set": {"minio_thumbnail_object_name": thumb_name, "minio_bucket": self.s3.bucket_name, "thumbnail": None}}
|
||||||
|
)
|
||||||
|
count_thumb += 1
|
||||||
|
else:
|
||||||
|
errors_thumb += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Thumbnail migration error for {doc.get('_id')}: {e}")
|
||||||
|
errors_thumb += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"migrated_data": count_data,
|
||||||
|
"errors_data": errors_data,
|
||||||
|
"migrated_thumbnails": count_thumb,
|
||||||
|
"errors_thumbnails": errors_thumb
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
@@ -12,7 +12,7 @@ class CharacterRepo:
|
|||||||
|
|
||||||
async def add_character(self, character: Character) -> Character:
|
async def add_character(self, character: Character) -> Character:
|
||||||
op = await self.collection.insert_one(character.model_dump())
|
op = await self.collection.insert_one(character.model_dump())
|
||||||
character.id = op.inserted_id
|
character.id = str(op.inserted_id)
|
||||||
return character
|
return character
|
||||||
|
|
||||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||||
@@ -26,18 +26,25 @@ class CharacterRepo:
|
|||||||
res["id"] = str(res.pop("_id"))
|
res["id"] = str(res.pop("_id"))
|
||||||
return Character(**res)
|
return Character(**res)
|
||||||
|
|
||||||
async def get_all_characters(self) -> List[Character]:
|
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
|
||||||
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None)
|
filter = {}
|
||||||
|
if created_by:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
if project_id:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
characters = []
|
args = {"character_image_data": 0} # don't return image data for list
|
||||||
for doc in docs:
|
res = await self.collection.find(filter, args).to_list(None)
|
||||||
# Конвертируем ObjectId в строку и кладем в поле id
|
chars = []
|
||||||
|
for doc in res:
|
||||||
doc["id"] = str(doc.pop("_id"))
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
chars.append(Character(**doc))
|
||||||
|
return chars
|
||||||
|
|
||||||
# Создаем объект
|
async def update_char(self, char_id: str, character: Character) -> bool:
|
||||||
characters.append(Character(**doc))
|
result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
||||||
|
return result.modified_count > 0
|
||||||
|
|
||||||
return characters
|
async def delete_character(self, char_id: str) -> bool:
|
||||||
|
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
|
||||||
async def update_char(self, char_id: str, character: Character) -> None:
|
return result.deleted_count > 0
|
||||||
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
|
|
||||||
|
|||||||
12
repos/dao.py
12
repos/dao.py
@@ -4,10 +4,18 @@ from repos.assets_repo import AssetsRepo
|
|||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
from repos.generation_repo import GenerationRepo
|
from repos.generation_repo import GenerationRepo
|
||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
|
from repos.albums_repo import AlbumsRepo
|
||||||
|
from repos.project_repo import ProjectRepo
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
class DAO:
|
class DAO:
|
||||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
def __init__(self, client: AsyncIOMotorClient, s3_adapter: Optional[S3Adapter] = None, db_name="bot_db"):
|
||||||
self.chars = CharacterRepo(client, db_name)
|
self.chars = CharacterRepo(client, db_name)
|
||||||
self.assets = AssetsRepo(client, db_name)
|
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
||||||
self.generations = GenerationRepo(client, db_name)
|
self.generations = GenerationRepo(client, db_name)
|
||||||
|
self.albums = AlbumsRepo(client, db_name)
|
||||||
|
self.projects = ProjectRepo(client, db_name)
|
||||||
|
self.users = UsersRepo(client, db_name)
|
||||||
|
|||||||
@@ -25,13 +25,20 @@ class GenerationRepo:
|
|||||||
return Generation(**res)
|
return Generation(**res)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
limit: int = 10, offset: int = 10) -> List[Generation]:
|
limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
args = {}
|
|
||||||
|
filter = {"is_deleted": False}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
args["character_id"] = character_id
|
filter["linked_character_id"] = character_id
|
||||||
if status is not None:
|
if status is not None:
|
||||||
args["status"] = status
|
filter["status"] = status
|
||||||
res = await self.collection.find(args).sort("created_at", -1).skip(
|
if created_by is not None:
|
||||||
|
filter["created_by"] = created_by
|
||||||
|
filter["project_id"] = None
|
||||||
|
if project_id is not None:
|
||||||
|
filter["project_id"] = project_id
|
||||||
|
|
||||||
|
res = await self.collection.find(filter).sort("created_at", -1).skip(
|
||||||
offset).limit(limit).to_list(None)
|
offset).limit(limit).to_list(None)
|
||||||
generations: List[Generation] = []
|
generations: List[Generation] = []
|
||||||
for generation in res:
|
for generation in res:
|
||||||
@@ -39,5 +46,34 @@ class GenerationRepo:
|
|||||||
generations.append(Generation(**generation))
|
generations.append(Generation(**generation))
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
|
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
|
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
|
||||||
|
args = {}
|
||||||
|
if character_id is not None:
|
||||||
|
args["linked_character_id"] = character_id
|
||||||
|
if status is not None:
|
||||||
|
args["status"] = status
|
||||||
|
if created_by is not None:
|
||||||
|
args["created_by"] = created_by
|
||||||
|
if project_id is not None:
|
||||||
|
args["project_id"] = project_id
|
||||||
|
return await self.collection.count_documents(args)
|
||||||
|
|
||||||
|
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||||
|
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
|
||||||
|
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
|
||||||
|
generations: List[Generation] = []
|
||||||
|
|
||||||
|
# Maintain order of generation_ids
|
||||||
|
gen_map = {str(doc["_id"]): doc for doc in res}
|
||||||
|
|
||||||
|
for gen_id in generation_ids:
|
||||||
|
doc = gen_map.get(gen_id)
|
||||||
|
if doc:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
generations.append(Generation(**doc))
|
||||||
|
|
||||||
|
return generations
|
||||||
|
|
||||||
async def update_generation(self, generation: Generation, ):
|
async def update_generation(self, generation: Generation, ):
|
||||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||||
|
|||||||
62
repos/project_repo.py
Normal file
62
repos/project_repo.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from models.Project import Project
|
||||||
|
|
||||||
|
class ProjectRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["projects"]
|
||||||
|
|
||||||
|
async def create_project(self, project: Project) -> str:
|
||||||
|
res = await self.collection.insert_one(project.model_dump())
|
||||||
|
return str(res.inserted_id)
|
||||||
|
|
||||||
|
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||||
|
if not ObjectId.is_valid(project_id):
|
||||||
|
return None
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(project_id)})
|
||||||
|
if res:
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Project(**res)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_projects_by_user(self, user_id: str) -> List[Project]:
|
||||||
|
# Find projects where user is owner OR in members
|
||||||
|
filter = {
|
||||||
|
"$or": [
|
||||||
|
{"owner_id": user_id},
|
||||||
|
{"members": user_id}
|
||||||
|
],
|
||||||
|
"is_deleted": False
|
||||||
|
}
|
||||||
|
cursor = self.collection.find(filter).sort("created_at", -1)
|
||||||
|
projects = []
|
||||||
|
async for doc in cursor:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
projects.append(Project(**doc))
|
||||||
|
return projects
|
||||||
|
|
||||||
|
async def add_member(self, project_id: str, user_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$addToSet": {"members": user_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def remove_member(self, project_id: str, user_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$pull": {"members": user_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def update_project(self, project_id: str, updates: dict) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(project_id)},
|
||||||
|
{"$set": updates}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def delete_project(self, project_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
|
||||||
|
return res.modified_count > 0
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from aiogram.types import User
|
from aiogram.types import User
|
||||||
from motor.motor_asyncio import AsyncIOMotorClient
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from utils.security import get_password_hash
|
||||||
|
|
||||||
|
|
||||||
class UserStatus:
|
class UserStatus:
|
||||||
@@ -17,12 +19,65 @@ class UsersRepo:
|
|||||||
self.collection = client[db_name]["users"]
|
self.collection = client[db_name]["users"]
|
||||||
|
|
||||||
async def get_user(self, user_id: int):
|
async def get_user(self, user_id: int):
|
||||||
return await self.collection.find_one({"user_id": user_id})
|
user = await self.collection.find_one({"user_id": user_id})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_user_by_username(self, username: str):
|
||||||
|
user = await self.collection.find_one({"username": username})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
|
||||||
|
"""Создает нового пользователя с username/паролем"""
|
||||||
|
existing = await self.get_user_by_username(username)
|
||||||
|
if existing:
|
||||||
|
raise ValueError("User with this username already exists")
|
||||||
|
|
||||||
|
user_doc = {
|
||||||
|
"username": username,
|
||||||
|
"hashed_password": get_password_hash(password),
|
||||||
|
"full_name": full_name,
|
||||||
|
"status": UserStatus.PENDING, # По умолчанию PENDING
|
||||||
|
"created_at": datetime.now(),
|
||||||
|
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
|
||||||
|
"is_web_user": True,
|
||||||
|
"is_admin": False,
|
||||||
|
"project_ids": [],
|
||||||
|
"current_project_id": None
|
||||||
|
}
|
||||||
|
result = await self.collection.insert_one(user_doc)
|
||||||
|
user = await self.collection.find_one({"_id": result.inserted_id})
|
||||||
|
if user:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_pending_users(self):
|
||||||
|
"""Возвращает список пользователей со статусом PENDING"""
|
||||||
|
cursor = self.collection.find({"status": UserStatus.PENDING})
|
||||||
|
users = await cursor.to_list(length=100)
|
||||||
|
for user in users:
|
||||||
|
user["id"] = str(user["_id"])
|
||||||
|
return users
|
||||||
|
|
||||||
|
async def approve_user(self, username: str):
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"username": username},
|
||||||
|
{"$set": {"status": UserStatus.ALLOWED}}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def deny_user(self, username: str):
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"username": username},
|
||||||
|
{"$set": {"status": UserStatus.DENIED}}
|
||||||
|
)
|
||||||
|
|
||||||
async def create_or_update_request(self, user: User):
|
async def create_or_update_request(self, user: User):
|
||||||
"""
|
"""
|
||||||
Обновляет дату последнего запроса и ставит статус PENDING.
|
Обновляет дату последнего запроса и ставит статус PENDING.
|
||||||
Сохраняет всю инфу о юзере.
|
Сохраняет всю инфу о юзере (для Telegram пользователей).
|
||||||
"""
|
"""
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
data = {
|
data = {
|
||||||
@@ -30,7 +85,8 @@ class UsersRepo:
|
|||||||
"username": user.username,
|
"username": user.username,
|
||||||
"full_name": user.full_name,
|
"full_name": user.full_name,
|
||||||
"status": UserStatus.PENDING,
|
"status": UserStatus.PENDING,
|
||||||
"last_request_date": now
|
"last_request_date": now,
|
||||||
|
"is_email_user": False
|
||||||
}
|
}
|
||||||
await self.collection.update_one(
|
await self.collection.update_one(
|
||||||
{"user_id": user.id},
|
{"user_id": user.id},
|
||||||
|
|||||||
@@ -45,3 +45,10 @@ urllib3==2.6.3
|
|||||||
uvicorn==0.40.0
|
uvicorn==0.40.0
|
||||||
websockets==15.0.1
|
websockets==15.0.1
|
||||||
yarl==1.22.0
|
yarl==1.22.0
|
||||||
|
aioboto3==13.3.0
|
||||||
|
passlib[argon2]==1.7.4
|
||||||
|
python-jose[cryptography]==3.3.0
|
||||||
|
python-multipart==0.0.22
|
||||||
|
email-validator
|
||||||
|
prometheus-fastapi-instrumentator
|
||||||
|
PyJWT
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from aiogram.fsm.state import State, StatesGroup
|
|||||||
from aiogram.types import *
|
from aiogram.types import *
|
||||||
from aiogram import Router, F, Bot
|
from aiogram import Router, F, Bot
|
||||||
|
|
||||||
from models.Asset import Asset, AssetType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
|
||||||
@@ -63,7 +63,8 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
|||||||
character_image_data=file_io.read(),
|
character_image_data=file_io.read(),
|
||||||
character_image_tg_id=None,
|
character_image_tg_id=None,
|
||||||
character_image_doc_tg_id=file_id,
|
character_image_doc_tg_id=file_id,
|
||||||
character_bio=bio
|
character_bio=bio,
|
||||||
|
created_by=str(message.from_user.id)
|
||||||
)
|
)
|
||||||
file_io.close()
|
file_io.close()
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
|||||||
file_bytes = await bot.download_file(file_info.file_path)
|
file_bytes = await bot.download_file(file_info.file_path)
|
||||||
file_io = file_bytes.read()
|
file_io = file_bytes.read()
|
||||||
avatar_asset = await dao.assets.create_asset(
|
avatar_asset = await dao.assets.create_asset(
|
||||||
Asset(name="avatar.png", type=AssetType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||||
tg_doc_file_id=file_id))
|
tg_doc_file_id=file_id))
|
||||||
char.avatar_image = avatar_asset.link
|
char.avatar_image = avatar_asset.link
|
||||||
# Отправляем подтверждение
|
# Отправляем подтверждение
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from aiogram.types import *
|
|||||||
|
|
||||||
import keyboards
|
import keyboards
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from models.Asset import Asset, AssetType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Character import Character
|
from models.Character import Character
|
||||||
from models.enums import AspectRatios, Quality, GenType
|
from models.enums import AspectRatios, Quality, GenType
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
@@ -50,8 +50,8 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
|
|||||||
gemini=gemini)
|
gemini=gemini)
|
||||||
await wait_msg.delete()
|
await wait_msg.delete()
|
||||||
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
doc = await message.answer_document(res[0], caption="Generated result 💫")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.IMAGE, data=res[0].data,
|
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
|
||||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None))
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
|
||||||
|
|
||||||
|
|
||||||
@router.message(Command("gen_mode"))
|
@router.message(Command("gen_mode"))
|
||||||
@@ -236,9 +236,6 @@ async def handle_album(
|
|||||||
for msg in album:
|
for msg in album:
|
||||||
if msg.photo:
|
if msg.photo:
|
||||||
file_ids.append(msg.photo[-1].file_id)
|
file_ids.append(msg.photo[-1].file_id)
|
||||||
elif msg.video:
|
|
||||||
# Если нужно, можно добавить обработку видео (пока пропускаем)
|
|
||||||
pass
|
|
||||||
|
|
||||||
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
|
await message.answer(f"📥 Принято {len(album)} файлов. Начинаю генерацию...")
|
||||||
wait_msg = await message.answer("🎨 Генерирую...")
|
wait_msg = await message.answer("🎨 Генерирую...")
|
||||||
@@ -260,9 +257,10 @@ async def handle_album(
|
|||||||
if generated_files:
|
if generated_files:
|
||||||
for file in generated_files:
|
for file in generated_files:
|
||||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||||
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
|
||||||
linked_char_id = data["char_id"]))
|
linked_char_id = data["char_id"],
|
||||||
|
created_by=str(message.from_user.id)))
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Генерация не вернула изображений.")
|
await message.answer("❌ Генерация не вернула изображений.")
|
||||||
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
|
||||||
@@ -315,9 +313,10 @@ async def gen_mode_start(
|
|||||||
if generated_files:
|
if generated_files:
|
||||||
for file in generated_files:
|
for file in generated_files:
|
||||||
doc = await message.answer_document(file, caption="✨ Generated result")
|
doc = await message.answer_document(file, caption="✨ Generated result")
|
||||||
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.IMAGE, data=file.data,
|
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
|
||||||
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
|
||||||
linked_char_id=data["char_id"]))
|
linked_char_id=data["char_id"],
|
||||||
|
created_by=str(message.from_user.id)))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
await message.answer("❌ Ничего не сгенерировалось.")
|
await message.answer("❌ Ничего не сгенерировалось.")
|
||||||
|
|||||||
22
tests/test_api_protection.py
Normal file
22
tests/test_api_protection.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from main import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
def test_api_protection():
|
||||||
|
# 1. Assets
|
||||||
|
response = client.get("/api/assets")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
# 2. Characters
|
||||||
|
response = client.get("/api/characters")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
# 3. Generations
|
||||||
|
response = client.get("/api/generations")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
# 4. Upload Asset (POST)
|
||||||
|
response = client.post("/api/assets/upload")
|
||||||
|
assert response.status_code == 401
|
||||||
107
tests/test_auth_flow.py
Normal file
107
tests/test_auth_flow.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from datetime import datetime
|
||||||
|
from main import app
|
||||||
|
from api.endpoints.auth import get_users_repo
|
||||||
|
from repos.user_repo import UsersRepo, UserStatus
|
||||||
|
from utils.security import get_password_hash
|
||||||
|
|
||||||
|
# Mock Repository
|
||||||
|
class MockUsersRepo:
|
||||||
|
def __init__(self):
|
||||||
|
self.users = {}
|
||||||
|
|
||||||
|
async def get_user_by_username(self, username: str):
|
||||||
|
return self.users.get(username)
|
||||||
|
|
||||||
|
async def create_user(self, username: str, password: str, full_name: str = None):
|
||||||
|
if username in self.users:
|
||||||
|
raise ValueError("User with this username already exists")
|
||||||
|
|
||||||
|
user = {
|
||||||
|
"username": username,
|
||||||
|
"hashed_password": get_password_hash(password),
|
||||||
|
"full_name": full_name,
|
||||||
|
"status": UserStatus.PENDING,
|
||||||
|
"is_email_user": False,
|
||||||
|
"is_admin": False,
|
||||||
|
"created_at": datetime.now()
|
||||||
|
}
|
||||||
|
self.users[username] = user
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_pending_users(self):
|
||||||
|
return [u for u in self.users.values() if u["status"] == UserStatus.PENDING]
|
||||||
|
|
||||||
|
async def approve_user(self, username: str):
|
||||||
|
if username in self.users:
|
||||||
|
self.users[username]["status"] = UserStatus.ALLOWED
|
||||||
|
|
||||||
|
async def deny_user(self, username: str):
|
||||||
|
if username in self.users:
|
||||||
|
self.users[username]["status"] = UserStatus.DENIED
|
||||||
|
|
||||||
|
mock_repo = MockUsersRepo()
|
||||||
|
|
||||||
|
# Override Dependency
|
||||||
|
app.dependency_overrides[get_users_repo] = lambda: mock_repo
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
def test_auth_flow_with_approval():
|
||||||
|
# 1. Register (User)
|
||||||
|
user_data = {
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "password123",
|
||||||
|
"full_name": "New User"
|
||||||
|
}
|
||||||
|
response = client.post("/auth/register", json=user_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["message"] == "Registration successful. Please wait for administrator approval."
|
||||||
|
|
||||||
|
# 2. Try Login (User) -> Should Fail (Pending)
|
||||||
|
login_data = {
|
||||||
|
"username": "newuser",
|
||||||
|
"password": "password123"
|
||||||
|
}
|
||||||
|
response = client.post("/auth/token", data=login_data)
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "not approved" in response.json()["detail"]
|
||||||
|
|
||||||
|
# 3. Setup Admin (Backdoor for test)
|
||||||
|
mock_repo.users["admin"] = {
|
||||||
|
"username": "admin",
|
||||||
|
"hashed_password": get_password_hash("adminpass"),
|
||||||
|
"status": UserStatus.ALLOWED,
|
||||||
|
"is_admin": True,
|
||||||
|
"created_at": datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. Admin Login
|
||||||
|
admin_login = {
|
||||||
|
"username": "admin",
|
||||||
|
"password": "adminpass"
|
||||||
|
}
|
||||||
|
response = client.post("/auth/token", data=admin_login)
|
||||||
|
assert response.status_code == 200
|
||||||
|
admin_token = response.json()["access_token"]
|
||||||
|
admin_auth = {"Authorization": f"Bearer {admin_token}"}
|
||||||
|
|
||||||
|
# 5. List Pending (Admin)
|
||||||
|
response = client.get("/admin/approvals", headers=admin_auth)
|
||||||
|
assert response.status_code == 200
|
||||||
|
users = response.json()
|
||||||
|
assert len(users) >= 1
|
||||||
|
assert users[0]["username"] == "newuser"
|
||||||
|
assert users[0]["status"] == "pending"
|
||||||
|
|
||||||
|
# 6. Approve User (Admin)
|
||||||
|
response = client.post("/admin/approve/newuser", headers=admin_auth)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["message"] == "User newuser approved"
|
||||||
|
|
||||||
|
# 7. Login User (Again) -> Should Success
|
||||||
|
response = client.post("/auth/token", data=login_data)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "access_token" in response.json()
|
||||||
101
tests/test_character_crud.py
Normal file
101
tests/test_character_crud.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from main import app
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.dependency import get_dao
|
||||||
|
from repos.dao import DAO
|
||||||
|
from models.Character import Character
|
||||||
|
|
||||||
|
# Config for test DB
|
||||||
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||||
|
DB_NAME = "bot_db_test_chars"
|
||||||
|
|
||||||
|
# Mock User
|
||||||
|
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||||
|
MOCK_USER = {
|
||||||
|
"_id": MOCK_USER_ID,
|
||||||
|
"username": "testuser",
|
||||||
|
"is_admin": False,
|
||||||
|
"status": "allowed"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Override get_current_user to bypass auth
|
||||||
|
def mock_get_current_user():
|
||||||
|
return MOCK_USER
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||||
|
|
||||||
|
# Setup Real DAO with Test DB
|
||||||
|
client_mongo = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
dao = DAO(client_mongo, db_name=DB_NAME)
|
||||||
|
|
||||||
|
def mock_get_dao():
|
||||||
|
return dao
|
||||||
|
|
||||||
|
app.dependency_overrides[get_dao] = mock_get_dao
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def setup_teardown():
|
||||||
|
# Setup: Ensure clean state
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
def test_character_crud_flow():
|
||||||
|
# 1. Create Character
|
||||||
|
create_payload = {
|
||||||
|
"name": "Test Character",
|
||||||
|
"character_bio": "A bio for test character",
|
||||||
|
"character_image_doc_tg_id": "file_123",
|
||||||
|
"avatar_image": "http://example.com/avatar.jpg"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/characters/", json=create_payload)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
char_data = response.json()
|
||||||
|
assert char_data["name"] == create_payload["name"]
|
||||||
|
assert char_data["created_by"] == MOCK_USER_ID
|
||||||
|
char_id = char_data["id"]
|
||||||
|
assert char_id is not None
|
||||||
|
|
||||||
|
# 2. Get Character
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == char_id
|
||||||
|
|
||||||
|
# 3. Update Character
|
||||||
|
update_payload = {
|
||||||
|
"name": "Updated Name",
|
||||||
|
"character_bio": "Updated bio"
|
||||||
|
}
|
||||||
|
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||||
|
assert response.status_code == 200
|
||||||
|
updated_data = response.json()
|
||||||
|
assert updated_data["name"] == "Updated Name"
|
||||||
|
assert updated_data["character_bio"] == "Updated bio"
|
||||||
|
|
||||||
|
# Verify update persistent
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.json()["name"] == "Updated Name"
|
||||||
|
|
||||||
|
# 4. Delete Character
|
||||||
|
response = client.delete(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 404, "Deleted character should return 404"
|
||||||
64
tests/test_character_integration.py
Normal file
64
tests/test_character_integration.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
# 1. Set Auth Bypass and Test Config
|
||||||
|
os.environ["DB_NAME"] = "bot_db_test_integration"
|
||||||
|
# We keep MONGO_HOST as is (it works in verified script)
|
||||||
|
|
||||||
|
# 2. Import app AFTER setting env
|
||||||
|
from main import app
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
|
||||||
|
# 3. Override Auth
|
||||||
|
MOCK_USER_ID = "507f1f77bcf86cd799439011"
|
||||||
|
MOCK_USER = {
|
||||||
|
"_id": MOCK_USER_ID,
|
||||||
|
"username": "testuser",
|
||||||
|
"is_admin": False,
|
||||||
|
"status": "allowed",
|
||||||
|
"project_ids": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def mock_get_current_user():
|
||||||
|
return MOCK_USER
|
||||||
|
|
||||||
|
app.dependency_overrides[get_current_user] = mock_get_current_user
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
def test_character_crud_lifecycle():
|
||||||
|
# 1. Create
|
||||||
|
create_payload = {
|
||||||
|
"name": "Integration Test Char",
|
||||||
|
"character_bio": "Testing with real app structure",
|
||||||
|
"character_image_doc_tg_id": "doc_123",
|
||||||
|
"avatar_image": "http://example.com/img.jpg"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/api/characters/", json=create_payload)
|
||||||
|
assert response.status_code == 200, response.text
|
||||||
|
char_data = response.json()
|
||||||
|
assert char_data["name"] == create_payload["name"]
|
||||||
|
char_id = char_data["id"]
|
||||||
|
|
||||||
|
# 2. Get
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["id"] == char_id
|
||||||
|
|
||||||
|
# 3. Update
|
||||||
|
update_payload = {"name": "Updated Int Name"}
|
||||||
|
response = client.put(f"/api/characters/{char_id}", json=update_payload)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "Updated Int Name"
|
||||||
|
|
||||||
|
# 4. Delete
|
||||||
|
response = client.delete(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
# 5. Verify Delete
|
||||||
|
response = client.get(f"/api/characters/{char_id}")
|
||||||
|
assert response.status_code == 404
|
||||||
63
tests/test_external_import.py
Executable file
63
tests/test_external_import.py
Executable file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for external generation import API.
|
||||||
|
This script demonstrates how to call the import endpoint with proper HMAC signature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
API_URL = "http://localhost:8090/api/generations/import"
|
||||||
|
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
|
||||||
|
|
||||||
|
# Sample generation data
|
||||||
|
generation_data = {
|
||||||
|
"prompt": "A beautiful sunset over mountains",
|
||||||
|
"tech_prompt": "High quality landscape photography",
|
||||||
|
"image_url": "https://picsum.photos/512/512", # Sample image URL
|
||||||
|
# OR use base64:
|
||||||
|
# "image_data": "base64_encoded_image_string_here",
|
||||||
|
"aspect_ratio": "9:16",
|
||||||
|
"quality": "1k",
|
||||||
|
"created_by": "external_user_123",
|
||||||
|
"execution_time_seconds": 5.2,
|
||||||
|
"token_usage": 1000,
|
||||||
|
"input_token_usage": 200,
|
||||||
|
"output_token_usage": 800
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert to JSON
|
||||||
|
body = json.dumps(generation_data).encode('utf-8')
|
||||||
|
|
||||||
|
# Compute HMAC signature
|
||||||
|
signature = hmac.new(
|
||||||
|
SECRET.encode('utf-8'),
|
||||||
|
body,
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-Signature": signature
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Sending request to {API_URL}")
|
||||||
|
print(f"Signature: {signature}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(API_URL, data=body, headers=headers)
|
||||||
|
print(f"\nStatus Code: {response.status_code}")
|
||||||
|
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
if hasattr(e, 'response'):
|
||||||
|
print(f"Response text: {e.response.text}")
|
||||||
44
tests/test_s3_connection.py
Normal file
44
tests/test_s3_connection.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
|
async def test_s3():
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
|
||||||
|
access_key = os.getenv("MINIO_ACCESS_KEY")
|
||||||
|
secret_key = os.getenv("MINIO_SECRET_KEY")
|
||||||
|
bucket = os.getenv("MINIO_BUCKET")
|
||||||
|
|
||||||
|
print(f"Connecting to {endpoint}, bucket: {bucket}")
|
||||||
|
|
||||||
|
s3 = S3Adapter(endpoint, access_key, secret_key, bucket)
|
||||||
|
|
||||||
|
test_filename = "test_connection.txt"
|
||||||
|
test_data = b"Hello MinIO!"
|
||||||
|
|
||||||
|
print("Uploading...")
|
||||||
|
success = await s3.upload_file(test_filename, test_data)
|
||||||
|
if success:
|
||||||
|
print("Upload successful!")
|
||||||
|
else:
|
||||||
|
print("Upload failed!")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Downloading...")
|
||||||
|
data = await s3.get_file(test_filename)
|
||||||
|
if data == test_data:
|
||||||
|
print("Download successful and data matches!")
|
||||||
|
else:
|
||||||
|
print(f"Download mismatch: {data}")
|
||||||
|
|
||||||
|
print("Deleting...")
|
||||||
|
deleted = await s3.delete_file(test_filename)
|
||||||
|
if deleted:
|
||||||
|
print("Delete successful!")
|
||||||
|
else:
|
||||||
|
print("Delete failed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_s3())
|
||||||
91
tests/verify_albums_manual.py
Normal file
91
tests/verify_albums_manual.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from repos.dao import DAO
|
||||||
|
from models.Album import Album
|
||||||
|
from models.Generation import Generation, GenerationStatus
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
|
# Mock config
|
||||||
|
# Use the same host as aiws.py but different DB
|
||||||
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||||
|
DB_NAME = "bot_db_test_albums"
|
||||||
|
|
||||||
|
async def test_albums():
|
||||||
|
print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...")
|
||||||
|
|
||||||
|
# Needs to run inside a loop from main
|
||||||
|
client = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
dao = DAO(client, db_name=DB_NAME)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Clean up
|
||||||
|
await client[DB_NAME]["albums"].drop()
|
||||||
|
await client[DB_NAME]["generations"].drop()
|
||||||
|
print("✅ Cleaned up test database")
|
||||||
|
|
||||||
|
# 2. Create Album
|
||||||
|
album = Album(name="Test Album", description="A test album")
|
||||||
|
print("Creating album...")
|
||||||
|
album_id = await dao.albums.create_album(album)
|
||||||
|
print(f"✅ Created Album: {album_id}")
|
||||||
|
|
||||||
|
# 3. Create Generations
|
||||||
|
gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||||
|
gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||||
|
|
||||||
|
print("Creating generations...")
|
||||||
|
gen1_id = await dao.generations.create_generation(gen1)
|
||||||
|
gen2_id = await dao.generations.create_generation(gen2)
|
||||||
|
print(f"✅ Created Generations: {gen1_id}, {gen2_id}")
|
||||||
|
|
||||||
|
# 4. Add generations to album
|
||||||
|
print("Adding generations to album...")
|
||||||
|
await dao.albums.add_generation(album_id, gen1_id)
|
||||||
|
await dao.albums.add_generation(album_id, gen2_id)
|
||||||
|
print("✅ Added generations to album")
|
||||||
|
|
||||||
|
# 5. Fetch album and check generation_ids
|
||||||
|
album_fetched = await dao.albums.get_album(album_id)
|
||||||
|
assert album_fetched is not None
|
||||||
|
assert len(album_fetched.generation_ids) == 2
|
||||||
|
assert gen1_id in album_fetched.generation_ids
|
||||||
|
assert gen2_id in album_fetched.generation_ids
|
||||||
|
print("✅ Verified generations in album")
|
||||||
|
|
||||||
|
# 6. Fetch generations by IDs via GenerationRepo
|
||||||
|
generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id])
|
||||||
|
assert len(generations) == 2
|
||||||
|
|
||||||
|
# Ensure ID type match (str vs ObjectId handling in repo)
|
||||||
|
gen_ids_fetched = [g.id for g in generations]
|
||||||
|
assert gen1_id in gen_ids_fetched
|
||||||
|
assert gen2_id in gen_ids_fetched
|
||||||
|
print("✅ Verified fetching generations by IDs")
|
||||||
|
|
||||||
|
# 7. Remove generation
|
||||||
|
print("Removing generation...")
|
||||||
|
await dao.albums.remove_generation(album_id, gen1_id)
|
||||||
|
album_fetched = await dao.albums.get_album(album_id)
|
||||||
|
assert len(album_fetched.generation_ids) == 1
|
||||||
|
assert album_fetched.generation_ids[0] == gen2_id
|
||||||
|
print("✅ Verified removing generation from album")
|
||||||
|
|
||||||
|
print("🎉 Album Verification SUCCESS")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup client
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
try:
|
||||||
|
asyncio.run(test_albums())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
84
tests/verify_minio_integration.py
Normal file
84
tests/verify_minio_integration.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from models.Asset import Asset, AssetType
|
||||||
|
from repos.assets_repo import AssetsRepo
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
|
# Load env to get credentials
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
async def test_integration():
|
||||||
|
print("🚀 Starting integration test...")
|
||||||
|
|
||||||
|
# 1. Setup Dependencies
|
||||||
|
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||||
|
client = AsyncIOMotorClient(mongo_uri)
|
||||||
|
db_name = os.getenv("DB_NAME", "bot_db_test")
|
||||||
|
|
||||||
|
s3_adapter = S3Adapter(
|
||||||
|
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
|
||||||
|
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
|
||||||
|
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
|
||||||
|
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||||
|
)
|
||||||
|
|
||||||
|
repo = AssetsRepo(client, s3_adapter, db_name=db_name)
|
||||||
|
|
||||||
|
# 2. Create Asset with Data and Thumbnail
|
||||||
|
print("📝 Creating asset...")
|
||||||
|
dummy_data = b"image_data_bytes"
|
||||||
|
dummy_thumb = b"thumbnail_bytes"
|
||||||
|
|
||||||
|
asset = Asset(
|
||||||
|
name="test_asset_with_thumb.png",
|
||||||
|
type=AssetType.IMAGE,
|
||||||
|
data=dummy_data,
|
||||||
|
thumbnail=dummy_thumb
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_id = await repo.create_asset(asset)
|
||||||
|
print(f"✅ Asset created with ID: {asset_id}")
|
||||||
|
|
||||||
|
# 3. Verify object names in Mongo (Raw check)
|
||||||
|
print("🔍 Verifying Mongo metadata...")
|
||||||
|
# Used repo to fetch is better
|
||||||
|
fetched_asset = await repo.get_asset(asset_id, with_data=False)
|
||||||
|
|
||||||
|
if fetched_asset.minio_object_name:
|
||||||
|
print(f"✅ minio_object_name set: {fetched_asset.minio_object_name}")
|
||||||
|
else:
|
||||||
|
print("❌ minio_object_name NOT set!")
|
||||||
|
|
||||||
|
if fetched_asset.minio_thumbnail_object_name:
|
||||||
|
print(f"✅ minio_thumbnail_object_name set: {fetched_asset.minio_thumbnail_object_name}")
|
||||||
|
else:
|
||||||
|
print("❌ minio_thumbnail_object_name NOT set!")
|
||||||
|
|
||||||
|
# 4. Fetch Data from S3 via Repo
|
||||||
|
print("📥 Fetching data from MinIO...")
|
||||||
|
full_asset = await repo.get_asset(asset_id, with_data=True)
|
||||||
|
|
||||||
|
if full_asset.data == dummy_data:
|
||||||
|
print("✅ Data matches!")
|
||||||
|
else:
|
||||||
|
print(f"❌ Data mismatch! Got: {full_asset.data}")
|
||||||
|
|
||||||
|
if full_asset.thumbnail == dummy_thumb:
|
||||||
|
print("✅ Thumbnail matches!")
|
||||||
|
else:
|
||||||
|
print(f"❌ Thumbnail mismatch! Got: {full_asset.thumbnail}")
|
||||||
|
|
||||||
|
# 5. Clean up
|
||||||
|
print("🧹 Cleaning up...")
|
||||||
|
deleted = await repo.delete_asset(asset_id)
|
||||||
|
if deleted:
|
||||||
|
print("✅ Asset deleted")
|
||||||
|
else:
|
||||||
|
print("❌ Failed to delete asset")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_integration())
|
||||||
46
utils/external_auth.py
Normal file
46
utils/external_auth.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import hmac
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from fastapi import Header, HTTPException
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
|
||||||
|
"""
|
||||||
|
Verify HMAC-SHA256 signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: Raw request body bytes
|
||||||
|
signature: Signature from X-Signature header
|
||||||
|
secret: Shared secret key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if signature is valid, False otherwise
|
||||||
|
"""
|
||||||
|
expected_signature = hmac.new(
|
||||||
|
secret.encode('utf-8'),
|
||||||
|
body,
|
||||||
|
hashlib.sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
return hmac.compare_digest(signature, expected_signature)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_external_signature(
|
||||||
|
x_signature: Optional[str] = Header(None, alias="X-Signature")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
FastAPI dependency to verify external API signature.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If signature is missing or invalid
|
||||||
|
"""
|
||||||
|
if not x_signature:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Missing X-Signature header"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We'll need to access the raw request body in the endpoint
|
||||||
|
# This dependency just validates the header exists
|
||||||
|
# Actual signature verification happens in the endpoint
|
||||||
|
return x_signature
|
||||||
27
utils/image_utils.py
Normal file
27
utils/image_utils.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from PIL import Image
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def create_thumbnail(image_data: bytes, size: Tuple[int, int] = (800, 800)) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Creates a thumbnail from image bytes.
|
||||||
|
Returns the thumbnail as bytes (JPEG format) or None if failed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Image.open(BytesIO(image_data)) as img:
|
||||||
|
# Convert to RGB if necessary (e.g. for RGBA/P images saving as JPEG)
|
||||||
|
if img.mode in ('RGBA', 'P'):
|
||||||
|
img = img.convert('RGB')
|
||||||
|
|
||||||
|
img.thumbnail(size)
|
||||||
|
|
||||||
|
thumb_io = BytesIO()
|
||||||
|
img.save(thumb_io, format='JPEG', quality=85)
|
||||||
|
thumb_io.seek(0)
|
||||||
|
return thumb_io.read()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create thumbnail: {e}")
|
||||||
|
return None
|
||||||
35
utils/security.py
Normal file
35
utils/security.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Union, Any
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
|
||||||
|
# SECRET_KEY должен быть сложным и секретным в продакшене!
|
||||||
|
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
to_encode = data.copy()
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||||
|
|
||||||
|
# Стандартное поле 'exp' для JWT
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
Reference in New Issue
Block a user