Compare commits
14 Commits
enviroment
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e011805186 | |||
| d9caececd7 | |||
| c1300b7a2d | |||
| f6001f5994 | |||
| e4a39e90c3 | |||
| e976fe1c58 | |||
| ecc8d69039 | |||
| bc9230a49b | |||
| f07105b0e5 | |||
| 9a5d54a373 | |||
| 1868864f76 | |||
| 9e0c522b5f | |||
| e1d941a2cd | |||
| c7c27197c9 |
33
.context.md
Normal file
33
.context.md
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Project Context: AI Char Bot
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Python backend project using FastAPI and MongoDB (Motor).
|
||||||
|
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||||
|
- **Service Layer**: `api/service` (Business logic).
|
||||||
|
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||||
|
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||||
|
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||||
|
|
||||||
|
## Coding Standards & Preferences
|
||||||
|
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||||
|
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||||
|
- **Error Handling**:
|
||||||
|
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||||
|
- Services/Routers handle `HTTPException`.
|
||||||
|
|
||||||
|
## Key Features & Implementation Details
|
||||||
|
- **Generations**:
|
||||||
|
- Managed by `GenerationService` and `GenerationRepo`.
|
||||||
|
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||||
|
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||||
|
- **Ideas**:
|
||||||
|
- Managed by `IdeaService` and `IdeaRepo`.
|
||||||
|
- Can have linked generations.
|
||||||
|
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||||
|
|
||||||
|
## Recent Changes
|
||||||
|
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||||
|
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||||
33
.gemini/AGENTS.md
Normal file
33
.gemini/AGENTS.md
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Project Context: AI Char Bot
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Python backend project using FastAPI and MongoDB (Motor).
|
||||||
|
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||||
|
- **Service Layer**: `api/service` (Business logic).
|
||||||
|
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||||
|
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||||
|
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||||
|
|
||||||
|
## Coding Standards & Preferences
|
||||||
|
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||||
|
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||||
|
- **Error Handling**:
|
||||||
|
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||||
|
- Services/Routers handle `HTTPException`.
|
||||||
|
|
||||||
|
## Key Features & Implementation Details
|
||||||
|
- **Generations**:
|
||||||
|
- Managed by `GenerationService` and `GenerationRepo`.
|
||||||
|
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||||
|
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||||
|
- **Ideas**:
|
||||||
|
- Managed by `IdeaService` and `IdeaRepo`.
|
||||||
|
- Can have linked generations.
|
||||||
|
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||||
|
|
||||||
|
## Recent Changes
|
||||||
|
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||||
|
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||||
@@ -8,7 +8,7 @@ from google import genai
|
|||||||
from google.genai import types
|
from google.genai import types
|
||||||
|
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality, TextModel, ImageModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -19,10 +19,6 @@ class GoogleAdapter:
|
|||||||
raise ValueError("API Key for Gemini is missing")
|
raise ValueError("API Key for Gemini is missing")
|
||||||
self.client = genai.Client(api_key=api_key)
|
self.client = genai.Client(api_key=api_key)
|
||||||
|
|
||||||
# Константы моделей
|
|
||||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
|
||||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
|
||||||
|
|
||||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||||||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
||||||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
||||||
@@ -41,16 +37,19 @@ class GoogleAdapter:
|
|||||||
logger.info("Preparing content with no images")
|
logger.info("Preparing content with no images")
|
||||||
return contents, opened_images
|
return contents, opened_images
|
||||||
|
|
||||||
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
|
def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Генерация текста (Чат или Vision).
|
Генерация текста (Чат или Vision).
|
||||||
Возвращает строку с ответом.
|
Возвращает строку с ответом.
|
||||||
"""
|
"""
|
||||||
|
if model not in [m.value for m in TextModel]:
|
||||||
|
raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
|
||||||
|
|
||||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||||
logger.info(f"Generating text: {prompt}")
|
logger.info(f"Generating text: {prompt} with model: {model}")
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
model=self.TEXT_MODEL,
|
model=model,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=types.GenerateContentConfig(
|
config=types.GenerateContentConfig(
|
||||||
response_modalities=['TEXT'],
|
response_modalities=['TEXT'],
|
||||||
@@ -74,21 +73,23 @@ class GoogleAdapter:
|
|||||||
for img in opened_images:
|
for img in opened_images:
|
||||||
img.close()
|
img.close()
|
||||||
|
|
||||||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Генерация изображений (Text-to-Image или Image-to-Image).
|
Генерация изображений (Text-to-Image или Image-to-Image).
|
||||||
Возвращает список байтовых потоков (готовых к отправке).
|
Возвращает список байтовых потоков (готовых к отправке).
|
||||||
"""
|
"""
|
||||||
|
if model not in [m.value for m in ImageModel]:
|
||||||
|
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
|
||||||
|
|
||||||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||||||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
|
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
|
||||||
|
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
token_usage = 0
|
token_usage = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.models.generate_content(
|
response = self.client.models.generate_content(
|
||||||
model=self.IMAGE_MODEL,
|
model=model,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
config=types.GenerateContentConfig(
|
config=types.GenerateContentConfig(
|
||||||
response_modalities=['IMAGE'],
|
response_modalities=['IMAGE'],
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional, BinaryIO
|
from typing import Optional, BinaryIO, AsyncGenerator
|
||||||
import aioboto3
|
import aioboto3
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
import os
|
import os
|
||||||
@@ -56,11 +56,25 @@ class S3Adapter:
|
|||||||
print(f"Error downloading from S3: {e}")
|
print(f"Error downloading from S3: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def stream_file(self, object_name: str, chunk_size: int = 65536):
|
async def get_file_size(self, object_name: str) -> Optional[int]:
|
||||||
|
"""Returns the size of the file in bytes."""
|
||||||
|
try:
|
||||||
|
async with self._get_client() as client:
|
||||||
|
response = await client.head_object(Bucket=self.bucket_name, Key=object_name)
|
||||||
|
return response['ContentLength']
|
||||||
|
except ClientError as e:
|
||||||
|
print(f"Error getting file size from S3: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def stream_file(self, object_name: str, range_header: Optional[str] = None, chunk_size: int = 65536) -> AsyncGenerator[bytes, None]:
|
||||||
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
|
||||||
try:
|
try:
|
||||||
async with self._get_client() as client:
|
async with self._get_client() as client:
|
||||||
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
|
args = {'Bucket': self.bucket_name, 'Key': object_name}
|
||||||
|
if range_header:
|
||||||
|
args['Range'] = range_header
|
||||||
|
|
||||||
|
response = await client.get_object(**args)
|
||||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
# aioboto3 Body is an aiohttp StreamReader wrapper
|
||||||
body = response['Body']
|
body = response['Body']
|
||||||
|
|
||||||
|
|||||||
4
aiws.py
4
aiws.py
@@ -45,6 +45,7 @@ from api.endpoints.project_router import router as project_api_router
|
|||||||
from api.endpoints.idea_router import router as idea_api_router
|
from api.endpoints.idea_router import router as idea_api_router
|
||||||
from api.endpoints.post_router import router as post_api_router
|
from api.endpoints.post_router import router as post_api_router
|
||||||
from api.endpoints.environment_router import router as environment_api_router
|
from api.endpoints.environment_router import router as environment_api_router
|
||||||
|
from api.endpoints.inspiration_router import router as inspiration_api_router
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -133,7 +134,7 @@ async def start_scheduler(service: GenerationService):
|
|||||||
try:
|
try:
|
||||||
logger.info("Running scheduler for stacked generation killing")
|
logger.info("Running scheduler for stacked generation killing")
|
||||||
await service.cleanup_stale_generations()
|
await service.cleanup_stale_generations()
|
||||||
await service.cleanup_old_data(days=2)
|
await service.cleanup_old_data(days=14)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -222,6 +223,7 @@ app.include_router(project_api_router)
|
|||||||
app.include_router(idea_api_router)
|
app.include_router(idea_api_router)
|
||||||
app.include_router(post_api_router)
|
app.include_router(post_api_router)
|
||||||
app.include_router(environment_api_router)
|
app.include_router(environment_api_router)
|
||||||
|
app.include_router(inspiration_api_router)
|
||||||
|
|
||||||
# Prometheus Metrics (Instrument after all routers are added)
|
# Prometheus Metrics (Instrument after all routers are added)
|
||||||
Instrumentator(
|
Instrumentator(
|
||||||
|
|||||||
@@ -63,3 +63,8 @@ from api.service.post_service import PostService
|
|||||||
|
|
||||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||||
return PostService(dao)
|
return PostService(dao)
|
||||||
|
|
||||||
|
from api.service.inspiration_service import InspirationService
|
||||||
|
|
||||||
|
def get_inspiration_service(dao: DAO = Depends(get_dao), s3_adapter: S3Adapter = Depends(get_s3_adapter)) -> InspirationService:
|
||||||
|
return InspirationService(dao, s3_adapter)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, List
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
@@ -54,7 +54,7 @@ class UserResponse(BaseModel):
|
|||||||
class Config:
|
class Config:
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
@router.get("/approvals", response_model=List[UserResponse])
|
@router.get("/approvals", response_model=list[UserResponse])
|
||||||
async def list_pending_users(
|
async def list_pending_users(
|
||||||
admin: Annotated[dict, Depends(get_current_admin)],
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
repo: Annotated[UsersRepo, Depends(get_users_repo)]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from fastapi import APIRouter, HTTPException, status, Request
|
from fastapi import APIRouter, HTTPException, status, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -13,18 +12,18 @@ router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
|||||||
|
|
||||||
class AlbumCreateRequest(BaseModel):
|
class AlbumCreateRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
|
||||||
class AlbumUpdateRequest(BaseModel):
|
class AlbumUpdateRequest(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
|
||||||
class AlbumResponse(BaseModel):
|
class AlbumResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
generation_ids: List[str] = []
|
generation_ids: list[str] = []
|
||||||
cover_asset_id: Optional[str] = None # Not implemented yet
|
cover_asset_id: str | None = None # Not implemented yet
|
||||||
|
|
||||||
@router.post("", response_model=AlbumResponse)
|
@router.post("", response_model=AlbumResponse)
|
||||||
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||||
@@ -32,7 +31,7 @@ async def create_album(request: Request, album_in: AlbumCreateRequest):
|
|||||||
album = await service.create_album(name=album_in.name, description=album_in.description)
|
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||||
return AlbumResponse(**album.model_dump())
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
@router.get("", response_model=List[AlbumResponse])
|
@router.get("", response_model=list[AlbumResponse])
|
||||||
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||||
service: AlbumService = request.app.state.album_service
|
service: AlbumService = request.app.state.album_service
|
||||||
albums = await service.get_albums(limit=limit, offset=offset)
|
albums = await service.get_albums(limit=limit, offset=offset)
|
||||||
@@ -77,7 +76,7 @@ async def remove_generation_from_album(request: Request, album_id: str, generati
|
|||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||||
return {"status": "success"}
|
return {"status": "success"}
|
||||||
|
|
||||||
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
@router.get("/{album_id}/generations", response_model=list[GenerationResponse])
|
||||||
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||||
service: AlbumService = request.app.state.album_service
|
service: AlbumService = request.app.state.album_service
|
||||||
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Dict, Any
|
from typing import Any
|
||||||
|
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
@@ -42,8 +42,9 @@ async def get_asset(
|
|||||||
if not asset:
|
if not asset:
|
||||||
raise HTTPException(status_code=404, detail="Asset not found")
|
raise HTTPException(status_code=404, detail="Asset not found")
|
||||||
|
|
||||||
headers = {
|
base_headers = {
|
||||||
"Cache-Control": "public, max-age=31536000, immutable"
|
"Cache-Control": "public, max-age=31536000, immutable",
|
||||||
|
"Accept-Ranges": "bytes"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Thumbnail: маленький, можно грузить в RAM
|
# Thumbnail: маленький, можно грузить в RAM
|
||||||
@@ -51,17 +52,70 @@ async def get_asset(
|
|||||||
if asset.minio_thumbnail_object_name and s3_adapter:
|
if asset.minio_thumbnail_object_name and s3_adapter:
|
||||||
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
|
||||||
if thumb_bytes:
|
if thumb_bytes:
|
||||||
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
|
return Response(content=thumb_bytes, media_type="image/jpeg", headers=base_headers)
|
||||||
# Fallback: thumbnail in DB
|
# Fallback: thumbnail in DB
|
||||||
if asset.thumbnail:
|
if asset.thumbnail:
|
||||||
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
|
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=base_headers)
|
||||||
# No thumbnail available — fall through to main content
|
# No thumbnail available — fall through to main content
|
||||||
|
|
||||||
# Main content: стримим из S3 без загрузки в RAM
|
# Main content: стримим из S3 без загрузки в RAM
|
||||||
if asset.minio_object_name and s3_adapter:
|
if asset.minio_object_name and s3_adapter:
|
||||||
content_type = "image/png"
|
content_type = "image/png"
|
||||||
# if asset.content_type == AssetContentType.VIDEO:
|
if asset.content_type == AssetContentType.VIDEO:
|
||||||
# content_type = "video/mp4"
|
content_type = "video/mp4" # Or detect from extension if stored
|
||||||
|
elif asset.content_type == AssetContentType.IMAGE:
|
||||||
|
content_type = "image/png" # Default for images
|
||||||
|
|
||||||
|
# Better content type detection based on extension if possible, but for now this is okay
|
||||||
|
if asset.minio_object_name.endswith(".mp4"):
|
||||||
|
content_type = "video/mp4"
|
||||||
|
elif asset.minio_object_name.endswith(".mov"):
|
||||||
|
content_type = "video/quicktime"
|
||||||
|
elif asset.minio_object_name.endswith(".jpg") or asset.minio_object_name.endswith(".jpeg"):
|
||||||
|
content_type = "image/jpeg"
|
||||||
|
|
||||||
|
# Handle Range requests for video streaming
|
||||||
|
range_header = request.headers.get("range")
|
||||||
|
file_size = await s3_adapter.get_file_size(asset.minio_object_name)
|
||||||
|
|
||||||
|
if range_header and file_size:
|
||||||
|
try:
|
||||||
|
# Parse Range header: bytes=start-end
|
||||||
|
byte_range = range_header.replace("bytes=", "")
|
||||||
|
start_str, end_str = byte_range.split("-")
|
||||||
|
start = int(start_str)
|
||||||
|
end = int(end_str) if end_str else file_size - 1
|
||||||
|
|
||||||
|
# Validate range
|
||||||
|
if start >= file_size:
|
||||||
|
# 416 Range Not Satisfiable
|
||||||
|
return Response(status_code=416, headers={"Content-Range": f"bytes */{file_size}"})
|
||||||
|
|
||||||
|
chunk_size = end - start + 1
|
||||||
|
|
||||||
|
headers = base_headers.copy()
|
||||||
|
headers.update({
|
||||||
|
"Content-Range": f"bytes {start}-{end}/{file_size}",
|
||||||
|
"Content-Length": str(chunk_size),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Pass the exact range string to S3
|
||||||
|
s3_range = f"bytes={start}-{end}"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
s3_adapter.stream_file(asset.minio_object_name, range_header=s3_range),
|
||||||
|
status_code=206,
|
||||||
|
headers=headers,
|
||||||
|
media_type=content_type
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # Fallback to full content if range parsing fails
|
||||||
|
|
||||||
|
# Full content response
|
||||||
|
headers = base_headers.copy()
|
||||||
|
if file_size:
|
||||||
|
headers["Content-Length"] = str(file_size)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
s3_adapter.stream_file(asset.minio_object_name),
|
s3_adapter.stream_file(asset.minio_object_name),
|
||||||
media_type=content_type,
|
media_type=content_type,
|
||||||
@@ -70,7 +124,7 @@ async def get_asset(
|
|||||||
|
|
||||||
# Fallback: data stored in DB (legacy)
|
# Fallback: data stored in DB (legacy)
|
||||||
if asset.data:
|
if asset.data:
|
||||||
return Response(content=asset.data, media_type="image/png", headers=headers)
|
return Response(content=asset.data, media_type="image/png", headers=base_headers)
|
||||||
|
|
||||||
raise HTTPException(status_code=404, detail="Asset data not found")
|
raise HTTPException(status_code=404, detail="Asset data not found")
|
||||||
|
|
||||||
@@ -81,22 +135,22 @@ async def delete_orphan_assets_from_minio(
|
|||||||
*,
|
*,
|
||||||
assets_collection: str = "assets",
|
assets_collection: str = "assets",
|
||||||
generations_collection: str = "generations",
|
generations_collection: str = "generations",
|
||||||
asset_type: Optional[str] = "generated",
|
asset_type: str | None = "generated",
|
||||||
project_id: Optional[str] = None,
|
project_id: str | None = None,
|
||||||
dry_run: bool = True,
|
dry_run: bool = True,
|
||||||
mark_assets_deleted: bool = False,
|
mark_assets_deleted: bool = False,
|
||||||
batch_size: int = 500,
|
batch_size: int = 500,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
|
||||||
assets = db[assets_collection]
|
assets = db[assets_collection]
|
||||||
|
|
||||||
match_assets: Dict[str, Any] = {}
|
match_assets: dict[str, Any] = {}
|
||||||
if asset_type is not None:
|
if asset_type is not None:
|
||||||
match_assets["type"] = asset_type
|
match_assets["type"] = asset_type
|
||||||
if project_id is not None:
|
if project_id is not None:
|
||||||
match_assets["project_id"] = project_id
|
match_assets["project_id"] = project_id
|
||||||
|
|
||||||
pipeline: List[Dict[str, Any]] = [
|
pipeline: list[dict[str, Any]] = [
|
||||||
{"$match": match_assets} if match_assets else {"$match": {}},
|
{"$match": match_assets} if match_assets else {"$match": {}},
|
||||||
{
|
{
|
||||||
"$lookup": {
|
"$lookup": {
|
||||||
@@ -138,8 +192,8 @@ async def delete_orphan_assets_from_minio(
|
|||||||
|
|
||||||
deleted_objects = 0
|
deleted_objects = 0
|
||||||
deleted_assets = 0
|
deleted_assets = 0
|
||||||
errors: List[Dict[str, Any]] = []
|
errors: list[dict[str, Any]] = []
|
||||||
orphan_asset_ids: List[ObjectId] = []
|
orphan_asset_ids: list[ObjectId] = []
|
||||||
|
|
||||||
async for asset in cursor:
|
async for asset in cursor:
|
||||||
aid = asset["_id"]
|
aid = asset["_id"]
|
||||||
@@ -205,7 +259,7 @@ async def delete_asset(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("", dependencies=[Depends(get_current_user)])
|
@router.get("", dependencies=[Depends(get_current_user)])
|
||||||
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
|
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str | None = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: str | None = Depends(get_project_id)) -> AssetsResponse:
|
||||||
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
|
||||||
|
|
||||||
user_id_filter = current_user["id"]
|
user_id_filter = current_user["id"]
|
||||||
@@ -232,10 +286,10 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
|
|||||||
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
|
||||||
async def upload_asset(
|
async def upload_asset(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
linked_char_id: Optional[str] = Form(None),
|
linked_char_id: str | None = Form(None),
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id)
|
project_id: str | None = Depends(get_project_id)
|
||||||
):
|
):
|
||||||
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
|
||||||
if not file.content_type:
|
if not file.content_type:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Any, Coroutine, Optional
|
from typing import Any, Coroutine
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -23,15 +23,15 @@ from api.dependency import get_project_id
|
|||||||
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[Character])
|
@router.get("/", response_model=list[Character])
|
||||||
async def get_characters(
|
async def get_characters(
|
||||||
request: Request,
|
request: Request,
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0
|
offset: int = 0
|
||||||
) -> List[Character]:
|
) -> list[Character]:
|
||||||
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
||||||
|
|
||||||
user_id_filter = str(current_user["_id"])
|
user_id_filter = str(current_user["_id"])
|
||||||
@@ -102,7 +102,7 @@ async def get_character_by_id(character_id: str, request: Request, dao: DAO = De
|
|||||||
@router.post("/", response_model=Character)
|
@router.post("/", response_model=Character)
|
||||||
async def create_character(
|
async def create_character(
|
||||||
char_req: CharacterCreateRequest,
|
char_req: CharacterCreateRequest,
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
current_user: dict = Depends(get_current_user)
|
current_user: dict = Depends(get_current_user)
|
||||||
) -> Character:
|
) -> Character:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from starlette import status
|
from starlette import status
|
||||||
|
|
||||||
@@ -50,7 +49,7 @@ async def create_environment(
|
|||||||
return created_env
|
return created_env
|
||||||
|
|
||||||
|
|
||||||
@router.get("/character/{character_id}", response_model=List[Environment])
|
@router.get("/character/{character_id}", response_model=list[Environment])
|
||||||
async def get_character_environments(
|
async def get_character_environments(
|
||||||
character_id: str,
|
character_id: str,
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
@@ -92,6 +91,18 @@ async def update_environment(
|
|||||||
if not update_data:
|
if not update_data:
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
# Verify assets exist if provided
|
||||||
|
if "asset_ids" in update_data:
|
||||||
|
if update_data["asset_ids"] is None:
|
||||||
|
del update_data["asset_ids"]
|
||||||
|
elif update_data["asset_ids"]:
|
||||||
|
# Verify all assets exist using batch check
|
||||||
|
assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"])
|
||||||
|
if len(assets) != len(update_data["asset_ids"]):
|
||||||
|
found_ids = {a.id for a in assets}
|
||||||
|
missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}")
|
||||||
|
|
||||||
success = await dao.environments.update_env(env_id, update_data)
|
success = await dao.environments.update_env(env_id, update_data)
|
||||||
if not success:
|
if not success:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||||
from fastapi.params import Depends
|
from fastapi.params import Depends
|
||||||
@@ -19,7 +17,8 @@ from api.models import (
|
|||||||
PromptRequest,
|
PromptRequest,
|
||||||
GenerationGroupResponse,
|
GenerationGroupResponse,
|
||||||
FinancialReport,
|
FinancialReport,
|
||||||
ExternalGenerationRequest
|
ExternalGenerationRequest,
|
||||||
|
NsfwRequest
|
||||||
)
|
)
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
@@ -30,85 +29,88 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||||
|
|
||||||
|
|
||||||
|
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
|
||||||
|
"""Helper to check if user has access to project."""
|
||||||
|
if not project_id:
|
||||||
|
return
|
||||||
|
project = await dao.projects.get_project(project_id)
|
||||||
|
if not project or str(current_user["_id"]) not in project.members:
|
||||||
|
raise HTTPException(status_code=403, detail="Project access denied")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
async def ask_prompt_assistant(
|
||||||
generation_service: GenerationService = Depends(
|
prompt_request: PromptRequest,
|
||||||
get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
current_user: dict = Depends(get_current_user)
|
||||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
) -> PromptResponse:
|
||||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
|
||||||
|
generated_prompt = await generation_service.ask_prompt_assistant(
|
||||||
|
prompt_request.prompt,
|
||||||
|
prompt_request.linked_assets,
|
||||||
|
prompt_request.model
|
||||||
|
)
|
||||||
return PromptResponse(prompt=generated_prompt)
|
return PromptResponse(prompt=generated_prompt)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/prompt-from-image", response_model=PromptResponse)
|
@router.post("/prompt-from-image", response_model=PromptResponse)
|
||||||
async def prompt_from_image(
|
async def prompt_from_image(
|
||||||
prompt: Optional[str] = Form(None),
|
prompt: str | None = Form(None),
|
||||||
images: List[UploadFile] = File(...),
|
model: str = Form("gemini-3.1-pro-preview"),
|
||||||
|
images: list[UploadFile] = File(...),
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user)
|
current_user: dict = Depends(get_current_user)
|
||||||
) -> PromptResponse:
|
) -> PromptResponse:
|
||||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
images_bytes = [await img.read() for img in images]
|
||||||
images_bytes = []
|
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
|
||||||
for image in images:
|
|
||||||
content = await image.read()
|
|
||||||
images_bytes.append(content)
|
|
||||||
|
|
||||||
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
|
|
||||||
return PromptResponse(prompt=generated_prompt)
|
return PromptResponse(prompt=generated_prompt)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=GenerationsResponse)
|
@router.get("", response_model=GenerationsResponse)
|
||||||
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
|
async def get_generations(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
character_id: str | None = None,
|
||||||
current_user: dict = Depends(get_current_user),
|
limit: int = 10,
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
offset: int = 0,
|
||||||
dao: DAO = Depends(get_dao)):
|
only_liked: bool = False,
|
||||||
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: str | None = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)
|
||||||
|
):
|
||||||
|
await check_project_access(project_id, current_user, dao)
|
||||||
|
|
||||||
user_id_filter = str(current_user["_id"])
|
# If project_id is set, we don't filter by user to show all project-wide generations
|
||||||
if project_id:
|
created_by_filter = None if project_id else str(current_user["_id"])
|
||||||
project = await dao.projects.get_project(project_id)
|
only_liked_by = str(current_user["_id"]) if only_liked else None
|
||||||
if not project or str(current_user["_id"]) not in project.members:
|
|
||||||
raise HTTPException(status_code=403, detail="Project access denied")
|
|
||||||
user_id_filter = None # Show all project generations
|
|
||||||
|
|
||||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
return await generation_service.get_generations(
|
||||||
|
character_id=character_id,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
created_by=created_by_filter,
|
||||||
|
project_id=project_id,
|
||||||
|
only_liked_by=only_liked_by,
|
||||||
|
current_user_id=str(current_user["_id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/usage", response_model=FinancialReport)
|
@router.get("/usage", response_model=FinancialReport)
|
||||||
async def get_usage_report(
|
async def get_usage_report(
|
||||||
breakdown: Optional[str] = None, # "user" or "project"
|
breakdown: str | None = None, # "user" or "project"
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
dao: DAO = Depends(get_dao)
|
dao: DAO = Depends(get_dao)
|
||||||
) -> FinancialReport:
|
) -> FinancialReport:
|
||||||
"""
|
await check_project_access(project_id, current_user, dao)
|
||||||
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
|
||||||
If project_id is provided, returns stats for that project.
|
user_id_filter = str(current_user["_id"]) if not project_id else None
|
||||||
Otherwise, returns stats for the current user.
|
|
||||||
"""
|
|
||||||
user_id_filter = str(current_user["_id"])
|
|
||||||
breakdown_by = None
|
breakdown_by = None
|
||||||
|
|
||||||
if project_id:
|
if breakdown == "user":
|
||||||
# Permission check
|
breakdown_by = "created_by"
|
||||||
project = await dao.projects.get_project(project_id)
|
elif breakdown == "project":
|
||||||
if not project or str(current_user["_id"]) not in project.members:
|
breakdown_by = "project_id"
|
||||||
raise HTTPException(status_code=403, detail="Project access denied")
|
|
||||||
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
|
||||||
if breakdown == "user":
|
|
||||||
breakdown_by = "created_by"
|
|
||||||
elif breakdown == "project":
|
|
||||||
breakdown_by = "project_id"
|
|
||||||
else:
|
|
||||||
# Default: Stats for current user
|
|
||||||
if breakdown == "project":
|
|
||||||
breakdown_by = "project_id"
|
|
||||||
elif breakdown == "user":
|
|
||||||
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
|
||||||
# No, if project_id is None, it's personal.
|
|
||||||
breakdown_by = "created_by"
|
|
||||||
|
|
||||||
return await generation_service.get_financial_report(
|
return await generation_service.get_financial_report(
|
||||||
user_id=user_id_filter,
|
user_id=user_id_filter,
|
||||||
@@ -116,58 +118,61 @@ async def get_usage_report(
|
|||||||
breakdown_by=breakdown_by
|
breakdown_by=breakdown_by
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
|
||||||
async def post_generation(generation: GenerationRequest, request: Request,
|
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
|
||||||
current_user: dict = Depends(get_current_user),
|
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
|
||||||
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
|
|
||||||
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
|
|
||||||
|
|
||||||
|
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||||
|
async def post_generation(
|
||||||
|
generation: GenerationRequest,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
project_id: str | None = Depends(get_project_id),
|
||||||
|
dao: DAO = Depends(get_dao)
|
||||||
|
) -> GenerationGroupResponse:
|
||||||
|
await check_project_access(project_id, current_user, dao)
|
||||||
if project_id:
|
if project_id:
|
||||||
project = await dao.projects.get_project(project_id)
|
|
||||||
if not project or str(current_user["_id"]) not in project.members:
|
|
||||||
raise HTTPException(status_code=403, detail="Project access denied")
|
|
||||||
generation.project_id = project_id
|
generation.project_id = project_id
|
||||||
|
|
||||||
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
|
return await generation_service.create_generation_task(
|
||||||
|
generation,
|
||||||
|
user_id=str(current_user.get("_id"))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/running")
|
@router.get("/running")
|
||||||
async def get_running_generations(request: Request,
|
async def get_running_generations(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
dao: DAO = Depends(get_dao)):
|
dao: DAO = Depends(get_dao)
|
||||||
|
):
|
||||||
|
await check_project_access(project_id, current_user, dao)
|
||||||
|
user_id_filter = None if project_id else str(current_user["_id"])
|
||||||
|
|
||||||
user_id_filter = str(current_user["_id"])
|
return await generation_service.get_running_generations(
|
||||||
if project_id:
|
user_id=user_id_filter,
|
||||||
project = await dao.projects.get_project(project_id)
|
project_id=project_id
|
||||||
if not project or str(current_user["_id"]) not in project.members:
|
)
|
||||||
raise HTTPException(status_code=403, detail="Project access denied")
|
|
||||||
user_id_filter = None
|
|
||||||
|
|
||||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
|
||||||
async def get_generation_group(group_id: str,
|
async def get_generation_group(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
group_id: str,
|
||||||
current_user: dict = Depends(get_current_user)):
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
logger.info(f"get_generation_group called for group_id: {group_id}")
|
current_user: dict = Depends(get_current_user)
|
||||||
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
|
):
|
||||||
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
|
||||||
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{generation_id}", response_model=GenerationResponse)
|
@router.get("/{generation_id}", response_model=GenerationResponse)
|
||||||
async def get_generation(generation_id: str,
|
async def get_generation(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_id: str,
|
||||||
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
current_user: dict = Depends(get_current_user)
|
||||||
gen = await generation_service.get_generation(generation_id)
|
) -> GenerationResponse:
|
||||||
if gen and gen.created_by != str(current_user["_id"]):
|
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||||
|
if not gen:
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
|
|
||||||
|
if gen.created_by != str(current_user["_id"]):
|
||||||
# Check project membership
|
# Check project membership
|
||||||
is_member = False
|
is_member = False
|
||||||
if gen.project_id:
|
if gen.project_id:
|
||||||
@@ -180,6 +185,41 @@ async def get_generation(generation_id: str,
|
|||||||
return gen
|
return gen
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{generation_id}/like", response_model=dict)
|
||||||
|
async def toggle_like(
|
||||||
|
generation_id: str,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
|
||||||
|
if is_liked is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
|
return {"is_liked": is_liked}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def mark_generation_nsfw(
|
||||||
|
generation_id: str,
|
||||||
|
request: NsfwRequest,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||||
|
if not gen:
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
|
|
||||||
|
if gen.created_by != str(current_user["_id"]):
|
||||||
|
is_member = False
|
||||||
|
if gen.project_id:
|
||||||
|
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||||
|
if project and str(current_user["_id"]) in project.members:
|
||||||
|
is_member = True
|
||||||
|
|
||||||
|
if not is_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=GenerationResponse)
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
@@ -188,35 +228,18 @@ async def import_external_generation(
|
|||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
x_signature: str = Header(..., alias="X-Signature")
|
x_signature: str = Header(..., alias="X-Signature")
|
||||||
) -> GenerationResponse:
|
) -> GenerationResponse:
|
||||||
"""
|
|
||||||
Import a generation from an external source.
|
|
||||||
Requires server-to-server authentication via HMAC signature.
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info("import_external_generation called")
|
|
||||||
# Get raw request body for signature verification
|
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
# Verify signature
|
|
||||||
secret = settings.EXTERNAL_API_SECRET
|
secret = settings.EXTERNAL_API_SECRET
|
||||||
if not secret:
|
if not secret:
|
||||||
logger.error("EXTERNAL_API_SECRET not configured")
|
|
||||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||||
|
|
||||||
if not verify_signature(body, x_signature, secret):
|
if not verify_signature(body, x_signature, secret):
|
||||||
logger.warning("Invalid signature for external generation import")
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||||
|
|
||||||
# Parse request body
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(body.decode('utf-8'))
|
data = json.loads(body.decode('utf-8'))
|
||||||
external_gen = ExternalGenerationRequest(**data)
|
external_gen = ExternalGenerationRequest(**data)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse request body: {e}")
|
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
|
||||||
|
|
||||||
# Import generation
|
|
||||||
try:
|
|
||||||
generation = await generation_service.import_external_generation(external_gen)
|
generation = await generation_service.import_external_generation(external_gen)
|
||||||
return GenerationResponse(**generation.model_dump())
|
return GenerationResponse(**generation.model_dump())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -225,11 +248,11 @@ async def import_external_generation(
|
|||||||
|
|
||||||
|
|
||||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def delete_generation(generation_id: str,
|
async def delete_generation(
|
||||||
generation_service: GenerationService = Depends(get_generation_service),
|
generation_id: str,
|
||||||
current_user: dict = Depends(get_current_user)):
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
current_user: dict = Depends(get_current_user)
|
||||||
deleted = await generation_service.delete_generation(generation_id)
|
):
|
||||||
if not deleted:
|
if not await generation_service.delete_generation(generation_id):
|
||||||
raise HTTPException(status_code=404, detail="Generation not found")
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||||
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
from api.dependency import get_idea_service, get_project_id, get_generation_service
|
||||||
from api.endpoints.auth import get_current_user
|
from api.endpoints.auth import get_current_user
|
||||||
@@ -14,17 +13,23 @@ router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
|||||||
@router.post("", response_model=Idea)
|
@router.post("", response_model=Idea)
|
||||||
async def create_idea(
|
async def create_idea(
|
||||||
request: IdeaCreateRequest,
|
request: IdeaCreateRequest,
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
idea_service: IdeaService = Depends(get_idea_service)
|
||||||
):
|
):
|
||||||
pid = project_id or request.project_id
|
pid = project_id or request.project_id
|
||||||
|
|
||||||
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
|
return await idea_service.create_idea(
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
project_id=pid,
|
||||||
|
user_id=str(current_user["_id"]),
|
||||||
|
inspiration_id=request.inspiration_id
|
||||||
|
)
|
||||||
|
|
||||||
@router.get("", response_model=List[IdeaResponse])
|
@router.get("", response_model=list[IdeaResponse])
|
||||||
async def get_ideas(
|
async def get_ideas(
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
@@ -48,7 +53,12 @@ async def update_idea(
|
|||||||
request: IdeaUpdateRequest,
|
request: IdeaUpdateRequest,
|
||||||
idea_service: IdeaService = Depends(get_idea_service)
|
idea_service: IdeaService = Depends(get_idea_service)
|
||||||
):
|
):
|
||||||
idea = await idea_service.update_idea(idea_id, request.name, request.description)
|
idea = await idea_service.update_idea(
|
||||||
|
idea_id=idea_id,
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
inspiration_id=request.inspiration_id
|
||||||
|
)
|
||||||
if not idea:
|
if not idea:
|
||||||
raise HTTPException(status_code=404, detail="Idea not found")
|
raise HTTPException(status_code=404, detail="Idea not found")
|
||||||
return idea
|
return idea
|
||||||
@@ -68,18 +78,10 @@ async def get_idea_generations(
|
|||||||
idea_id: str,
|
idea_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
generation_service: GenerationService = Depends(get_generation_service)
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
# Depending on how generation service implements filtering by idea_id.
|
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset, current_user_id=str(current_user["_id"]))
|
||||||
# We might need to update generation_service to support getting by idea_id directly
|
|
||||||
# or ensure generic get_generations supports it.
|
|
||||||
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
|
|
||||||
# Let's check generation_service.get_generations signature again.
|
|
||||||
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
|
|
||||||
# I need to update GenerationService.get_generations too!
|
|
||||||
|
|
||||||
# For now, let's assume I will update it.
|
|
||||||
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
|
|
||||||
|
|
||||||
@router.post("/{idea_id}/generations/{generation_id}")
|
@router.post("/{idea_id}/generations/{generation_id}")
|
||||||
async def add_generation_to_idea(
|
async def add_generation_to_idea(
|
||||||
|
|||||||
94
api/endpoints/inspiration_router.py
Normal file
94
api/endpoints/inspiration_router.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from api.dependency import get_inspiration_service, get_project_id
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.models.InspirationRequest import InspirationCreateRequest, InspirationResponse, InspirationListResponse
|
||||||
|
from api.service.inspiration_service import InspirationService
|
||||||
|
from models.Inspiration import Inspiration
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/inspirations", tags=["Inspirations"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=InspirationResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_inspiration(
|
||||||
|
request: InspirationCreateRequest,
|
||||||
|
project_id: str | None = Depends(get_project_id),
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
service: InspirationService = Depends(get_inspiration_service)
|
||||||
|
):
|
||||||
|
pid = project_id or request.project_id
|
||||||
|
|
||||||
|
inspiration = await service.create_inspiration(
|
||||||
|
source_url=request.source_url,
|
||||||
|
created_by=str(current_user["_id"]),
|
||||||
|
project_id=pid,
|
||||||
|
caption=request.caption
|
||||||
|
)
|
||||||
|
return inspiration
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=InspirationListResponse)
|
||||||
|
async def get_inspirations(
|
||||||
|
project_id: str | None = Depends(get_project_id),
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
current_user: dict = Depends(get_current_user),
|
||||||
|
service: InspirationService = Depends(get_inspiration_service)
|
||||||
|
):
|
||||||
|
# If project_id is provided, filter by it. Otherwise, filter by user.
|
||||||
|
# Or maybe we want to see all user's inspirations if no project is selected?
|
||||||
|
# Let's follow the pattern: if project_id is present, show project's inspirations.
|
||||||
|
# If not, show user's personal inspirations (where project_id is None) OR all user's inspirations?
|
||||||
|
# Usually "My Inspirations" means created by me.
|
||||||
|
|
||||||
|
# Let's assume:
|
||||||
|
# If project_id -> filter by project_id (and maybe created_by if we want strict ownership, but usually project members share)
|
||||||
|
# If no project_id -> filter by created_by (personal)
|
||||||
|
|
||||||
|
pid = project_id
|
||||||
|
uid = str(current_user["_id"])
|
||||||
|
|
||||||
|
inspirations = await service.get_inspirations(project_id=pid, created_by=uid if not pid else None, limit=limit, offset=offset)
|
||||||
|
total_count = await service.dao.inspirations.count_inspirations(project_id=pid, created_by=uid if not pid else None)
|
||||||
|
|
||||||
|
return InspirationListResponse(
|
||||||
|
inspirations=[InspirationResponse(**inspiration.model_dump()) for inspiration in inspirations],
|
||||||
|
total_count=total_count
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{inspiration_id}", response_model=InspirationResponse)
|
||||||
|
async def get_inspiration(
|
||||||
|
inspiration_id: str,
|
||||||
|
service: InspirationService = Depends(get_inspiration_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
inspiration = await service.get_inspiration(inspiration_id)
|
||||||
|
if not inspiration:
|
||||||
|
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||||
|
return inspiration
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{inspiration_id}/complete", response_model=InspirationResponse)
|
||||||
|
async def mark_inspiration_complete(
|
||||||
|
inspiration_id: str,
|
||||||
|
is_completed: bool = True,
|
||||||
|
service: InspirationService = Depends(get_inspiration_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
inspiration = await service.mark_as_completed(inspiration_id, is_completed)
|
||||||
|
if not inspiration:
|
||||||
|
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||||
|
return inspiration
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{inspiration_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_inspiration(
|
||||||
|
inspiration_id: str,
|
||||||
|
service: InspirationService = Depends(get_inspiration_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
success = await service.delete_inspiration(inspiration_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail="Inspiration not found")
|
||||||
|
return None
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
@@ -14,7 +13,7 @@ router = APIRouter(prefix="/api/posts", tags=["posts"])
|
|||||||
@router.post("", response_model=Post)
|
@router.post("", response_model=Post)
|
||||||
async def create_post(
|
async def create_post(
|
||||||
request: PostCreateRequest,
|
request: PostCreateRequest,
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
post_service: PostService = Depends(get_post_service),
|
post_service: PostService = Depends(get_post_service),
|
||||||
):
|
):
|
||||||
@@ -28,13 +27,13 @@ async def create_post(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=List[Post])
|
@router.get("", response_model=list[Post])
|
||||||
async def get_posts(
|
async def get_posts(
|
||||||
project_id: Optional[str] = Depends(get_project_id),
|
project_id: str | None = Depends(get_project_id),
|
||||||
limit: int = 200,
|
limit: int = 200,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
date_from: Optional[datetime] = None,
|
date_from: datetime | None = None,
|
||||||
date_to: Optional[datetime] = None,
|
date_to: datetime | None = None,
|
||||||
current_user: dict = Depends(get_current_user),
|
current_user: dict = Depends(get_current_user),
|
||||||
post_service: PostService = Depends(get_post_service),
|
post_service: PostService = Depends(get_post_service),
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
@@ -12,7 +11,7 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
|
|||||||
|
|
||||||
class ProjectCreate(BaseModel):
|
class ProjectCreate(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
|
||||||
class ProjectMemberResponse(BaseModel):
|
class ProjectMemberResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
@@ -21,9 +20,9 @@ class ProjectMemberResponse(BaseModel):
|
|||||||
class ProjectResponse(BaseModel):
|
class ProjectResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
owner_id: str
|
owner_id: str
|
||||||
members: List[ProjectMemberResponse]
|
members: list[ProjectMemberResponse]
|
||||||
is_owner: bool = False
|
is_owner: bool = False
|
||||||
|
|
||||||
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
||||||
@@ -78,7 +77,7 @@ async def create_project(
|
|||||||
|
|
||||||
return await _get_project_response(new_project, user_id, dao)
|
return await _get_project_response(new_project, user_id, dao)
|
||||||
|
|
||||||
@router.get("", response_model=List[ProjectResponse])
|
@router.get("", response_model=list[ProjectResponse])
|
||||||
async def get_my_projects(
|
async def get_my_projects(
|
||||||
dao: DAO = Depends(get_dao),
|
dao: DAO = Depends(get_dao),
|
||||||
current_user: dict = Depends(get_current_user)
|
current_user: dict = Depends(get_current_user)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -11,10 +10,10 @@ class AssetResponse(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
type: str # uploaded / generated
|
type: str # uploaded / generated
|
||||||
content_type: str # image / prompt
|
content_type: str # image / prompt
|
||||||
linked_char_id: Optional[str] = None
|
linked_char_id: str | None = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
url: Optional[str] = None
|
url: str | None = None
|
||||||
|
|
||||||
class AssetsResponse(BaseModel):
|
class AssetsResponse(BaseModel):
|
||||||
assets: List[AssetResponse]
|
assets: list[AssetResponse]
|
||||||
total_count: int
|
total_count: int
|
||||||
@@ -1,18 +1,17 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class CharacterCreateRequest(BaseModel):
|
class CharacterCreateRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
character_bio: str
|
character_bio: str
|
||||||
character_image_doc_tg_id: Optional[str] = None
|
character_image_doc_tg_id: str | None = None
|
||||||
avatar_image: Optional[str] = None
|
avatar_image: str | None = None
|
||||||
character_image_tg_id: Optional[str] = None
|
character_image_tg_id: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
|
||||||
class CharacterUpdateRequest(BaseModel):
|
class CharacterUpdateRequest(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
character_bio: Optional[str] = None
|
character_bio: str | None = None
|
||||||
character_image_doc_tg_id: Optional[str] = None
|
character_image_doc_tg_id: str | None = None
|
||||||
avatar_image: Optional[str] = None
|
avatar_image: str | None = None
|
||||||
character_image_tg_id: Optional[str] = None
|
character_image_tg_id: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentCreate(BaseModel):
|
class EnvironmentCreate(BaseModel):
|
||||||
character_id: str
|
character_id: str
|
||||||
name: str = Field(..., min_length=1)
|
name: str = Field(..., min_length=1)
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
asset_ids: Optional[List[str]] = []
|
asset_ids: list[str] | None = []
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentUpdate(BaseModel):
|
class EnvironmentUpdate(BaseModel):
|
||||||
name: Optional[str] = Field(None, min_length=1)
|
name: str | None = Field(None, min_length=1)
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
asset_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class AssetToEnvironment(BaseModel):
|
class AssetToEnvironment(BaseModel):
|
||||||
@@ -19,4 +19,4 @@ class AssetToEnvironment(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AssetsToEnvironment(BaseModel):
|
class AssetsToEnvironment(BaseModel):
|
||||||
asset_ids: List[str]
|
asset_ids: list[str]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
@@ -7,27 +6,31 @@ class ExternalGenerationRequest(BaseModel):
|
|||||||
"""Request model for importing external generations."""
|
"""Request model for importing external generations."""
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
tech_prompt: Optional[str] = None
|
tech_prompt: str | None = None
|
||||||
|
|
||||||
# Image can be provided as base64 string OR URL (one must be provided)
|
# Image can be provided as base64 string OR URL (one must be provided)
|
||||||
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
image_data: str | None = Field(None, description="Base64-encoded image data")
|
||||||
image_url: Optional[str] = Field(None, description="URL to download image from")
|
image_url: str | None = Field(None, description="URL to download image from")
|
||||||
|
|
||||||
|
nsfw: bool = False
|
||||||
|
|
||||||
# Generation metadata
|
# Generation metadata
|
||||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||||
quality: Quality = Quality.ONEK
|
quality: Quality = Quality.ONEK
|
||||||
|
model: str | None = None
|
||||||
|
seed: int | None = None
|
||||||
|
|
||||||
# Optional linking
|
# Optional linking
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: str | None = None
|
||||||
created_by: str = Field(..., description="User ID from external system")
|
created_by: str = Field(..., description="User ID from external system")
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
|
||||||
# Performance metrics
|
# Performance metrics
|
||||||
execution_time_seconds: Optional[float] = None
|
execution_time_seconds: float | None = None
|
||||||
api_execution_time_seconds: Optional[float] = None
|
api_execution_time_seconds: float | None = None
|
||||||
token_usage: Optional[int] = None
|
token_usage: int | None = None
|
||||||
input_token_usage: Optional[int] = None
|
input_token_usage: int | None = None
|
||||||
output_token_usage: Optional[int] = None
|
output_token_usage: int | None = None
|
||||||
|
|
||||||
def validate_image_source(self):
|
def validate_image_source(self):
|
||||||
"""Ensure at least one image source is provided."""
|
"""Ensure at least one image source is provided."""
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
class UsageStats(BaseModel):
|
class UsageStats(BaseModel):
|
||||||
total_runs: int
|
total_runs: int
|
||||||
@@ -9,10 +8,10 @@ class UsageStats(BaseModel):
|
|||||||
total_cost: float
|
total_cost: float
|
||||||
|
|
||||||
class UsageByEntity(BaseModel):
|
class UsageByEntity(BaseModel):
|
||||||
entity_id: Optional[str] = None
|
entity_id: str | None = None
|
||||||
stats: UsageStats
|
stats: UsageStats
|
||||||
|
|
||||||
class FinancialReport(BaseModel):
|
class FinancialReport(BaseModel):
|
||||||
summary: UsageStats
|
summary: UsageStats
|
||||||
by_user: Optional[List[UsageByEntity]] = None
|
by_user: list[UsageByEntity] | None = None
|
||||||
by_project: Optional[List[UsageByEntity]] = None
|
by_project: list[UsageByEntity] | None = None
|
||||||
|
|||||||
@@ -1,67 +1,78 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.Asset import Asset
|
||||||
from models.Generation import GenerationStatus
|
from models.Generation import GenerationStatus
|
||||||
from models.enums import AspectRatios, Quality, GenType
|
from models.enums import AspectRatios, Quality, GenType, ImageModel, TextModel
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest(BaseModel):
|
class GenerationRequest(BaseModel):
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: str | None = None
|
||||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
|
||||||
quality: Quality = Quality.ONEK
|
quality: Quality = Quality.ONEK
|
||||||
prompt: str
|
prompt: str
|
||||||
telegram_id: Optional[int] = None
|
model: ImageModel = Field(default=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW)
|
||||||
|
telegram_id: int | None = None
|
||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: list[str]
|
||||||
environment_id: Optional[str] = None
|
environment_id: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
idea_id: Optional[str] = None
|
idea_id: str | None = None
|
||||||
|
nsfw: bool = False
|
||||||
count: int = Field(default=1, ge=1, le=10)
|
count: int = Field(default=1, ge=1, le=10)
|
||||||
|
|
||||||
|
|
||||||
|
class NsfwRequest(BaseModel):
|
||||||
|
is_nsfw: bool
|
||||||
|
|
||||||
|
|
||||||
class GenerationsResponse(BaseModel):
|
class GenerationsResponse(BaseModel):
|
||||||
generations: List["GenerationResponse"]
|
generations: list["GenerationResponse"]
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|
||||||
|
|
||||||
class GenerationResponse(BaseModel):
|
class GenerationResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
status: GenerationStatus
|
status: GenerationStatus
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: str | None = None
|
||||||
|
project_id: str | None = None
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: str | None = None
|
||||||
aspect_ratio: AspectRatios
|
aspect_ratio: AspectRatios
|
||||||
quality: Quality
|
quality: Quality
|
||||||
prompt: str
|
prompt: str
|
||||||
tech_prompt: Optional[str] = None
|
model: ImageModel | None = None
|
||||||
assets_list: List[str]
|
seed: int | None = None
|
||||||
result_list: List[str] = []
|
tech_prompt: str | None = None
|
||||||
result: Optional[str] = None
|
assets_list: list[str]
|
||||||
execution_time_seconds: Optional[float] = None
|
result_list: list[str] = []
|
||||||
api_execution_time_seconds: Optional[float] = None
|
result: str | None = None
|
||||||
token_usage: Optional[int] = None
|
execution_time_seconds: float | None = None
|
||||||
input_token_usage: Optional[int] = None
|
api_execution_time_seconds: float | None = None
|
||||||
output_token_usage: Optional[int] = None
|
token_usage: int | None = None
|
||||||
|
input_token_usage: int | None = None
|
||||||
|
output_token_usage: int | None = None
|
||||||
progress: int = 0
|
progress: int = 0
|
||||||
cost: Optional[float] = None
|
cost: float | None = None
|
||||||
created_by: Optional[str] = None
|
created_by: str | None = None
|
||||||
generation_group_id: Optional[str] = None
|
generation_group_id: str | None = None
|
||||||
idea_id: Optional[str] = None
|
idea_id: str | None = None
|
||||||
|
likes_count: int = 0
|
||||||
|
is_liked: bool = False
|
||||||
|
nsfw: bool = False
|
||||||
created_at: datetime = datetime.now(UTC)
|
created_at: datetime = datetime.now(UTC)
|
||||||
updated_at: datetime = datetime.now(UTC)
|
updated_at: datetime = datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
class GenerationGroupResponse(BaseModel):
|
class GenerationGroupResponse(BaseModel):
|
||||||
generation_group_id: str
|
generation_group_id: str
|
||||||
generations: List[GenerationResponse]
|
generations: list[GenerationResponse]
|
||||||
|
|
||||||
|
|
||||||
class PromptRequest(BaseModel):
|
class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
linked_assets: List[str] = []
|
model: TextModel = Field(default=TextModel.GEMINI_3_1_PRO_PREVIEW)
|
||||||
|
linked_assets: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
class PromptResponse(BaseModel):
|
class PromptResponse(BaseModel):
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from models.Idea import Idea
|
from models.Idea import Idea
|
||||||
from api.models.GenerationRequest import GenerationResponse
|
from api.models.GenerationRequest import GenerationResponse
|
||||||
|
|
||||||
class IdeaCreateRequest(BaseModel):
|
class IdeaCreateRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
project_id: Optional[str] = None # Optional in body if passed via header/dependency
|
project_id: str | None = None # Optional in body if passed via header/dependency
|
||||||
|
inspiration_id: str | None = None
|
||||||
|
|
||||||
class IdeaUpdateRequest(BaseModel):
|
class IdeaUpdateRequest(BaseModel):
|
||||||
name: Optional[str] = None
|
name: str | None = None
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
|
inspiration_id: str | None = None
|
||||||
|
|
||||||
class IdeaResponse(Idea):
|
class IdeaResponse(Idea):
|
||||||
last_generation: Optional[GenerationResponse] = None
|
last_generation: GenerationResponse | None = None
|
||||||
|
|||||||
28
api/models/InspirationRequest.py
Normal file
28
api/models/InspirationRequest.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.Inspiration import Inspiration
|
||||||
|
|
||||||
|
|
||||||
|
class InspirationCreateRequest(BaseModel):
|
||||||
|
source_url: str
|
||||||
|
caption: str | None = None
|
||||||
|
project_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InspirationResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
source_url: str
|
||||||
|
caption: str | None = None
|
||||||
|
asset_id: str
|
||||||
|
is_completed: bool
|
||||||
|
created_by: str
|
||||||
|
project_id: str | None = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class InspirationListResponse(BaseModel):
|
||||||
|
inspirations: list[InspirationResponse]
|
||||||
|
total_count: int
|
||||||
@@ -1,19 +1,18 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class PostCreateRequest(BaseModel):
|
class PostCreateRequest(BaseModel):
|
||||||
date: datetime
|
date: datetime
|
||||||
topic: str
|
topic: str
|
||||||
generation_ids: List[str] = []
|
generation_ids: list[str] = []
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class PostUpdateRequest(BaseModel):
|
class PostUpdateRequest(BaseModel):
|
||||||
date: Optional[datetime] = None
|
date: datetime | None = None
|
||||||
topic: Optional[str] = None
|
topic: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class AddGenerationsRequest(BaseModel):
|
class AddGenerationsRequest(BaseModel):
|
||||||
generation_ids: List[str]
|
generation_ids: list[str]
|
||||||
|
|||||||
@@ -2,6 +2,6 @@ from .AssetDTO import AssetResponse, AssetsResponse
|
|||||||
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||||
from .ExternalGenerationDTO import ExternalGenerationRequest
|
from .ExternalGenerationDTO import ExternalGenerationRequest
|
||||||
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
||||||
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
|
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse, NsfwRequest
|
||||||
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||||
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||||
|
|||||||
@@ -9,18 +9,19 @@ from uuid import uuid4
|
|||||||
import httpx
|
import httpx
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.models import FinancialReport, UsageStats, UsageByEntity
|
from api.models import (
|
||||||
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
FinancialReport, UsageStats, UsageByEntity,
|
||||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||||
|
)
|
||||||
from models.Asset import Asset, AssetType, AssetContentType
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
from models.Generation import Generation, GenerationStatus
|
from models.Generation import Generation, GenerationStatus
|
||||||
from models.enums import AspectRatios, Quality
|
from models.enums import AspectRatios, Quality
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
from utils.image_utils import create_thumbnail
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -28,36 +29,32 @@ logger = logging.getLogger(__name__)
|
|||||||
generation_semaphore = asyncio.Semaphore(4)
|
generation_semaphore = asyncio.Semaphore(4)
|
||||||
|
|
||||||
|
|
||||||
# --- Вспомогательная функция генерации ---
|
|
||||||
async def generate_image_task(
|
async def generate_image_task(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
media_group_bytes: List[bytes],
|
media_group_bytes: List[bytes],
|
||||||
aspect_ratio: AspectRatios,
|
aspect_ratio: AspectRatios,
|
||||||
quality: Quality,
|
quality: Quality,
|
||||||
|
model: str,
|
||||||
gemini: GoogleAdapter,
|
gemini: GoogleAdapter,
|
||||||
|
|
||||||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Обертка для вызова синхронного метода Gemini в отдельном потоке.
|
Wrapper for calling Gemini's synchronous method in a separate thread.
|
||||||
Возвращает список байтов сгенерированных изображений.
|
|
||||||
"""
|
"""
|
||||||
try :
|
try:
|
||||||
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
|
||||||
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
|
|
||||||
result = await asyncio.to_thread(
|
result = await asyncio.to_thread(
|
||||||
gemini.generate_image,
|
gemini.generate_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
images_list=media_group_bytes,
|
images_list=media_group_bytes,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
|
model=model,
|
||||||
)
|
)
|
||||||
generated_images_io, metrics = result
|
generated_images_io, metrics = result
|
||||||
|
|
||||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||||
except GoogleGenerationException as e:
|
except GoogleGenerationException:
|
||||||
raise e
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Освобождаем входные данные — они больше не нужны
|
|
||||||
del media_group_bytes
|
del media_group_bytes
|
||||||
|
|
||||||
images_bytes = []
|
images_bytes = []
|
||||||
@@ -66,414 +63,176 @@ async def generate_image_task(
|
|||||||
img_io.seek(0)
|
img_io.seek(0)
|
||||||
images_bytes.append(img_io.read())
|
images_bytes.append(img_io.read())
|
||||||
img_io.close()
|
img_io.close()
|
||||||
# Освобождаем список BytesIO сразу
|
|
||||||
del generated_images_io
|
del generated_images_io
|
||||||
|
|
||||||
return images_bytes, metrics
|
return images_bytes, metrics
|
||||||
|
|
||||||
|
|
||||||
class GenerationService:
|
class GenerationService:
|
||||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
self.gemini = gemini
|
self.gemini = gemini
|
||||||
self.s3_adapter = s3_adapter
|
self.s3_adapter = s3_adapter
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
|
# --- Public API ---
|
||||||
|
|
||||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||||
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
|
future_prompt = (
|
||||||
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
|
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
||||||
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
|
"User may upload reference image too. I will provide sources prompt entered by user. "
|
||||||
future_prompt += prompt
|
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||||
|
f"USER_ENTERED_PROMPT: {prompt}"
|
||||||
|
)
|
||||||
assets_data = []
|
assets_data = []
|
||||||
if assets is not None:
|
if assets:
|
||||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||||
assets_data.extend(asset.data for asset in assets_db)
|
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
|
||||||
logger.info(future_prompt)
|
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||||
logger.info(generated_prompt)
|
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||||
return generated_prompt
|
return generated_prompt
|
||||||
|
|
||||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||||
if user_prompt:
|
if user_prompt:
|
||||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||||
|
|
||||||
technical_prompt += "Provide ONLY the detailed prompt."
|
technical_prompt += "Provide ONLY the detailed prompt."
|
||||||
|
|
||||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
|
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||||
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
current_user_id = kwargs.pop('current_user_id', None)
|
||||||
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
|
generations = await self.dao.generations.get_generations(**kwargs)
|
||||||
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
|
total_count = await self.dao.generations.count_generations(
|
||||||
return GenerationsResponse(generations=generations, total_count=total_count)
|
character_id=kwargs.get('character_id'),
|
||||||
|
created_by=kwargs.get('created_by'),
|
||||||
|
project_id=kwargs.get('project_id'),
|
||||||
|
idea_id=kwargs.get('idea_id'),
|
||||||
|
only_liked_by=kwargs.get('only_liked_by')
|
||||||
|
)
|
||||||
|
return GenerationsResponse(
|
||||||
|
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
|
||||||
|
total_count=total_count
|
||||||
|
)
|
||||||
|
|
||||||
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
|
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
|
||||||
gen = await self.dao.generations.get_generation(generation_id)
|
gen = await self.dao.generations.get_generation(generation_id)
|
||||||
if gen is None:
|
return self._map_to_response(gen, current_user_id) if gen else None
|
||||||
return None
|
|
||||||
else:
|
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||||
return GenerationResponse(**gen.model_dump())
|
return await self.dao.generations.toggle_like(generation_id, user_id)
|
||||||
|
|
||||||
|
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||||
|
generations = await self.dao.generations.get_generations_by_group(group_id)
|
||||||
|
return GenerationGroupResponse(
|
||||||
|
generation_group_id=group_id,
|
||||||
|
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
|
||||||
|
res = GenerationResponse(**gen.model_dump())
|
||||||
|
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
|
||||||
|
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
|
||||||
|
return res
|
||||||
|
|
||||||
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
|
||||||
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
|
||||||
|
|
||||||
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
|
||||||
count = generation_request.count
|
|
||||||
|
|
||||||
if generation_group_id is None:
|
if generation_group_id is None:
|
||||||
generation_group_id = str(uuid4())
|
generation_group_id = str(uuid4())
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for _ in range(count):
|
for _ in range(generation_request.count):
|
||||||
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
|
||||||
results.append(gen_response)
|
results.append(gen_response)
|
||||||
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
|
||||||
|
|
||||||
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
|
|
||||||
gen_id = None
|
|
||||||
generation_model = None
|
|
||||||
|
|
||||||
if generation_request.environment_id and not generation_request.linked_character_id:
|
|
||||||
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
|
|
||||||
|
|
||||||
try:
|
|
||||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
|
||||||
if user_id:
|
|
||||||
generation_model.created_by = user_id
|
|
||||||
if generation_group_id:
|
|
||||||
generation_model.generation_group_id = generation_group_id
|
|
||||||
|
|
||||||
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
|
|
||||||
if generation_request.idea_id:
|
|
||||||
generation_model.idea_id = generation_request.idea_id
|
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation_model)
|
|
||||||
generation_model.id = gen_id
|
|
||||||
|
|
||||||
async def runner(gen):
|
|
||||||
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
|
|
||||||
try:
|
|
||||||
async with generation_semaphore:
|
|
||||||
logger.info(f"Starting background generation task for ID: {gen.id}")
|
|
||||||
await self.create_generation(gen)
|
|
||||||
logger.info(f"Background generation task finished for ID: {gen.id}")
|
|
||||||
except Exception:
|
|
||||||
# если генерация уже пошла и упала — пометим FAILED
|
|
||||||
try:
|
|
||||||
db_gen = await self.dao.generations.get_generation(gen.id)
|
|
||||||
if db_gen is not None:
|
|
||||||
db_gen.status = GenerationStatus.FAILED
|
|
||||||
await self.dao.generations.update_generation(db_gen)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to mark generation as FAILED")
|
|
||||||
logger.exception("create_generation task failed")
|
|
||||||
|
|
||||||
asyncio.create_task(runner(generation_model))
|
|
||||||
|
|
||||||
return GenerationResponse(**generation_model.model_dump())
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# если не успели создать запись — нечего помечать
|
|
||||||
if gen_id is not None:
|
|
||||||
try:
|
|
||||||
gen = await self.dao.generations.get_generation(gen_id)
|
|
||||||
if gen is not None:
|
|
||||||
gen.status = GenerationStatus.FAILED
|
|
||||||
await self.dao.generations.update_generation(gen)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to mark generation as FAILED in create_generation_task")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def create_generation(self, generation: Generation):
|
async def create_generation(self, generation: Generation):
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||||
|
|
||||||
# 2. Получаем ассеты-референсы (если они есть)
|
# 1. Prepare input
|
||||||
media_group_bytes: List[bytes] = []
|
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||||
generation_prompt = generation.prompt
|
|
||||||
|
|
||||||
# 2.1 Аватар персонажа (всегда первый, если включен)
|
# 2. Run generation with progress simulation
|
||||||
if generation.linked_character_id is not None:
|
|
||||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
|
||||||
if char_info is None:
|
|
||||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
|
||||||
|
|
||||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
|
||||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
|
||||||
if avatar_asset:
|
|
||||||
img_data = await self._get_asset_data(avatar_asset)
|
|
||||||
if img_data:
|
|
||||||
media_group_bytes.append(img_data)
|
|
||||||
|
|
||||||
# 2.2 Явно указанные ассеты
|
|
||||||
if generation.assets_list:
|
|
||||||
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
|
||||||
for asset in explicit_assets:
|
|
||||||
ref_asset_data = await self._get_asset_data(asset)
|
|
||||||
if ref_asset_data:
|
|
||||||
media_group_bytes.append(ref_asset_data)
|
|
||||||
|
|
||||||
# 2.3 Ассеты из окружения (в самый конец)
|
|
||||||
if generation.environment_id:
|
|
||||||
env = await self.dao.environments.get_env(generation.environment_id)
|
|
||||||
if env and env.asset_ids:
|
|
||||||
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
|
|
||||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
|
||||||
for asset in env_assets:
|
|
||||||
img_data = await self._get_asset_data(asset)
|
|
||||||
if img_data:
|
|
||||||
media_group_bytes.append(img_data)
|
|
||||||
|
|
||||||
if media_group_bytes:
|
|
||||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
|
||||||
|
|
||||||
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
|
|
||||||
|
|
||||||
# 3. Запускаем процесс генерации и симуляцию прогресса
|
|
||||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Default to Image Generation (Gemini)
|
|
||||||
generated_bytes_list, metrics = await generate_image_task(
|
generated_bytes_list, metrics = await generate_image_task(
|
||||||
prompt=generation_prompt, # или request.prompt
|
prompt=generation_prompt,
|
||||||
media_group_bytes=media_group_bytes,
|
media_group_bytes=media_group_bytes,
|
||||||
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
|
aspect_ratio=generation.aspect_ratio,
|
||||||
quality=generation.quality,
|
quality=generation.quality,
|
||||||
|
model=generation.model or "gemini-3-pro-image-preview",
|
||||||
gemini=self.gemini
|
gemini=self.gemini
|
||||||
)
|
)
|
||||||
|
self._update_generation_metrics(generation, metrics)
|
||||||
|
|
||||||
|
# 3. Process results
|
||||||
|
created_assets = await self._process_generated_images(generation, generated_bytes_list)
|
||||||
|
|
||||||
# Update metrics from API (Common for both)
|
# 4. Finalize generation record
|
||||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
|
||||||
generation.token_usage = metrics.get("token_usage")
|
|
||||||
generation.input_token_usage = metrics.get("input_token_usage")
|
|
||||||
generation.output_token_usage = metrics.get("output_token_usage")
|
|
||||||
|
|
||||||
except GoogleGenerationException as e:
|
# 5. Notify
|
||||||
generation.status = GenerationStatus.FAILED
|
if generation.telegram_id and self.bot:
|
||||||
generation.failed_reason = str(e)
|
await self._notify_telegram(generation, created_assets)
|
||||||
generation.updated_at = datetime.now(UTC)
|
|
||||||
await self.dao.generations.update_generation(generation)
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
# Тут стоит добавить логирование ошибки
|
|
||||||
logging.error(f"Generation failed: {e}")
|
|
||||||
generation.status = GenerationStatus.FAILED
|
|
||||||
generation.failed_reason = str(e)
|
|
||||||
generation.updated_at = datetime.now(UTC)
|
|
||||||
await self.dao.generations.update_generation(generation)
|
|
||||||
raise e
|
|
||||||
finally:
|
finally:
|
||||||
if not progress_task.done():
|
if not progress_task.done():
|
||||||
progress_task.cancel()
|
progress_task.cancel()
|
||||||
try:
|
try:
|
||||||
await progress_task
|
await progress_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 4. Сохраняем полученные изображения как новые Ассеты
|
|
||||||
created_assets: List[Asset] = []
|
|
||||||
|
|
||||||
for idx, img_bytes in enumerate(generated_bytes_list):
|
|
||||||
# Generate thumbnail
|
|
||||||
thumbnail_bytes = None
|
|
||||||
from utils.image_utils import create_thumbnail
|
|
||||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
|
|
||||||
|
|
||||||
# Save to S3
|
|
||||||
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
|
||||||
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
|
|
||||||
|
|
||||||
new_asset = Asset(
|
|
||||||
name=f"Generated_{generation.linked_character_id}",
|
|
||||||
type=AssetType.GENERATED,
|
|
||||||
content_type=AssetContentType.IMAGE,
|
|
||||||
linked_char_id=generation.linked_character_id,
|
|
||||||
data=None, # Not storing bytes in DB anymore
|
|
||||||
minio_object_name=filename,
|
|
||||||
minio_bucket=self.s3_adapter.bucket_name,
|
|
||||||
thumbnail=thumbnail_bytes,
|
|
||||||
created_by=generation.created_by,
|
|
||||||
project_id=generation.project_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Сохраняем в БД
|
|
||||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
|
||||||
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
|
|
||||||
|
|
||||||
created_assets.append(new_asset)
|
|
||||||
|
|
||||||
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
|
|
||||||
# Предполагаем, что у модели Generation есть поле result_asset_ids
|
|
||||||
result_ids = []
|
|
||||||
for a in created_assets:
|
|
||||||
result_ids.append(a.id)
|
|
||||||
|
|
||||||
generation.result_list = result_ids
|
|
||||||
generation.status = GenerationStatus.DONE
|
|
||||||
generation.progress = 100
|
|
||||||
generation.updated_at = datetime.now(UTC)
|
|
||||||
generation.tech_prompt = generation_prompt
|
|
||||||
|
|
||||||
end_time = datetime.now()
|
|
||||||
generation.execution_time_seconds = (end_time - start_time).total_seconds()
|
|
||||||
|
|
||||||
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
|
|
||||||
|
|
||||||
await self.dao.generations.update_generation(generation)
|
|
||||||
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
|
|
||||||
|
|
||||||
# 6. Send to Telegram if telegram_id is provided
|
|
||||||
if generation.telegram_id and self.bot:
|
|
||||||
try:
|
|
||||||
for asset in created_assets:
|
|
||||||
if asset.data:
|
|
||||||
await self.bot.send_photo(
|
|
||||||
chat_id=generation.telegram_id,
|
|
||||||
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
|
|
||||||
caption=f"Generated from prompt: {generation.prompt[:100]}..."
|
|
||||||
)
|
|
||||||
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
|
|
||||||
if asset.content_type != AssetContentType.IMAGE:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if asset.minio_object_name:
|
|
||||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
|
||||||
return asset.data
|
|
||||||
|
|
||||||
async def _simulate_progress(self, generation: Generation):
|
|
||||||
"""
|
|
||||||
Increments progress from 0 to 90 over ~20 seconds.
|
|
||||||
"""
|
|
||||||
current_progress = 0
|
|
||||||
try:
|
|
||||||
while current_progress < 90:
|
|
||||||
await asyncio.sleep(4)
|
|
||||||
# Random increment between 5 and 15
|
|
||||||
increment = random.randint(5, 15)
|
|
||||||
current_progress = min(current_progress + increment, 90)
|
|
||||||
|
|
||||||
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
|
|
||||||
# But for simplicity here we just use the object we have and save it.
|
|
||||||
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
|
|
||||||
# Assuming simple update is fine for now.
|
|
||||||
generation.progress = current_progress
|
|
||||||
await self.dao.generations.update_generation(generation)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
# Task cancelled, generation finished (or failed)
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error in progress simulation: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def import_external_generation(self, external_gen) -> Generation:
|
async def import_external_generation(self, external_gen) -> Generation:
|
||||||
"""
|
|
||||||
Import a generation from an external source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
external_gen: ExternalGenerationRequest with generation data and image
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created Generation object
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Validate image source
|
|
||||||
external_gen.validate_image_source()
|
external_gen.validate_image_source()
|
||||||
|
|
||||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||||
|
|
||||||
# 1. Process image (download or decode)
|
image_bytes = await self._fetch_external_image(external_gen)
|
||||||
image_bytes = None
|
|
||||||
|
|
||||||
if external_gen.image_url:
|
# Reuse internal processing logic
|
||||||
# Download image from URL
|
new_asset = await self._save_asset(
|
||||||
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
image_bytes=image_bytes,
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
|
||||||
response.raise_for_status()
|
|
||||||
image_bytes = response.content
|
|
||||||
elif external_gen.image_data:
|
|
||||||
# Decode base64 image
|
|
||||||
logger.info("Decoding base64 image data")
|
|
||||||
image_bytes = base64.b64decode(external_gen.image_data)
|
|
||||||
|
|
||||||
if not image_bytes:
|
|
||||||
raise ValueError("Failed to process image data")
|
|
||||||
|
|
||||||
# 2. Generate thumbnail
|
|
||||||
from utils.image_utils import create_thumbnail
|
|
||||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
|
||||||
|
|
||||||
# 3. Save to S3
|
|
||||||
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
|
||||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
|
||||||
|
|
||||||
# 4. Create Asset
|
|
||||||
new_asset = Asset(
|
|
||||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||||
type=AssetType.GENERATED,
|
|
||||||
content_type=AssetContentType.IMAGE,
|
|
||||||
linked_char_id=external_gen.linked_character_id,
|
|
||||||
data=None, # Not storing bytes in DB
|
|
||||||
minio_object_name=filename,
|
|
||||||
minio_bucket=self.s3_adapter.bucket_name,
|
|
||||||
thumbnail=thumbnail_bytes,
|
|
||||||
created_by=external_gen.created_by,
|
created_by=external_gen.created_by,
|
||||||
project_id=external_gen.project_id
|
project_id=external_gen.project_id,
|
||||||
|
linked_char_id=external_gen.linked_character_id,
|
||||||
|
folder="external"
|
||||||
)
|
)
|
||||||
|
|
||||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
|
||||||
new_asset.id = str(asset_id)
|
|
||||||
|
|
||||||
logger.info(f"Created asset {asset_id} for external generation")
|
|
||||||
|
|
||||||
# 5. Create Generation record
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
status=GenerationStatus.DONE,
|
status=GenerationStatus.DONE,
|
||||||
linked_character_id=external_gen.linked_character_id,
|
linked_character_id=external_gen.linked_character_id,
|
||||||
aspect_ratio=external_gen.aspect_ratio,
|
aspect_ratio=external_gen.aspect_ratio,
|
||||||
quality=external_gen.quality,
|
quality=external_gen.quality,
|
||||||
prompt=external_gen.prompt,
|
prompt=external_gen.prompt,
|
||||||
|
model=external_gen.model,
|
||||||
tech_prompt=external_gen.tech_prompt,
|
tech_prompt=external_gen.tech_prompt,
|
||||||
|
seed=external_gen.seed,
|
||||||
result_list=[new_asset.id],
|
result_list=[new_asset.id],
|
||||||
result=new_asset.id,
|
result=new_asset.id,
|
||||||
progress=100,
|
progress=100,
|
||||||
|
nsfw=external_gen.nsfw,
|
||||||
execution_time_seconds=external_gen.execution_time_seconds,
|
execution_time_seconds=external_gen.execution_time_seconds,
|
||||||
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||||
token_usage=external_gen.token_usage,
|
token_usage=external_gen.token_usage,
|
||||||
input_token_usage=external_gen.input_token_usage,
|
input_token_usage=external_gen.input_token_usage,
|
||||||
output_token_usage=external_gen.output_token_usage,
|
output_token_usage=external_gen.output_token_usage,
|
||||||
created_by=external_gen.created_by,
|
created_by=external_gen.created_by,
|
||||||
project_id=external_gen.project_id,
|
project_id=external_gen.project_id
|
||||||
created_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_id = await self.dao.generations.create_generation(generation)
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
generation.id = gen_id
|
generation.id = gen_id
|
||||||
|
|
||||||
logger.info(f"Created generation {gen_id} from external source")
|
|
||||||
|
|
||||||
return generation
|
return generation
|
||||||
|
|
||||||
async def delete_generation(self, generation_id: str) -> bool:
|
async def delete_generation(self, generation_id: str) -> bool:
|
||||||
"""
|
|
||||||
Soft delete generation by marking it as deleted.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
generation = await self.dao.generations.get_generation(generation_id)
|
generation = await self.dao.generations.get_generation(generation_id)
|
||||||
if not generation:
|
if not generation:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
generation.is_deleted = True
|
generation.is_deleted = True
|
||||||
generation.updated_at = datetime.now(UTC)
|
generation.updated_at = datetime.now(UTC)
|
||||||
await self.dao.generations.update_generation(generation)
|
await self.dao.generations.update_generation(generation)
|
||||||
@@ -483,59 +242,205 @@ class GenerationService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def cleanup_stale_generations(self):
|
async def cleanup_stale_generations(self):
|
||||||
"""
|
|
||||||
Cancels generations that have been running for more than 1 hour.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
|
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
|
||||||
if count > 0:
|
if count > 0:
|
||||||
logger.info(f"Cleaned up {count} stale generations (timeout)")
|
logger.info(f"Cleaned up {count} stale generations")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cleaning up stale generations: {e}")
|
logger.error(f"Error cleaning up stale generations: {e}")
|
||||||
|
|
||||||
async def cleanup_old_data(self, days: int = 2):
|
async def cleanup_old_data(self, days: int = 30):
|
||||||
"""
|
|
||||||
Очистка старых данных:
|
|
||||||
1. Мягко удаляет генерации старше N дней
|
|
||||||
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 1. Мягко удаляем генерации и собираем asset IDs
|
|
||||||
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
|
||||||
|
|
||||||
if gen_count > 0:
|
if gen_count > 0:
|
||||||
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
|
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
|
||||||
f"Found {len(asset_ids)} associated asset IDs.")
|
|
||||||
|
|
||||||
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
|
|
||||||
if asset_ids:
|
if asset_ids:
|
||||||
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
|
||||||
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during old data cleanup: {e}")
|
logger.error(f"Error during old data cleanup: {e}")
|
||||||
|
|
||||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||||
"""
|
|
||||||
Generates a financial usage report for a specific user or project.
|
|
||||||
'breakdown_by' can be 'created_by' or 'project_id'.
|
|
||||||
"""
|
|
||||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||||
summary = UsageStats(**summary_data)
|
summary = UsageStats(**summary_data)
|
||||||
|
|
||||||
by_user = None
|
by_user, by_project = None, None
|
||||||
by_project = None
|
|
||||||
|
|
||||||
if breakdown_by == "created_by":
|
if breakdown_by == "created_by":
|
||||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||||
by_user = [UsageByEntity(**item) for item in res]
|
by_user = [UsageByEntity(**item) for item in res]
|
||||||
|
|
||||||
if breakdown_by == "project_id":
|
if breakdown_by == "project_id":
|
||||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||||
by_project = [UsageByEntity(**item) for item in res]
|
by_project = [UsageByEntity(**item) for item in res]
|
||||||
|
|
||||||
return FinancialReport(
|
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
|
||||||
summary=summary,
|
|
||||||
by_user=by_user,
|
# --- Private Helpers ---
|
||||||
by_project=by_project
|
|
||||||
|
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
|
||||||
|
try:
|
||||||
|
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||||
|
gen_model.created_by = user_id
|
||||||
|
gen_model.generation_group_id = generation_group_id
|
||||||
|
|
||||||
|
gen_id = await self.dao.generations.create_generation(gen_model)
|
||||||
|
gen_model.id = gen_id
|
||||||
|
|
||||||
|
asyncio.create_task(self._queued_generation_runner(gen_model))
|
||||||
|
return GenerationResponse(**gen_model.model_dump())
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to initiate single generation")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _queued_generation_runner(self, gen: Generation):
|
||||||
|
logger.info(f"Generation {gen.id} waiting for slot...")
|
||||||
|
try:
|
||||||
|
async with generation_semaphore:
|
||||||
|
await self.create_generation(gen)
|
||||||
|
except Exception as e:
|
||||||
|
await self._handle_generation_failure(gen, e)
|
||||||
|
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||||
|
|
||||||
|
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||||
|
media_group_bytes: List[bytes] = []
|
||||||
|
prompt = generation.prompt
|
||||||
|
|
||||||
|
# 1. Character Avatar
|
||||||
|
if generation.linked_character_id:
|
||||||
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||||
|
if not char_info:
|
||||||
|
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||||
|
|
||||||
|
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||||
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||||
|
if avatar_asset:
|
||||||
|
data = await self._get_asset_data_bytes(avatar_asset)
|
||||||
|
if data: media_group_bytes.append(data)
|
||||||
|
|
||||||
|
# 2. Reference Assets
|
||||||
|
if generation.assets_list:
|
||||||
|
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
|
for asset in assets:
|
||||||
|
data = await self._get_asset_data_bytes(asset)
|
||||||
|
if data: media_group_bytes.append(data)
|
||||||
|
|
||||||
|
# 3. Environment Assets
|
||||||
|
if generation.environment_id:
|
||||||
|
env = await self.dao.environments.get_env(generation.environment_id)
|
||||||
|
if env and env.asset_ids:
|
||||||
|
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||||
|
for asset in env_assets:
|
||||||
|
data = await self._get_asset_data_bytes(asset)
|
||||||
|
if data: media_group_bytes.append(data)
|
||||||
|
|
||||||
|
if media_group_bytes:
|
||||||
|
prompt += (
|
||||||
|
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
|
||||||
|
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||||
|
)
|
||||||
|
|
||||||
|
return media_group_bytes, prompt
|
||||||
|
|
||||||
|
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||||
|
if asset.content_type != AssetContentType.IMAGE:
|
||||||
|
return None
|
||||||
|
if asset.minio_object_name:
|
||||||
|
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||||
|
return asset.data
|
||||||
|
|
||||||
|
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||||
|
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||||
|
generation.token_usage = metrics.get("token_usage")
|
||||||
|
generation.input_token_usage = metrics.get("input_token_usage")
|
||||||
|
generation.output_token_usage = metrics.get("output_token_usage")
|
||||||
|
|
||||||
|
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
|
||||||
|
logger.error(f"Generation {generation.id} failed: {error}")
|
||||||
|
generation.status = GenerationStatus.FAILED
|
||||||
|
# Don't overwrite if reason is already set, unless a new error is provided
|
||||||
|
if error:
|
||||||
|
generation.failed_reason = str(error)
|
||||||
|
elif not generation.failed_reason:
|
||||||
|
generation.failed_reason = "Unknown error"
|
||||||
|
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
|
||||||
|
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
|
||||||
|
created_assets = []
|
||||||
|
for img_bytes in bytes_list:
|
||||||
|
asset = await self._save_asset(
|
||||||
|
image_bytes=img_bytes,
|
||||||
|
name=f"Generated_{generation.linked_character_id}",
|
||||||
|
created_by=generation.created_by,
|
||||||
|
project_id=generation.project_id,
|
||||||
|
linked_char_id=generation.linked_character_id,
|
||||||
|
folder="generated"
|
||||||
|
)
|
||||||
|
created_assets.append(asset)
|
||||||
|
return created_assets
|
||||||
|
|
||||||
|
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
|
||||||
|
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||||
|
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||||
|
|
||||||
|
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||||
|
|
||||||
|
new_asset = Asset(
|
||||||
|
name=name,
|
||||||
|
type=AssetType.GENERATED,
|
||||||
|
content_type=AssetContentType.IMAGE,
|
||||||
|
linked_char_id=linked_char_id,
|
||||||
|
data=None,
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
thumbnail=thumbnail_bytes,
|
||||||
|
created_by=created_by,
|
||||||
|
project_id=project_id
|
||||||
)
|
)
|
||||||
|
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||||
|
new_asset.id = str(asset_id)
|
||||||
|
return new_asset
|
||||||
|
|
||||||
|
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
|
||||||
|
generation.result_list = [a.id for a in assets]
|
||||||
|
generation.status = GenerationStatus.DONE
|
||||||
|
generation.progress = 100
|
||||||
|
generation.updated_at = datetime.now(UTC)
|
||||||
|
generation.tech_prompt = tech_prompt
|
||||||
|
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
|
||||||
|
|
||||||
|
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
|
||||||
|
try:
|
||||||
|
for asset in assets:
|
||||||
|
# Need to get data for telegram if it's not in Asset object
|
||||||
|
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
|
||||||
|
if img_data:
|
||||||
|
await self.bot.send_photo(
|
||||||
|
chat_id=generation.telegram_id,
|
||||||
|
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
|
||||||
|
caption=f"Generated from: {generation.prompt[:100]}..."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to send to Telegram: {e}")
|
||||||
|
|
||||||
|
async def _simulate_progress(self, generation: Generation):
|
||||||
|
current_progress = 0
|
||||||
|
try:
|
||||||
|
while current_progress < 90:
|
||||||
|
await asyncio.sleep(4)
|
||||||
|
current_progress = min(current_progress + random.randint(5, 15), 90)
|
||||||
|
generation.progress = current_progress
|
||||||
|
await self.dao.generations.update_generation(generation)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _fetch_external_image(self, external_gen) -> bytes:
|
||||||
|
if external_gen.image_url:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
elif external_gen.image_data:
|
||||||
|
return base64.b64decode(external_gen.image_data)
|
||||||
|
raise ValueError("No image source provided")
|
||||||
|
|||||||
@@ -7,8 +7,14 @@ class IdeaService:
|
|||||||
def __init__(self, dao: DAO):
|
def __init__(self, dao: DAO):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
|
|
||||||
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
|
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str, inspiration_id: Optional[str] = None) -> Idea:
|
||||||
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
|
idea = Idea(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
project_id=project_id,
|
||||||
|
created_by=user_id,
|
||||||
|
inspiration_id=inspiration_id
|
||||||
|
)
|
||||||
idea_id = await self.dao.ideas.create_idea(idea)
|
idea_id = await self.dao.ideas.create_idea(idea)
|
||||||
idea.id = idea_id
|
idea.id = idea_id
|
||||||
return idea
|
return idea
|
||||||
@@ -19,7 +25,7 @@ class IdeaService:
|
|||||||
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
async def get_idea(self, idea_id: str) -> Optional[Idea]:
|
||||||
return await self.dao.ideas.get_idea(idea_id)
|
return await self.dao.ideas.get_idea(idea_id)
|
||||||
|
|
||||||
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
|
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None, inspiration_id: Optional[str] = None) -> Optional[Idea]:
|
||||||
idea = await self.dao.ideas.get_idea(idea_id)
|
idea = await self.dao.ideas.get_idea(idea_id)
|
||||||
if not idea:
|
if not idea:
|
||||||
return None
|
return None
|
||||||
@@ -28,6 +34,8 @@ class IdeaService:
|
|||||||
idea.name = name
|
idea.name = name
|
||||||
if description is not None:
|
if description is not None:
|
||||||
idea.description = description
|
idea.description = description
|
||||||
|
if inspiration_id is not None:
|
||||||
|
idea.inspiration_id = inspiration_id
|
||||||
|
|
||||||
idea.updated_at = datetime.now()
|
idea.updated_at = datetime.now()
|
||||||
await self.dao.ideas.update_idea(idea)
|
await self.dao.ideas.update_idea(idea)
|
||||||
@@ -72,4 +80,3 @@ class IdeaService:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
146
api/service/inspiration_service.py
Normal file
146
api/service/inspiration_service.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from models.Asset import Asset, AssetType, AssetContentType
|
||||||
|
from models.Inspiration import Inspiration
|
||||||
|
from repos.dao import DAO
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
|
||||||
|
# Try to import yt_dlp, but don't crash if it's missing (though we added it to requirements)
|
||||||
|
try:
|
||||||
|
import yt_dlp
|
||||||
|
except ImportError:
|
||||||
|
yt_dlp = None
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class InspirationService:
|
||||||
|
def __init__(self, dao: DAO, s3_adapter: S3Adapter):
|
||||||
|
self.dao = dao
|
||||||
|
self.s3_adapter = s3_adapter
|
||||||
|
|
||||||
|
async def create_inspiration(self, source_url: str, created_by: str, project_id: Optional[str] = None, caption: Optional[str] = None) -> Inspiration:
|
||||||
|
# 1. Download content from Instagram
|
||||||
|
try:
|
||||||
|
content_bytes, content_type, ext = await self._download_content(source_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download content from {source_url}: {e}")
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to download content: {str(e)}")
|
||||||
|
|
||||||
|
# 2. Save as Asset
|
||||||
|
filename = f"inspirations/{datetime.now().strftime('%Y%m%d_%H%M%S')}_insta.{ext}"
|
||||||
|
|
||||||
|
await self.s3_adapter.upload_file(filename, content_bytes, content_type=content_type)
|
||||||
|
|
||||||
|
asset = Asset(
|
||||||
|
name=f"Inspiration from {source_url}",
|
||||||
|
type=AssetType.INSPIRATION,
|
||||||
|
content_type=AssetContentType.VIDEO if content_type.startswith("video") else AssetContentType.IMAGE,
|
||||||
|
minio_object_name=filename,
|
||||||
|
minio_bucket=self.s3_adapter.bucket_name,
|
||||||
|
created_by=created_by,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
asset_id = await self.dao.assets.create_asset(asset)
|
||||||
|
|
||||||
|
# 3. Create Inspiration object
|
||||||
|
inspiration = Inspiration(
|
||||||
|
source_url=source_url,
|
||||||
|
caption=caption,
|
||||||
|
asset_id=str(asset_id),
|
||||||
|
created_by=created_by,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
insp_id = await self.dao.inspirations.create_inspiration(inspiration)
|
||||||
|
inspiration.id = insp_id
|
||||||
|
|
||||||
|
return inspiration
|
||||||
|
|
||||||
|
async def get_inspirations(self, project_id: Optional[str], created_by: str, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||||
|
return await self.dao.inspirations.get_inspirations(project_id, created_by, limit, offset)
|
||||||
|
|
||||||
|
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||||
|
return await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||||
|
|
||||||
|
async def mark_as_completed(self, inspiration_id: str, is_completed: bool = True) -> Optional[Inspiration]:
|
||||||
|
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||||
|
if not inspiration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
inspiration.is_completed = is_completed
|
||||||
|
inspiration.updated_at = datetime.now()
|
||||||
|
await self.dao.inspirations.update_inspiration(inspiration)
|
||||||
|
return inspiration
|
||||||
|
|
||||||
|
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||||
|
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
|
||||||
|
if not inspiration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Delete associated asset
|
||||||
|
if inspiration.asset_id:
|
||||||
|
await self.dao.assets.delete_asset(inspiration.asset_id)
|
||||||
|
|
||||||
|
return await self.dao.inspirations.delete_inspiration(inspiration_id)
|
||||||
|
|
||||||
|
async def _download_content(self, url: str) -> Tuple[bytes, str, str]:
|
||||||
|
"""
|
||||||
|
Downloads content using yt-dlp.
|
||||||
|
Returns (content_bytes, content_type, extension)
|
||||||
|
"""
|
||||||
|
if not yt_dlp:
|
||||||
|
raise RuntimeError("yt-dlp is not installed")
|
||||||
|
|
||||||
|
logger.info(f"Downloading from {url} using yt-dlp...")
|
||||||
|
|
||||||
|
def run_yt_dlp():
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
ydl_opts = {
|
||||||
|
'outtmpl': f'{tmpdirname}/%(id)s.%(ext)s',
|
||||||
|
'quiet': True,
|
||||||
|
'no_warnings': True,
|
||||||
|
'format': 'best', # Best quality single file
|
||||||
|
'noplaylist': True, # Only single video if it's a playlist/profile
|
||||||
|
'writethumbnail': False,
|
||||||
|
'writesubtitles': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||||
|
ydl.download([url])
|
||||||
|
|
||||||
|
# Find the downloaded file
|
||||||
|
files = os.listdir(tmpdirname)
|
||||||
|
if not files:
|
||||||
|
raise Exception("No files downloaded")
|
||||||
|
|
||||||
|
# Pick the largest file if multiple (e.g. if yt-dlp downloaded parts)
|
||||||
|
# But with 'format': 'best', it should be one.
|
||||||
|
# If carousel, it might be multiple. Let's pick the first one.
|
||||||
|
filename = files[0]
|
||||||
|
filepath = os.path.join(tmpdirname, filename)
|
||||||
|
|
||||||
|
with open(filepath, 'rb') as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
ext = filename.split('.')[-1].lower()
|
||||||
|
|
||||||
|
# Determine content type
|
||||||
|
if ext in ['mp4', 'mov', 'avi', 'mkv', 'webm']:
|
||||||
|
content_type = f"video/{ext}"
|
||||||
|
if ext == 'mov': content_type = "video/quicktime"
|
||||||
|
elif ext in ['jpg', 'jpeg', 'png', 'webp']:
|
||||||
|
content_type = f"image/{ext}"
|
||||||
|
if ext == 'jpg': content_type = "image/jpeg"
|
||||||
|
else:
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
|
||||||
|
return data, content_type, ext
|
||||||
|
|
||||||
|
return await asyncio.to_thread(run_yt_dlp)
|
||||||
@@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Album(BaseModel):
|
class Album(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
cover_asset_id: Optional[str] = None
|
cover_asset_id: str | None = None
|
||||||
generation_ids: List[str] = []
|
generation_ids: list[str] = []
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Any, List
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, computed_field, Field, model_validator
|
from pydantic import BaseModel, computed_field, Field, model_validator
|
||||||
|
|
||||||
@@ -8,28 +8,30 @@ from pydantic import BaseModel, computed_field, Field, model_validator
|
|||||||
class AssetContentType(str, Enum):
|
class AssetContentType(str, Enum):
|
||||||
IMAGE = 'image'
|
IMAGE = 'image'
|
||||||
PROMPT = 'prompt'
|
PROMPT = 'prompt'
|
||||||
|
VIDEO = 'video'
|
||||||
|
|
||||||
class AssetType(str, Enum):
|
class AssetType(str, Enum):
|
||||||
UPLOADED = 'uploaded'
|
UPLOADED = 'uploaded'
|
||||||
GENERATED = 'generated'
|
GENERATED = 'generated'
|
||||||
|
INSPIRATION = 'inspiration'
|
||||||
|
|
||||||
|
|
||||||
class Asset(BaseModel):
|
class Asset(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
type: AssetType = AssetType.GENERATED
|
type: AssetType = AssetType.GENERATED
|
||||||
content_type: AssetContentType = AssetContentType.IMAGE
|
content_type: AssetContentType = AssetContentType.IMAGE
|
||||||
linked_char_id: Optional[str] = None
|
linked_char_id: str | None = None
|
||||||
data: Optional[bytes] = None
|
data: bytes | None = None
|
||||||
tg_doc_file_id: Optional[str] = None
|
tg_doc_file_id: str | None = None
|
||||||
tg_photo_file_id: Optional[str] = None
|
tg_photo_file_id: str | None = None
|
||||||
minio_object_name: Optional[str] = None
|
minio_object_name: str | None = None
|
||||||
minio_bucket: Optional[str] = None
|
minio_bucket: str | None = None
|
||||||
minio_thumbnail_object_name: Optional[str] = None
|
minio_thumbnail_object_name: str | None = None
|
||||||
thumbnail: Optional[bytes] = None
|
thumbnail: bytes | None = None
|
||||||
tags: List[str] = []
|
tags: list[str] = []
|
||||||
created_by: Optional[str] = None
|
created_by: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|||||||
@@ -1,16 +1,15 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic_core.core_schema import computed_field
|
from pydantic_core.core_schema import computed_field
|
||||||
|
|
||||||
|
|
||||||
class Character(BaseModel):
|
class Character(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
avatar_asset_id: Optional[str] = None
|
avatar_asset_id: str | None = None
|
||||||
avatar_image: Optional[str] = None
|
avatar_image: str | None = None
|
||||||
character_image_doc_tg_id: Optional[str] = None
|
character_image_doc_tg_id: str | None = None
|
||||||
character_image_tg_id: Optional[str] = None
|
character_image_tg_id: str | None = None
|
||||||
character_bio: Optional[str] = None
|
character_bio: str | None = None
|
||||||
created_by: Optional[str] = None
|
created_by: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
from typing import List, Optional
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
|
|
||||||
class Environment(BaseModel):
|
class Environment(BaseModel):
|
||||||
id: Optional[str] = Field(None, alias="_id")
|
id: str | None = Field(None, alias="_id")
|
||||||
character_id: str
|
character_id: str
|
||||||
name: str = Field(..., min_length=1)
|
name: str = Field(..., min_length=1)
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
asset_ids: List[str] = Field(default_factory=list)
|
asset_ids: list[str] = Field(default_factory=list)
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,9 @@
|
|||||||
from datetime import datetime, UTC
|
from datetime import datetime, UTC
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
from models.Asset import Asset
|
from models.enums import AspectRatios, Quality
|
||||||
from models.enums import AspectRatios, Quality, GenType
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationStatus(str, Enum):
|
class GenerationStatus(str, Enum):
|
||||||
@@ -14,32 +12,36 @@ class GenerationStatus(str, Enum):
|
|||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
class Generation(BaseModel):
|
class Generation(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
status: GenerationStatus = GenerationStatus.RUNNING
|
status: GenerationStatus = GenerationStatus.RUNNING
|
||||||
failed_reason: Optional[str] = None
|
failed_reason: str | None = None
|
||||||
linked_character_id: Optional[str] = None
|
linked_character_id: str | None = None
|
||||||
telegram_id: Optional[int] = None
|
telegram_id: int | None = None
|
||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
aspect_ratio: AspectRatios
|
aspect_ratio: AspectRatios
|
||||||
quality: Quality
|
quality: Quality
|
||||||
prompt: str
|
prompt: str
|
||||||
tech_prompt: Optional[str] = None
|
model: str | None = None
|
||||||
assets_list: List[str] = Field(default_factory=list)
|
seed: int | None = None
|
||||||
result_list: List[str] = Field(default_factory=list)
|
tech_prompt: str | None = None
|
||||||
result: Optional[str] = None
|
assets_list: list[str] = Field(default_factory=list)
|
||||||
|
result_list: list[str] = Field(default_factory=list)
|
||||||
|
result: str | None = None
|
||||||
progress: int = 0
|
progress: int = 0
|
||||||
execution_time_seconds: Optional[float] = None
|
execution_time_seconds: float | None = None
|
||||||
api_execution_time_seconds: Optional[float] = None
|
api_execution_time_seconds: float | None = None
|
||||||
token_usage: Optional[int] = None
|
token_usage: int | None = None
|
||||||
input_token_usage: Optional[int] = None
|
input_token_usage: int | None = None
|
||||||
output_token_usage: Optional[int] = None
|
output_token_usage: int | None = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
album_id: Optional[str] = None
|
album_id: str | None = None
|
||||||
environment_id: Optional[str] = None
|
environment_id: str | None = None
|
||||||
generation_group_id: Optional[str] = None
|
generation_group_id: str | None = None
|
||||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
created_by: str | None = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
idea_id: Optional[str] = None
|
idea_id: str | None = None
|
||||||
|
liked_by: list[str] = Field(default_factory=list)
|
||||||
|
nsfw: bool = False
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Idea(BaseModel):
|
class Idea(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
name: str = "New Idea"
|
name: str = "New Idea"
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
|
inspiration_id: str | None = None # Link to Inspiration
|
||||||
created_by: str # User ID
|
created_by: str # User ID
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
created_at: datetime = Field(default_factory=datetime.now)
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
|||||||
15
models/Inspiration.py
Normal file
15
models/Inspiration.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from datetime import datetime, UTC
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Inspiration(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
source_url: str
|
||||||
|
caption: str | None = None
|
||||||
|
asset_id: str
|
||||||
|
is_completed: bool = False
|
||||||
|
created_by: str
|
||||||
|
project_id: str | None = None
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
from datetime import datetime, timezone, UTC
|
from datetime import datetime, timezone, UTC
|
||||||
from typing import Optional, List
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
class Post(BaseModel):
|
class Post(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
date: datetime
|
date: datetime
|
||||||
topic: str
|
topic: str
|
||||||
generation_ids: List[str] = Field(default_factory=list)
|
generation_ids: list[str] = Field(default_factory=list)
|
||||||
project_id: Optional[str] = None
|
project_id: str | None = None
|
||||||
created_by: str
|
created_by: str
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List, Optional
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class Project(BaseModel):
|
class Project(BaseModel):
|
||||||
id: Optional[str] = None
|
id: str | None = None
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: str | None = None
|
||||||
owner_id: str
|
owner_id: str
|
||||||
members: List[str] = [] # List of User IDs
|
members: list[str] = [] # List of User IDs
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
created_at: datetime = Field(default_factory=datetime.now)
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
|||||||
@@ -2,19 +2,30 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class AspectRatios(str, Enum):
|
class AspectRatios(str, Enum):
|
||||||
NINESIXTEEN = "NINESIXTEEN"
|
ONEONE = "1:1"
|
||||||
SIXTEENNINE = "SIXTEENNINE"
|
TWOTHREE = "2:3"
|
||||||
THREEFOUR = "THREEFOUR"
|
THREETWO = "3:2"
|
||||||
FOURTHREE = "FOURTHREE"
|
THREEFOUR = "3:4"
|
||||||
|
FOURTHREE = "4:3"
|
||||||
|
FOURFIVE = "4:5"
|
||||||
|
FIVEFOUR = "5:4"
|
||||||
|
NINESIXTEEN = "9:16"
|
||||||
|
SIXTEENNINE = "16:9"
|
||||||
|
TWENTYONENINE = "21:9"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _missing_(cls, value):
|
||||||
|
mapping = {
|
||||||
|
"NINESIXTEEN": cls.NINESIXTEEN,
|
||||||
|
"SIXTEENNINE": cls.SIXTEENNINE,
|
||||||
|
"THREEFOUR": cls.THREEFOUR,
|
||||||
|
"FOURTHREE": cls.FOURTHREE,
|
||||||
|
}
|
||||||
|
return mapping.get(value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_ratio(self) -> str:
|
def value_ratio(self) -> str:
|
||||||
return {
|
return self.value
|
||||||
AspectRatios.NINESIXTEEN: "9:16",
|
|
||||||
AspectRatios.SIXTEENNINE: "16:9",
|
|
||||||
AspectRatios.THREEFOUR: "3:4",
|
|
||||||
AspectRatios.FOURTHREE: "4:3",
|
|
||||||
}[self]
|
|
||||||
|
|
||||||
|
|
||||||
class Quality(str, Enum):
|
class Quality(str, Enum):
|
||||||
@@ -41,3 +52,20 @@ class GenType(str, Enum):
|
|||||||
GenType.TEXT: 'Text',
|
GenType.TEXT: 'Text',
|
||||||
GenType.IMAGE: 'Image',
|
GenType.IMAGE: 'Image',
|
||||||
}[self]
|
}[self]
|
||||||
|
|
||||||
|
|
||||||
|
class TextModel(str, Enum):
|
||||||
|
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value_model(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class ImageModel(str, Enum):
|
||||||
|
GEMINI_3_PRO_IMAGE_PREVIEW = "gemini-3-pro-image-preview"
|
||||||
|
GEMINI_3_1_FLASH_IMAGE_PREVIEW = "gemini-3.1-flash-image-preview"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value_model(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class AssetsRepo:
|
|||||||
|
|
||||||
return assets
|
return assets
|
||||||
|
|
||||||
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
|
async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]:
|
||||||
projection = None
|
projection = None
|
||||||
if not with_data:
|
if not with_data:
|
||||||
projection = {"data": 0, "thumbnail": 0}
|
projection = {"data": 0, "thumbnail": 0}
|
||||||
@@ -182,7 +182,9 @@ class AssetsRepo:
|
|||||||
return await self.collection.count_documents(filter)
|
return await self.collection.count_documents(filter)
|
||||||
|
|
||||||
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
|
||||||
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
|
object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)]
|
||||||
|
if not object_ids:
|
||||||
|
return []
|
||||||
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
|
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
|
||||||
# Original excluded thumbnail too.
|
# Original excluded thumbnail too.
|
||||||
assets = []
|
assets = []
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from repos.project_repo import ProjectRepo
|
|||||||
from repos.idea_repo import IdeaRepo
|
from repos.idea_repo import IdeaRepo
|
||||||
from repos.post_repo import PostRepo
|
from repos.post_repo import PostRepo
|
||||||
from repos.environment_repo import EnvironmentRepo
|
from repos.environment_repo import EnvironmentRepo
|
||||||
|
from repos.inspiration_repo import InspirationRepo
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -25,3 +26,4 @@ class DAO:
|
|||||||
self.ideas = IdeaRepo(client, db_name)
|
self.ideas = IdeaRepo(client, db_name)
|
||||||
self.posts = PostRepo(client, db_name)
|
self.posts = PostRepo(client, db_name)
|
||||||
self.environments = EnvironmentRepo(client, db_name)
|
self.environments = EnvironmentRepo(client, db_name)
|
||||||
|
self.inspirations = InspirationRepo(client, db_name)
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ class GenerationRepo:
|
|||||||
return Generation(**res)
|
return Generation(**res)
|
||||||
|
|
||||||
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
|
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||||
|
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> List[Generation]:
|
||||||
|
|
||||||
filter: dict[str, Any] = {"is_deleted": False}
|
filter: dict[str, Any] = {"is_deleted": False}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
@@ -43,6 +44,8 @@ class GenerationRepo:
|
|||||||
filter["project_id"] = project_id
|
filter["project_id"] = project_id
|
||||||
if idea_id is not None:
|
if idea_id is not None:
|
||||||
filter["idea_id"] = idea_id
|
filter["idea_id"] = idea_id
|
||||||
|
if only_liked_by is not None:
|
||||||
|
filter["liked_by"] = only_liked_by
|
||||||
|
|
||||||
# If fetching for an idea, sort by created_at ascending (cronological)
|
# If fetching for an idea, sort by created_at ascending (cronological)
|
||||||
# Otherwise typically descending (newest first)
|
# Otherwise typically descending (newest first)
|
||||||
@@ -57,7 +60,8 @@ class GenerationRepo:
|
|||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
|
||||||
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
|
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None,
|
||||||
|
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> int:
|
||||||
args = {}
|
args = {}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
args["linked_character_id"] = character_id
|
args["linked_character_id"] = character_id
|
||||||
@@ -73,6 +77,8 @@ class GenerationRepo:
|
|||||||
args["idea_id"] = idea_id
|
args["idea_id"] = idea_id
|
||||||
if album_id is not None:
|
if album_id is not None:
|
||||||
args["album_id"] = album_id
|
args["album_id"] = album_id
|
||||||
|
if only_liked_by is not None:
|
||||||
|
args["liked_by"] = only_liked_by
|
||||||
return await self.collection.count_documents(args)
|
return await self.collection.count_documents(args)
|
||||||
|
|
||||||
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||||
@@ -94,6 +100,47 @@ class GenerationRepo:
|
|||||||
async def update_generation(self, generation: Generation, ):
|
async def update_generation(self, generation: Generation, ):
|
||||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||||
|
|
||||||
|
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
|
||||||
|
"""
|
||||||
|
Toggles like for a user on a generation.
|
||||||
|
Returns True if liked, False if unliked, None if generation not found.
|
||||||
|
"""
|
||||||
|
if not ObjectId.is_valid(generation_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
oid = ObjectId(generation_id)
|
||||||
|
|
||||||
|
# Check if generation exists
|
||||||
|
gen = await self.collection.find_one({"_id": oid}, {"liked_by": 1})
|
||||||
|
|
||||||
|
if not gen:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if user_id in gen.get("liked_by", []):
|
||||||
|
# Unlike
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": oid},
|
||||||
|
{"$pull": {"liked_by": user_id}}
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Like
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": oid},
|
||||||
|
{"$addToSet": {"liked_by": user_id}}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
|
||||||
|
if not ObjectId.is_valid(generation_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(generation_id)},
|
||||||
|
{"$set": {"nsfw": is_nsfw}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||||
@@ -253,7 +300,6 @@ class GenerationRepo:
|
|||||||
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
asset_ids.extend(doc.get("result_list", []))
|
asset_ids.extend(doc.get("result_list", []))
|
||||||
asset_ids.extend(doc.get("assets_list", []))
|
|
||||||
|
|
||||||
# Мягкое удаление
|
# Мягкое удаление
|
||||||
res = await self.collection.update_many(
|
res = await self.collection.update_many(
|
||||||
|
|||||||
54
repos/inspiration_repo.py
Normal file
54
repos/inspiration_repo.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from models.Inspiration import Inspiration
|
||||||
|
|
||||||
|
|
||||||
|
class InspirationRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["inspirations"]
|
||||||
|
|
||||||
|
async def create_inspiration(self, inspiration: Inspiration) -> str:
|
||||||
|
res = await self.collection.insert_one(inspiration.model_dump(exclude={"id"}))
|
||||||
|
return str(res.inserted_id)
|
||||||
|
|
||||||
|
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(inspiration_id)})
|
||||||
|
if res:
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Inspiration(**res)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None, limit: int = 20, offset: int = 0) -> List[Inspiration]:
|
||||||
|
query = {}
|
||||||
|
if project_id:
|
||||||
|
query["project_id"] = project_id
|
||||||
|
if created_by:
|
||||||
|
query["created_by"] = created_by
|
||||||
|
|
||||||
|
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
|
||||||
|
inspirations = []
|
||||||
|
async for doc in cursor:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
inspirations.append(Inspiration(**doc))
|
||||||
|
return inspirations
|
||||||
|
|
||||||
|
async def count_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None) -> int:
|
||||||
|
query = {}
|
||||||
|
if project_id:
|
||||||
|
query["project_id"] = project_id
|
||||||
|
if created_by:
|
||||||
|
query["created_by"] = created_by
|
||||||
|
return await self.collection.count_documents(query)
|
||||||
|
|
||||||
|
async def update_inspiration(self, inspiration: Inspiration):
|
||||||
|
await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(inspiration.id)},
|
||||||
|
{"$set": inspiration.model_dump(exclude={"id"})}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_inspiration(self, inspiration_id: str) -> bool:
|
||||||
|
res = await self.collection.delete_one({"_id": ObjectId(inspiration_id)})
|
||||||
|
return res.deleted_count > 0
|
||||||
@@ -52,3 +52,4 @@ python-multipart==0.0.22
|
|||||||
email-validator
|
email-validator
|
||||||
prometheus-fastapi-instrumentator
|
prometheus-fastapi-instrumentator
|
||||||
pydantic-settings==2.13.0
|
pydantic-settings==2.13.0
|
||||||
|
yt-dlp
|
||||||
|
|||||||
@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
|
|||||||
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
|
||||||
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
|
||||||
await call.answer()
|
await call.answer()
|
||||||
keyboards = []
|
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
|
||||||
for ratio in AspectRatios:
|
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
|
||||||
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
|
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
|
||||||
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
await call.message.edit_caption(caption="Выбери соотношение сторон",
|
||||||
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
|
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
|
||||||
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
|
|
||||||
|
|
||||||
|
|
||||||
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))
|
||||||
|
|||||||
Reference in New Issue
Block a user