13 Commits

Author SHA1 Message Date
xds
e011805186 models + refactor 2026-02-27 20:37:24 +03:00
xds
d9caececd7 nsfw mark api 2026-02-27 14:44:14 +03:00
xds
c1300b7a2d nsfw mark api 2026-02-27 14:33:37 +03:00
xds
f6001f5994 nsfw mark api 2026-02-27 13:51:22 +03:00
xds
e4a39e90c3 nsfw mark api 2026-02-27 09:08:48 +03:00
xds
e976fe1c58 inspirations 2026-02-26 11:26:18 +03:00
xds
ecc8d69039 inspirations 2026-02-24 16:42:46 +03:00
xds
bc9230a49b fixes 2026-02-24 12:11:19 +03:00
xds
f07105b0e5 fixes 2026-02-21 18:31:28 +03:00
xds
9a5d54a373 fixes 2026-02-20 17:54:57 +03:00
xds
1868864f76 fixes 2026-02-20 13:10:37 +03:00
xds
9e0c522b5f ашчуы 2026-02-20 10:28:56 +03:00
xds
e1d941a2cd + env 2026-02-20 02:02:13 +03:00
45 changed files with 1248 additions and 736 deletions

33
.context.md Normal file
View File

@@ -0,0 +1,33 @@
# Project Context: AI Char Bot
## Overview
Python backend project using FastAPI and MongoDB (Motor).
Root: `/Users/xds/develop/py projects/ai-char-bot`
## Architecture
- **API Layer**: `api/endpoints` (FastAPI routers).
- **Service Layer**: `api/service` (Business logic).
- **Data Layer**: `repos` (DAOs and Repositories).
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
- **Adapters**: `adapters` (External services like S3, Google Gemini).
## Coding Standards & Preferences
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
- **Error Handling**:
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
- Services/Routers handle `HTTPException`.
## Key Features & Implementation Details
- **Generations**:
- Managed by `GenerationService` and `GenerationRepo`.
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
- **Ideas**:
- Managed by `IdeaService` and `IdeaRepo`.
- Can have linked generations.
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
## Recent Changes
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.

33
.gemini/AGENTS.md Normal file
View File

@@ -0,0 +1,33 @@
# Project Context: AI Char Bot
## Overview
Python backend project using FastAPI and MongoDB (Motor).
Root: `/Users/xds/develop/py projects/ai-char-bot`
## Architecture
- **API Layer**: `api/endpoints` (FastAPI routers).
- **Service Layer**: `api/service` (Business logic).
- **Data Layer**: `repos` (DAOs and Repositories).
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
- **Adapters**: `adapters` (External services like S3, Google Gemini).
## Coding Standards & Preferences
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
- **Error Handling**:
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
- Services/Routers handle `HTTPException`.
## Key Features & Implementation Details
- **Generations**:
- Managed by `GenerationService` and `GenerationRepo`.
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
- **Ideas**:
- Managed by `IdeaService` and `IdeaRepo`.
- Can have linked generations.
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
## Recent Changes
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.

View File

@@ -8,7 +8,7 @@ from google import genai
from google.genai import types
from adapters.Exception import GoogleGenerationException
from models.enums import AspectRatios, Quality
from models.enums import AspectRatios, Quality, TextModel, ImageModel
logger = logging.getLogger(__name__)
@@ -19,10 +19,6 @@ class GoogleAdapter:
raise ValueError("API Key for Gemini is missing")
self.client = genai.Client(api_key=api_key)
# Константы моделей
self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
@@ -41,16 +37,19 @@ class GoogleAdapter:
logger.info("Preparing content with no images")
return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
"""
Генерация текста (Чат или Vision).
Возвращает строку с ответом.
"""
if model not in [m.value for m in TextModel]:
raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}")
logger.info(f"Generating text: {prompt} with model: {model}")
try:
response = self.client.models.generate_content(
model=self.TEXT_MODEL,
model=model,
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['TEXT'],
@@ -74,21 +73,23 @@ class GoogleAdapter:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке).
"""
if model not in [m.value for m in ImageModel]:
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
start_time = datetime.now()
token_usage = 0
try:
response = self.client.models.generate_content(
model=self.IMAGE_MODEL,
model=model,
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['IMAGE'],

View File

@@ -1,5 +1,5 @@
from contextlib import asynccontextmanager
from typing import Optional, BinaryIO
from typing import Optional, BinaryIO, AsyncGenerator
import aioboto3
from botocore.exceptions import ClientError
import os
@@ -56,11 +56,25 @@ class S3Adapter:
print(f"Error downloading from S3: {e}")
return None
async def stream_file(self, object_name: str, chunk_size: int = 65536):
async def get_file_size(self, object_name: str) -> Optional[int]:
"""Returns the size of the file in bytes."""
try:
async with self._get_client() as client:
response = await client.head_object(Bucket=self.bucket_name, Key=object_name)
return response['ContentLength']
except ClientError as e:
print(f"Error getting file size from S3: {e}")
return None
async def stream_file(self, object_name: str, range_header: Optional[str] = None, chunk_size: int = 65536) -> AsyncGenerator[bytes, None]:
"""Streams a file from S3 yielding chunks. Memory-efficient for large files."""
try:
async with self._get_client() as client:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
args = {'Bucket': self.bucket_name, 'Key': object_name}
if range_header:
args['Range'] = range_header
response = await client.get_object(**args)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']

View File

@@ -45,6 +45,7 @@ from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
from api.endpoints.environment_router import router as environment_api_router
from api.endpoints.inspiration_router import router as inspiration_api_router
logger = logging.getLogger(__name__)
@@ -133,7 +134,7 @@ async def start_scheduler(service: GenerationService):
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
await service.cleanup_old_data(days=14)
except asyncio.CancelledError:
break
except Exception as e:
@@ -222,6 +223,7 @@ app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
app.include_router(environment_api_router)
app.include_router(inspiration_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(

View File

@@ -62,4 +62,9 @@ async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)
return PostService(dao)
from api.service.inspiration_service import InspirationService
def get_inspiration_service(dao: DAO = Depends(get_dao), s3_adapter: S3Adapter = Depends(get_s3_adapter)) -> InspirationService:
return InspirationService(dao, s3_adapter)

View File

@@ -1,4 +1,4 @@
from typing import Annotated, List
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
@@ -54,7 +54,7 @@ class UserResponse(BaseModel):
class Config:
from_attributes = True
@router.get("/approvals", response_model=List[UserResponse])
@router.get("/approvals", response_model=list[UserResponse])
async def list_pending_users(
admin: Annotated[dict, Depends(get_current_admin)],
repo: Annotated[UsersRepo, Depends(get_users_repo)]

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel
@@ -13,18 +12,18 @@ router = APIRouter(prefix="/api/albums", tags=["Albums"])
class AlbumCreateRequest(BaseModel):
name: str
description: Optional[str] = None
description: str | None = None
class AlbumUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
name: str | None = None
description: str | None = None
class AlbumResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
generation_ids: List[str] = []
cover_asset_id: Optional[str] = None # Not implemented yet
description: str | None = None
generation_ids: list[str] = []
cover_asset_id: str | None = None # Not implemented yet
@router.post("", response_model=AlbumResponse)
async def create_album(request: Request, album_in: AlbumCreateRequest):
@@ -32,7 +31,7 @@ async def create_album(request: Request, album_in: AlbumCreateRequest):
album = await service.create_album(name=album_in.name, description=album_in.description)
return AlbumResponse(**album.model_dump())
@router.get("", response_model=List[AlbumResponse])
@router.get("", response_model=list[AlbumResponse])
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
albums = await service.get_albums(limit=limit, offset=offset)
@@ -77,7 +76,7 @@ async def remove_generation_from_album(request: Request, album_id: str, generati
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
return {"status": "success"}
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
@router.get("/{album_id}/generations", response_model=list[GenerationResponse])
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
service: AlbumService = request.app.state.album_service
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)

View File

@@ -1,4 +1,4 @@
from typing import List, Optional, Dict, Any
from typing import Any
from aiogram.types import BufferedInputFile
from bson import ObjectId
@@ -42,8 +42,9 @@ async def get_asset(
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
headers = {
"Cache-Control": "public, max-age=31536000, immutable"
base_headers = {
"Cache-Control": "public, max-age=31536000, immutable",
"Accept-Ranges": "bytes"
}
# Thumbnail: маленький, можно грузить в RAM
@@ -51,17 +52,70 @@ async def get_asset(
if asset.minio_thumbnail_object_name and s3_adapter:
thumb_bytes = await s3_adapter.get_file(asset.minio_thumbnail_object_name)
if thumb_bytes:
return Response(content=thumb_bytes, media_type="image/jpeg", headers=headers)
return Response(content=thumb_bytes, media_type="image/jpeg", headers=base_headers)
# Fallback: thumbnail in DB
if asset.thumbnail:
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=headers)
return Response(content=asset.thumbnail, media_type="image/jpeg", headers=base_headers)
# No thumbnail available — fall through to main content
# Main content: стримим из S3 без загрузки в RAM
if asset.minio_object_name and s3_adapter:
content_type = "image/png"
# if asset.content_type == AssetContentType.VIDEO:
# content_type = "video/mp4"
if asset.content_type == AssetContentType.VIDEO:
content_type = "video/mp4" # Or detect from extension if stored
elif asset.content_type == AssetContentType.IMAGE:
content_type = "image/png" # Default for images
# Better content type detection based on extension if possible, but for now this is okay
if asset.minio_object_name.endswith(".mp4"):
content_type = "video/mp4"
elif asset.minio_object_name.endswith(".mov"):
content_type = "video/quicktime"
elif asset.minio_object_name.endswith(".jpg") or asset.minio_object_name.endswith(".jpeg"):
content_type = "image/jpeg"
# Handle Range requests for video streaming
range_header = request.headers.get("range")
file_size = await s3_adapter.get_file_size(asset.minio_object_name)
if range_header and file_size:
try:
# Parse Range header: bytes=start-end
byte_range = range_header.replace("bytes=", "")
start_str, end_str = byte_range.split("-")
start = int(start_str)
end = int(end_str) if end_str else file_size - 1
# Validate range
if start >= file_size:
# 416 Range Not Satisfiable
return Response(status_code=416, headers={"Content-Range": f"bytes */{file_size}"})
chunk_size = end - start + 1
headers = base_headers.copy()
headers.update({
"Content-Range": f"bytes {start}-{end}/{file_size}",
"Content-Length": str(chunk_size),
})
# Pass the exact range string to S3
s3_range = f"bytes={start}-{end}"
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name, range_header=s3_range),
status_code=206,
headers=headers,
media_type=content_type
)
except ValueError:
pass # Fallback to full content if range parsing fails
# Full content response
headers = base_headers.copy()
if file_size:
headers["Content-Length"] = str(file_size)
return StreamingResponse(
s3_adapter.stream_file(asset.minio_object_name),
media_type=content_type,
@@ -70,7 +124,7 @@ async def get_asset(
# Fallback: data stored in DB (legacy)
if asset.data:
return Response(content=asset.data, media_type="image/png", headers=headers)
return Response(content=asset.data, media_type="image/png", headers=base_headers)
raise HTTPException(status_code=404, detail="Asset data not found")
@@ -81,22 +135,22 @@ async def delete_orphan_assets_from_minio(
*,
assets_collection: str = "assets",
generations_collection: str = "generations",
asset_type: Optional[str] = "generated",
project_id: Optional[str] = None,
asset_type: str | None = "generated",
project_id: str | None = None,
dry_run: bool = True,
mark_assets_deleted: bool = False,
batch_size: int = 500,
) -> Dict[str, Any]:
) -> dict[str, Any]:
db = mongo['bot_db'] # БД уже выбрана в get_mongo_client
assets = db[assets_collection]
match_assets: Dict[str, Any] = {}
match_assets: dict[str, Any] = {}
if asset_type is not None:
match_assets["type"] = asset_type
if project_id is not None:
match_assets["project_id"] = project_id
pipeline: List[Dict[str, Any]] = [
pipeline: list[dict[str, Any]] = [
{"$match": match_assets} if match_assets else {"$match": {}},
{
"$lookup": {
@@ -138,8 +192,8 @@ async def delete_orphan_assets_from_minio(
deleted_objects = 0
deleted_assets = 0
errors: List[Dict[str, Any]] = []
orphan_asset_ids: List[ObjectId] = []
errors: list[dict[str, Any]] = []
orphan_asset_ids: list[ObjectId] = []
async for asset in cursor:
aid = asset["_id"]
@@ -205,7 +259,7 @@ async def delete_asset(
@router.get("", dependencies=[Depends(get_current_user)])
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: str | None = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: str | None = Depends(get_project_id)) -> AssetsResponse:
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
user_id_filter = current_user["id"]
@@ -232,10 +286,10 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset(
file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None),
linked_char_id: str | None = Form(None),
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id)
project_id: str | None = Depends(get_project_id)
):
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
if not file.content_type:

View File

@@ -1,4 +1,4 @@
from typing import List, Any, Coroutine, Optional
from typing import Any, Coroutine
from fastapi import APIRouter, Depends
from pydantic import BaseModel
@@ -23,15 +23,15 @@ from api.dependency import get_project_id
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
@router.get("/", response_model=List[Character])
@router.get("/", response_model=list[Character])
async def get_characters(
request: Request,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
limit: int = 100,
offset: int = 0
) -> List[Character]:
) -> list[Character]:
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"])
@@ -102,7 +102,7 @@ async def get_character_by_id(character_id: str, request: Request, dao: DAO = De
@router.post("/", response_model=Character)
async def create_character(
char_req: CharacterCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:

View File

@@ -1,5 +1,4 @@
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from starlette import status
@@ -50,7 +49,7 @@ async def create_environment(
return created_env
@router.get("/character/{character_id}", response_model=List[Environment])
@router.get("/character/{character_id}", response_model=list[Environment])
async def get_character_environments(
character_id: str,
dao: DAO = Depends(get_dao),
@@ -91,6 +90,18 @@ async def update_environment(
update_data = env_update.model_dump(exclude_unset=True)
if not update_data:
return env
# Verify assets exist if provided
if "asset_ids" in update_data:
if update_data["asset_ids"] is None:
del update_data["asset_ids"]
elif update_data["asset_ids"]:
# Verify all assets exist using batch check
assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"])
if len(assets) != len(update_data["asset_ids"]):
found_ids = {a.id for a in assets}
missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids]
raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}")
success = await dao.environments.update_env(env_id, update_data)
if not success:

View File

@@ -1,7 +1,5 @@
import logging
import os
import json
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends
@@ -19,7 +17,8 @@ from api.models import (
PromptRequest,
GenerationGroupResponse,
FinancialReport,
ExternalGenerationRequest
ExternalGenerationRequest,
NsfwRequest
)
from api.service.generation_service import GenerationService
from repos.dao import DAO
@@ -30,85 +29,88 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix='/api/generations', tags=["Generation"])
async def check_project_access(project_id: str | None, current_user: dict, dao: DAO):
"""Helper to check if user has access to project."""
if not project_id:
return
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
@router.post("/prompt-assistant", response_model=PromptResponse)
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
generation_service: GenerationService = Depends(
get_generation_service),
current_user: dict = Depends(get_current_user)) -> PromptResponse:
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
async def ask_prompt_assistant(
prompt_request: PromptRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"ask_prompt_assistant: {len(prompt_request.prompt)} chars")
generated_prompt = await generation_service.ask_prompt_assistant(
prompt_request.prompt,
prompt_request.linked_assets,
prompt_request.model
)
return PromptResponse(prompt=generated_prompt)
@router.post("/prompt-from-image", response_model=PromptResponse)
async def prompt_from_image(
prompt: Optional[str] = Form(None),
images: List[UploadFile] = File(...),
prompt: str | None = Form(None),
model: str = Form("gemini-3.1-pro-preview"),
images: list[UploadFile] = File(...),
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> PromptResponse:
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
images_bytes = []
for image in images:
content = await image.read()
images_bytes.append(content)
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt)
images_bytes = [await img.read() for img in images]
generated_prompt = await generation_service.generate_prompt_from_images(images_bytes, prompt, model)
return PromptResponse(prompt=generated_prompt)
@router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
async def get_generations(
character_id: str | None = None,
limit: int = 10,
offset: int = 0,
only_liked: bool = False,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
# If project_id is set, we don't filter by user to show all project-wide generations
created_by_filter = None if project_id else str(current_user["_id"])
only_liked_by = str(current_user["_id"]) if only_liked else None
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
return await generation_service.get_generations(
character_id=character_id,
limit=limit,
offset=offset,
created_by=created_by_filter,
project_id=project_id,
only_liked_by=only_liked_by,
current_user_id=str(current_user["_id"])
)
@router.get("/usage", response_model=FinancialReport)
async def get_usage_report(
breakdown: Optional[str] = None, # "user" or "project"
breakdown: str | None = None, # "user" or "project"
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
await check_project_access(project_id, current_user, dao)
user_id_filter = str(current_user["_id"]) if not project_id else None
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
return await generation_service.get_financial_report(
user_id=user_id_filter,
@@ -116,58 +118,61 @@ async def get_usage_report(
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationGroupResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(
generation: GenerationRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> GenerationGroupResponse:
await check_project_access(project_id, current_user, dao)
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
return await generation_service.create_generation_task(
generation,
user_id=str(current_user.get("_id"))
)
@router.get("/running")
async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
async def get_running_generations(
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: str | None = Depends(get_project_id),
dao: DAO = Depends(get_dao)
):
await check_project_access(project_id, current_user, dao)
user_id_filter = None if project_id else str(current_user["_id"])
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
return await generation_service.get_running_generations(
user_id=user_id_filter,
project_id=project_id
)
@router.get("/group/{group_id}", response_model=GenerationGroupResponse)
async def get_generation_group(group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"get_generation_group called for group_id: {group_id}")
generations = await generation_service.dao.generations.get_generations_by_group(group_id)
gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses)
async def get_generation_group(
group_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
return await generation_service.get_generations_by_group(group_id, current_user_id=str(current_user["_id"]))
@router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
async def get_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
) -> GenerationResponse:
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
@@ -180,6 +185,41 @@ async def get_generation(generation_id: str,
return gen
@router.post("/{generation_id}/like", response_model=dict)
async def toggle_like(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
is_liked = await generation_service.toggle_like(generation_id, str(current_user["_id"]))
if is_liked is None:
raise HTTPException(status_code=404, detail="Generation not found")
return {"is_liked": is_liked}
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
async def mark_generation_nsfw(
generation_id: str,
request: NsfwRequest,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
if not gen:
raise HTTPException(status_code=404, detail="Generation not found")
if gen.created_by != str(current_user["_id"]):
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied")
await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw)
return None
@router.post("/import", response_model=GenerationResponse)
@@ -188,35 +228,18 @@ async def import_external_generation(
generation_service: GenerationService = Depends(get_generation_service),
x_signature: str = Header(..., alias="X-Signature")
) -> GenerationResponse:
"""
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = settings.EXTERNAL_API_SECRET
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
if not verify_signature(body, x_signature, secret):
logger.warning("Invalid signature for external generation import")
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
except Exception as e:
logger.error(f"Failed to parse request body: {e}")
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
# Import generation
try:
generation = await generation_service.import_external_generation(external_gen)
return GenerationResponse(**generation.model_dump())
except Exception as e:
@@ -225,11 +248,11 @@ async def import_external_generation(
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)):
logger.info(f"delete_generation called for ID: {generation_id}")
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
async def delete_generation(
generation_id: str,
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
if not await generation_service.delete_generation(generation_id):
raise HTTPException(status_code=404, detail="Generation not found")
return None

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from api.dependency import get_idea_service, get_project_id, get_generation_service
from api.endpoints.auth import get_current_user
@@ -14,17 +13,23 @@ router = APIRouter(prefix="/api/ideas", tags=["ideas"])
@router.post("", response_model=Idea)
async def create_idea(
request: IdeaCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
idea_service: IdeaService = Depends(get_idea_service)
):
pid = project_id or request.project_id
return await idea_service.create_idea(request.name, request.description, pid, str(current_user["_id"]))
return await idea_service.create_idea(
name=request.name,
description=request.description,
project_id=pid,
user_id=str(current_user["_id"]),
inspiration_id=request.inspiration_id
)
@router.get("", response_model=List[IdeaResponse])
@router.get("", response_model=list[IdeaResponse])
async def get_ideas(
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
@@ -48,7 +53,12 @@ async def update_idea(
request: IdeaUpdateRequest,
idea_service: IdeaService = Depends(get_idea_service)
):
idea = await idea_service.update_idea(idea_id, request.name, request.description)
idea = await idea_service.update_idea(
idea_id=idea_id,
name=request.name,
description=request.description,
inspiration_id=request.inspiration_id
)
if not idea:
raise HTTPException(status_code=404, detail="Idea not found")
return idea
@@ -68,18 +78,10 @@ async def get_idea_generations(
idea_id: str,
limit: int = 50,
offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)
):
# Depending on how generation service implements filtering by idea_id.
# We might need to update generation_service to support getting by idea_id directly
# or ensure generic get_generations supports it.
# Looking at generation_router.py, get_generations doesn't have idea_id arg?
# Let's check generation_service.get_generations signature again.
# It has: (character_id, limit, offset, user_id, project_id). NO IDEA_ID.
# I need to update GenerationService.get_generations too!
# For now, let's assume I will update it.
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset)
return await generation_service.get_generations(idea_id=idea_id, limit=limit, offset=offset, current_user_id=str(current_user["_id"]))
@router.post("/{idea_id}/generations/{generation_id}")
async def add_generation_to_idea(

View File

@@ -0,0 +1,94 @@
from fastapi import APIRouter, Depends, HTTPException, status
from api.dependency import get_inspiration_service, get_project_id
from api.endpoints.auth import get_current_user
from api.models.InspirationRequest import InspirationCreateRequest, InspirationResponse, InspirationListResponse
from api.service.inspiration_service import InspirationService
from models.Inspiration import Inspiration
router = APIRouter(prefix="/api/inspirations", tags=["Inspirations"])
@router.post("", response_model=InspirationResponse, status_code=status.HTTP_201_CREATED)
async def create_inspiration(
request: InspirationCreateRequest,
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
service: InspirationService = Depends(get_inspiration_service)
):
pid = project_id or request.project_id
inspiration = await service.create_inspiration(
source_url=request.source_url,
created_by=str(current_user["_id"]),
project_id=pid,
caption=request.caption
)
return inspiration
@router.get("", response_model=InspirationListResponse)
async def get_inspirations(
project_id: str | None = Depends(get_project_id),
limit: int = 20,
offset: int = 0,
current_user: dict = Depends(get_current_user),
service: InspirationService = Depends(get_inspiration_service)
):
# If project_id is provided, filter by it. Otherwise, filter by user.
# Or maybe we want to see all user's inspirations if no project is selected?
# Let's follow the pattern: if project_id is present, show project's inspirations.
# If not, show user's personal inspirations (where project_id is None) OR all user's inspirations?
# Usually "My Inspirations" means created by me.
# Let's assume:
# If project_id -> filter by project_id (and maybe created_by if we want strict ownership, but usually project members share)
# If no project_id -> filter by created_by (personal)
pid = project_id
uid = str(current_user["_id"])
inspirations = await service.get_inspirations(project_id=pid, created_by=uid if not pid else None, limit=limit, offset=offset)
total_count = await service.dao.inspirations.count_inspirations(project_id=pid, created_by=uid if not pid else None)
return InspirationListResponse(
inspirations=[InspirationResponse(**inspiration.model_dump()) for inspiration in inspirations],
total_count=total_count
)
@router.get("/{inspiration_id}", response_model=InspirationResponse)
async def get_inspiration(
inspiration_id: str,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
inspiration = await service.get_inspiration(inspiration_id)
if not inspiration:
raise HTTPException(status_code=404, detail="Inspiration not found")
return inspiration
@router.patch("/{inspiration_id}/complete", response_model=InspirationResponse)
async def mark_inspiration_complete(
inspiration_id: str,
is_completed: bool = True,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
inspiration = await service.mark_as_completed(inspiration_id, is_completed)
if not inspiration:
raise HTTPException(status_code=404, detail="Inspiration not found")
return inspiration
@router.delete("/{inspiration_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_inspiration(
inspiration_id: str,
service: InspirationService = Depends(get_inspiration_service),
current_user: dict = Depends(get_current_user)
):
success = await service.delete_inspiration(inspiration_id)
if not success:
raise HTTPException(status_code=404, detail="Inspiration not found")
return None

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
@@ -14,7 +13,7 @@ router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
@@ -28,13 +27,13 @@ async def create_post(
)
@router.get("", response_model=List[Post])
@router.get("", response_model=list[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
project_id: str | None = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
date_from: datetime | None = None,
date_to: datetime | None = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):

View File

@@ -1,4 +1,3 @@
from typing import List, Optional
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, status
@@ -12,7 +11,7 @@ router = APIRouter(prefix="/api/projects", tags=["Projects"])
class ProjectCreate(BaseModel):
name: str
description: Optional[str] = None
description: str | None = None
class ProjectMemberResponse(BaseModel):
id: str
@@ -21,9 +20,9 @@ class ProjectMemberResponse(BaseModel):
class ProjectResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
description: str | None = None
owner_id: str
members: List[ProjectMemberResponse]
members: list[ProjectMemberResponse]
is_owner: bool = False
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
@@ -78,7 +77,7 @@ async def create_project(
return await _get_project_response(new_project, user_id, dao)
@router.get("", response_model=List[ProjectResponse])
@router.get("", response_model=list[ProjectResponse])
async def get_my_projects(
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
@@ -11,10 +10,10 @@ class AssetResponse(BaseModel):
name: str
type: str # uploaded / generated
content_type: str # image / prompt
linked_char_id: Optional[str] = None
linked_char_id: str | None = None
created_at: datetime
url: Optional[str] = None
url: str | None = None
class AssetsResponse(BaseModel):
assets: List[AssetResponse]
assets: list[AssetResponse]
total_count: int

View File

@@ -1,18 +1,17 @@
from typing import Optional
from pydantic import BaseModel
class CharacterCreateRequest(BaseModel):
name: str
character_bio: str
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None
character_image_doc_tg_id: str | None = None
avatar_image: str | None = None
character_image_tg_id: str | None = None
project_id: str | None = None
class CharacterUpdateRequest(BaseModel):
name: Optional[str] = None
character_bio: Optional[str] = None
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None
name: str | None = None
character_bio: str | None = None
character_image_doc_tg_id: str | None = None
avatar_image: str | None = None
character_image_tg_id: str | None = None
project_id: str | None = None

View File

@@ -1,17 +1,17 @@
from typing import Optional, List
from pydantic import BaseModel, Field
class EnvironmentCreate(BaseModel):
character_id: str
name: str = Field(..., min_length=1)
description: Optional[str] = None
asset_ids: Optional[List[str]] = []
description: str | None = None
asset_ids: list[str] | None = []
class EnvironmentUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1)
description: Optional[str] = None
name: str | None = Field(None, min_length=1)
description: str | None = None
asset_ids: list[str] | None = None
class AssetToEnvironment(BaseModel):
@@ -19,4 +19,4 @@ class AssetToEnvironment(BaseModel):
class AssetsToEnvironment(BaseModel):
asset_ids: List[str]
asset_ids: list[str]

View File

@@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
from models.enums import AspectRatios, Quality
@@ -7,27 +6,31 @@ class ExternalGenerationRequest(BaseModel):
"""Request model for importing external generations."""
prompt: str
tech_prompt: Optional[str] = None
tech_prompt: str | None = None
# Image can be provided as base64 string OR URL (one must be provided)
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
image_url: Optional[str] = Field(None, description="URL to download image from")
image_data: str | None = Field(None, description="Base64-encoded image data")
image_url: str | None = Field(None, description="URL to download image from")
nsfw: bool = False
# Generation metadata
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK
model: str | None = None
seed: int | None = None
# Optional linking
linked_character_id: Optional[str] = None
linked_character_id: str | None = None
created_by: str = Field(..., description="User ID from external system")
project_id: Optional[str] = None
project_id: str | None = None
# Performance metrics
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
execution_time_seconds: float | None = None
api_execution_time_seconds: float | None = None
token_usage: int | None = None
input_token_usage: int | None = None
output_token_usage: int | None = None
def validate_image_source(self):
"""Ensure at least one image source is provided."""

View File

@@ -1,5 +1,4 @@
from pydantic import BaseModel
from typing import List, Optional
class UsageStats(BaseModel):
total_runs: int
@@ -9,10 +8,10 @@ class UsageStats(BaseModel):
total_cost: float
class UsageByEntity(BaseModel):
entity_id: Optional[str] = None
entity_id: str | None = None
stats: UsageStats
class FinancialReport(BaseModel):
summary: UsageStats
by_user: Optional[List[UsageByEntity]] = None
by_project: Optional[List[UsageByEntity]] = None
by_user: list[UsageByEntity] | None = None
by_project: list[UsageByEntity] | None = None

View File

@@ -1,68 +1,79 @@
from datetime import datetime, UTC
from typing import List, Optional
from pydantic import BaseModel, Field
from models.Asset import Asset
from models.Generation import GenerationStatus
from models.enums import AspectRatios, Quality, GenType
from models.enums import AspectRatios, Quality, GenType, ImageModel, TextModel
class GenerationRequest(BaseModel):
linked_character_id: Optional[str] = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
linked_character_id: str | None = None
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN # "1:1","2:3","3:2","3:4","4:3","4:5","5:4","9:16","16:9","21:9"
quality: Quality = Quality.ONEK
prompt: str
telegram_id: Optional[int] = None
model: ImageModel = Field(default=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW)
telegram_id: int | None = None
use_profile_image: bool = True
assets_list: List[str]
environment_id: Optional[str] = None
project_id: Optional[str] = None
idea_id: Optional[str] = None
assets_list: list[str]
environment_id: str | None = None
project_id: str | None = None
idea_id: str | None = None
nsfw: bool = False
count: int = Field(default=1, ge=1, le=10)
class NsfwRequest(BaseModel):
is_nsfw: bool
class GenerationsResponse(BaseModel):
generations: List["GenerationResponse"]
generations: list["GenerationResponse"]
total_count: int
class GenerationResponse(BaseModel):
id: str
status: GenerationStatus
failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None
failed_reason: str | None = None
project_id: str | None = None
linked_character_id: str | None = None
aspect_ratio: AspectRatios
quality: Quality
prompt: str
tech_prompt: Optional[str] = None
assets_list: List[str]
result_list: List[str] = []
result: Optional[str] = None
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
model: ImageModel | None = None
seed: int | None = None
tech_prompt: str | None = None
assets_list: list[str]
result_list: list[str] = []
result: str | None = None
execution_time_seconds: float | None = None
api_execution_time_seconds: float | None = None
token_usage: int | None = None
input_token_usage: int | None = None
output_token_usage: int | None = None
progress: int = 0
cost: Optional[float] = None
created_by: Optional[str] = None
generation_group_id: Optional[str] = None
idea_id: Optional[str] = None
cost: float | None = None
created_by: str | None = None
generation_group_id: str | None = None
idea_id: str | None = None
likes_count: int = 0
is_liked: bool = False
nsfw: bool = False
created_at: datetime = datetime.now(UTC)
updated_at: datetime = datetime.now(UTC)
class GenerationGroupResponse(BaseModel):
generation_group_id: str
generations: List[GenerationResponse]
generations: list[GenerationResponse]
class PromptRequest(BaseModel):
prompt: str
linked_assets: List[str] = []
model: TextModel = Field(default=TextModel.GEMINI_3_1_PRO_PREVIEW)
linked_assets: list[str] = []
class PromptResponse(BaseModel):
prompt: str
prompt: str

View File

@@ -1,16 +1,17 @@
from typing import Optional
from pydantic import BaseModel
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse
class IdeaCreateRequest(BaseModel):
name: str
description: Optional[str] = None
project_id: Optional[str] = None # Optional in body if passed via header/dependency
description: str | None = None
project_id: str | None = None # Optional in body if passed via header/dependency
inspiration_id: str | None = None
class IdeaUpdateRequest(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
name: str | None = None
description: str | None = None
inspiration_id: str | None = None
class IdeaResponse(Idea):
last_generation: Optional[GenerationResponse] = None
last_generation: GenerationResponse | None = None

View File

@@ -0,0 +1,28 @@
from datetime import datetime
from pydantic import BaseModel
from models.Inspiration import Inspiration
class InspirationCreateRequest(BaseModel):
source_url: str
caption: str | None = None
project_id: str | None = None
class InspirationResponse(BaseModel):
id: str
source_url: str
caption: str | None = None
asset_id: str
is_completed: bool
created_by: str
project_id: str | None = None
created_at: datetime
updated_at: datetime
class InspirationListResponse(BaseModel):
inspirations: list[InspirationResponse]
total_count: int

View File

@@ -1,19 +1,18 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
generation_ids: list[str] = []
project_id: str | None = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
date: datetime | None = None
topic: str | None = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]
generation_ids: list[str]

View File

@@ -2,6 +2,6 @@ from .AssetDTO import AssetResponse, AssetsResponse
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from .ExternalGenerationDTO import ExternalGenerationRequest
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse, NsfwRequest
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest

View File

@@ -9,18 +9,19 @@ from uuid import uuid4
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from fastapi import HTTPException
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter
from api.models import FinancialReport, UsageStats, UsageByEntity
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from api.models import (
FinancialReport, UsageStats, UsageByEntity,
GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
)
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from repos.dao import DAO
from utils.image_utils import create_thumbnail
logger = logging.getLogger(__name__)
@@ -28,36 +29,32 @@ logger = logging.getLogger(__name__)
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
prompt: str,
media_group_bytes: List[bytes],
aspect_ratio: AspectRatios,
quality: Quality,
model: str,
gemini: GoogleAdapter,
) -> Tuple[List[bytes], Dict[str, Any]]:
"""
Обертка для вызова синхронного метода Gemini в отдельном потоке.
Возвращает список байтов сгенерированных изображений.
Wrapper for calling Gemini's synchronous method in a separate thread.
"""
try :
try:
logger.info(f"Starting generate_image_task with prompt length: {len(prompt)}")
# Запускаем блокирующую операцию в отдельном потоке, чтобы не тормозить Event Loop
result = await asyncio.to_thread(
gemini.generate_image,
prompt=prompt,
images_list=media_group_bytes,
aspect_ratio=aspect_ratio,
quality=quality,
model=model,
)
generated_images_io, metrics = result
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
except GoogleGenerationException as e:
raise e
except GoogleGenerationException:
raise
finally:
# Освобождаем входные данные — они больше не нужны
del media_group_bytes
images_bytes = []
@@ -66,414 +63,176 @@ async def generate_image_task(
img_io.seek(0)
images_bytes.append(img_io.read())
img_io.close()
# Освобождаем список BytesIO сразу
del generated_images_io
return images_bytes, metrics
class GenerationService:
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
self.dao = dao
self.gemini = gemini
self.s3_adapter = s3_adapter
self.bot = bot
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
future_prompt += prompt
# --- Public API ---
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
future_prompt = (
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
"User may upload reference image too. I will provide sources prompt entered by user. "
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
f"USER_ENTERED_PROMPT: {prompt}"
)
assets_data = []
if assets is not None:
if assets:
assets_db = await self.dao.assets.get_assets_by_ids(assets)
assets_data.extend(asset.data for asset in assets_db)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
logger.info(future_prompt)
logger.info(generated_prompt)
assets_data.extend(asset.data for asset in assets_db if asset.data)
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
logger.info(f"Prompt Assistant: {generated_prompt}")
return generated_prompt
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
if user_prompt:
technical_prompt += f"User also provided this context: {user_prompt}. "
technical_prompt += "Provide ONLY the detailed prompt."
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> GenerationsResponse:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id, idea_id=idea_id)
total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id, idea_id=idea_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count)
async def get_generations(self, **kwargs) -> GenerationsResponse:
current_user_id = kwargs.pop('current_user_id', None)
generations = await self.dao.generations.get_generations(**kwargs)
total_count = await self.dao.generations.count_generations(
character_id=kwargs.get('character_id'),
created_by=kwargs.get('created_by'),
project_id=kwargs.get('project_id'),
idea_id=kwargs.get('idea_id'),
only_liked_by=kwargs.get('only_liked_by')
)
return GenerationsResponse(
generations=[self._map_to_response(gen, current_user_id) for gen in generations],
total_count=total_count
)
async def get_generation(self, generation_id: str) -> Optional[GenerationResponse]:
async def get_generation(self, generation_id: str, current_user_id: Optional[str] = None) -> Optional[GenerationResponse]:
gen = await self.dao.generations.get_generation(generation_id)
if gen is None:
return None
else:
return GenerationResponse(**gen.model_dump())
return self._map_to_response(gen, current_user_id) if gen else None
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
return await self.dao.generations.toggle_like(generation_id, user_id)
async def get_generations_by_group(self, group_id: str, current_user_id: Optional[str] = None) -> GenerationGroupResponse:
generations = await self.dao.generations.get_generations_by_group(group_id)
return GenerationGroupResponse(
generation_group_id=group_id,
generations=[self._map_to_response(gen, current_user_id) for gen in generations]
)
def _map_to_response(self, gen: Generation, current_user_id: Optional[str] = None) -> GenerationResponse:
res = GenerationResponse(**gen.model_dump())
res.likes_count = len(gen.liked_by) if gen.liked_by else 0
res.is_liked = current_user_id in gen.liked_by if current_user_id and gen.liked_by else False
return res
async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse:
count = generation_request.count
if generation_group_id is None:
generation_group_id = str(uuid4())
results = []
for _ in range(count):
for _ in range(generation_request.count):
gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id)
results.append(gen_response)
return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results)
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse:
gen_id = None
generation_model = None
if generation_request.environment_id and not generation_request.linked_character_id:
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
try:
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id:
generation_model.created_by = user_id
if generation_group_id:
generation_model.generation_group_id = generation_group_id
# Explicitly set idea_id from request if present (already in model_dump, but ensuring clarity)
if generation_request.idea_id:
generation_model.idea_id = generation_request.idea_id
gen_id = await self.dao.generations.create_generation(generation_model)
generation_model.id = gen_id
async def runner(gen):
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
try:
async with generation_semaphore:
logger.info(f"Starting background generation task for ID: {gen.id}")
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen is not None:
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
asyncio.create_task(runner(generation_model))
return GenerationResponse(**generation_model.model_dump())
except Exception:
# если не успели создать запись — нечего помечать
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
if gen is not None:
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
async def create_generation(self, generation: Generation):
start_time = datetime.now()
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 2. Получаем ассеты-референсы (если они есть)
media_group_bytes: List[bytes] = []
generation_prompt = generation.prompt
# 1. Prepare input
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
# 2.1 Аватар персонажа (всегда первый, если включен)
if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
img_data = await self._get_asset_data(avatar_asset)
if img_data:
media_group_bytes.append(img_data)
# 2.2 Явно указанные ассеты
if generation.assets_list:
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in explicit_assets:
ref_asset_data = await self._get_asset_data(asset)
if ref_asset_data:
media_group_bytes.append(ref_asset_data)
# 2.3 Ассеты из окружения (в самый конец)
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
img_data = await self._get_asset_data(asset)
if img_data:
media_group_bytes.append(img_data)
if media_group_bytes:
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
logger.info(f"Final generation prompt assembled. Length: {len(generation_prompt)}. Media count: {len(media_group_bytes)}")
# 3. Запускаем процесс генерации и симуляцию прогресса
# 2. Run generation with progress simulation
progress_task = asyncio.create_task(self._simulate_progress(generation))
try:
# Default to Image Generation (Gemini)
generated_bytes_list, metrics = await generate_image_task(
prompt=generation_prompt, # или request.prompt
prompt=generation_prompt,
media_group_bytes=media_group_bytes,
aspect_ratio=generation.aspect_ratio, # предполагаем поля в request
aspect_ratio=generation.aspect_ratio,
quality=generation.quality,
model=generation.model or "gemini-3-pro-image-preview",
gemini=self.gemini
)
# Update metrics from API (Common for both)
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
except GoogleGenerationException as e:
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
except Exception as e:
# Тут стоит добавить логирование ошибки
logging.error(f"Generation failed: {e}")
generation.status = GenerationStatus.FAILED
generation.failed_reason = str(e)
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
raise e
self._update_generation_metrics(generation, metrics)
# 3. Process results
created_assets = await self._process_generated_images(generation, generated_bytes_list)
# 4. Finalize generation record
await self._finalize_generation(generation, created_assets, generation_prompt, start_time)
# 5. Notify
if generation.telegram_id and self.bot:
await self._notify_telegram(generation, created_assets)
finally:
if not progress_task.done():
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except asyncio.CancelledError:
pass
# 4. Сохраняем полученные изображения как новые Ассеты
created_assets: List[Asset] = []
for idx, img_bytes in enumerate(generated_bytes_list):
# Generate thumbnail
thumbnail_bytes = None
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, img_bytes)
# Save to S3
filename = f"generated/{generation.linked_character_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, img_bytes, content_type="image/png")
new_asset = Asset(
name=f"Generated_{generation.linked_character_id}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=generation.linked_character_id,
data=None, # Not storing bytes in DB anymore
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=generation.created_by,
project_id=generation.project_id
)
# Сохраняем в БД
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id) # Присваиваем ID, полученный от базы
created_assets.append(new_asset)
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = generation_prompt
end_time = datetime.now()
generation.execution_time_seconds = (end_time - start_time).total_seconds()
logger.info(f"DEBUG: Saving generation {generation.id}. Metrics: api_exec={generation.api_execution_time_seconds}, tokens={generation.token_usage}, in_tokens={generation.input_token_usage}, out_tokens={generation.output_token_usage}, exec={generation.execution_time_seconds}")
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} completed successfully. {len(created_assets)} assets created. Total Time: {generation.execution_time_seconds:.2f}s")
# 6. Send to Telegram if telegram_id is provided
if generation.telegram_id and self.bot:
try:
for asset in created_assets:
if asset.data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(asset.data, filename=f"{asset.name}.jpg"),
caption=f"Generated from prompt: {generation.prompt[:100]}..."
)
logger.info(f"Sent {len(created_assets)} assets to Telegram ID: {generation.telegram_id}")
except Exception as e:
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
async def _simulate_progress(self, generation: Generation):
"""
Increments progress from 0 to 90 over ~20 seconds.
"""
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
# Random increment between 5 and 15
increment = random.randint(5, 15)
current_progress = min(current_progress + increment, 90)
# Fetch latest state (optional, but good practice to avoid overwriting unrelated fields)
# But for simplicity here we just use the object we have and save it.
# Ideally, we should fetch-update-save or use partial update if DAO supports it.
# Assuming simple update is fine for now.
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
# Task cancelled, generation finished (or failed)
pass
except Exception as e:
logger.error(f"Error in progress simulation: {e}")
async def import_external_generation(self, external_gen) -> Generation:
"""
Import a generation from an external source.
Args:
external_gen: ExternalGenerationRequest with generation data and image
Returns:
Created Generation object
"""
# Validate image source
external_gen.validate_image_source()
logger.info(f"Importing external generation for user: {external_gen.created_by}")
image_bytes = await self._fetch_external_image(external_gen)
# 1. Process image (download or decode)
image_bytes = None
if external_gen.image_url:
# Download image from URL
logger.info(f"Downloading image from URL: {external_gen.image_url}")
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
image_bytes = response.content
elif external_gen.image_data:
# Decode base64 image
logger.info("Decoding base64 image data")
image_bytes = base64.b64decode(external_gen.image_data)
if not image_bytes:
raise ValueError("Failed to process image data")
# 2. Generate thumbnail
from utils.image_utils import create_thumbnail
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
# 3. Save to S3
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
# 4. Create Asset
new_asset = Asset(
# Reuse internal processing logic
new_asset = await self._save_asset(
image_bytes=image_bytes,
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=external_gen.linked_character_id,
data=None, # Not storing bytes in DB
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=external_gen.created_by,
project_id=external_gen.project_id
project_id=external_gen.project_id,
linked_char_id=external_gen.linked_character_id,
folder="external"
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
logger.info(f"Created asset {asset_id} for external generation")
# 5. Create Generation record
generation = Generation(
status=GenerationStatus.DONE,
linked_character_id=external_gen.linked_character_id,
aspect_ratio=external_gen.aspect_ratio,
quality=external_gen.quality,
prompt=external_gen.prompt,
model=external_gen.model,
tech_prompt=external_gen.tech_prompt,
seed=external_gen.seed,
result_list=[new_asset.id],
result=new_asset.id,
progress=100,
nsfw=external_gen.nsfw,
execution_time_seconds=external_gen.execution_time_seconds,
api_execution_time_seconds=external_gen.api_execution_time_seconds,
token_usage=external_gen.token_usage,
input_token_usage=external_gen.input_token_usage,
output_token_usage=external_gen.output_token_usage,
created_by=external_gen.created_by,
project_id=external_gen.project_id,
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
project_id=external_gen.project_id
)
gen_id = await self.dao.generations.create_generation(generation)
generation.id = gen_id
logger.info(f"Created generation {gen_id} from external source")
return generation
async def delete_generation(self, generation_id: str) -> bool:
"""
Soft delete generation by marking it as deleted.
"""
try:
generation = await self.dao.generations.get_generation(generation_id)
if not generation:
return False
generation.is_deleted = True
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
@@ -483,59 +242,205 @@ class GenerationService:
return False
async def cleanup_stale_generations(self):
"""
Cancels generations that have been running for more than 1 hour.
"""
try:
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=60)
count = await self.dao.generations.cancel_stale_generations(timeout_minutes=5)
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
logger.info(f"Cleaned up {count} stale generations")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
async def cleanup_old_data(self, days: int = 30):
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
logger.info(f"Soft-deleted {gen_count} generations. Purging {len(asset_ids)} assets.")
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
"""
Generates a financial usage report for a specific user or project.
'breakdown_by' can be 'created_by' or 'project_id'.
"""
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user = None
by_project = None
by_user, by_project = None, None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(
summary=summary,
by_user=by_user,
by_project=by_project
)
return FinancialReport(summary=summary, by_user=by_user, by_project=by_project)
# --- Private Helpers ---
async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str], generation_group_id: str) -> GenerationResponse:
try:
gen_model = Generation(**generation_request.model_dump(exclude={'count'}))
gen_model.created_by = user_id
gen_model.generation_group_id = generation_group_id
gen_id = await self.dao.generations.create_generation(gen_model)
gen_model.id = gen_id
asyncio.create_task(self._queued_generation_runner(gen_model))
return GenerationResponse(**gen_model.model_dump())
except Exception:
logger.exception("Failed to initiate single generation")
raise
async def _queued_generation_runner(self, gen: Generation):
logger.info(f"Generation {gen.id} waiting for slot...")
try:
async with generation_semaphore:
await self.create_generation(gen)
except Exception as e:
await self._handle_generation_failure(gen, e)
logger.exception(f"Background generation task failed for ID: {gen.id}")
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
media_group_bytes: List[bytes] = []
prompt = generation.prompt
# 1. Character Avatar
if generation.linked_character_id:
char_info = await self.dao.chars.get_character(generation.linked_character_id)
if not char_info:
raise ValueError(f"Character {generation.linked_character_id} not found")
if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
data = await self._get_asset_data_bytes(avatar_asset)
if data: media_group_bytes.append(data)
# 2. Reference Assets
if generation.assets_list:
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
# 3. Environment Assets
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
data = await self._get_asset_data_bytes(asset)
if data: media_group_bytes.append(data)
if media_group_bytes:
prompt += (
" \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main "
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
)
return media_group_bytes, prompt
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
generation.token_usage = metrics.get("token_usage")
generation.input_token_usage = metrics.get("input_token_usage")
generation.output_token_usage = metrics.get("output_token_usage")
async def _handle_generation_failure(self, generation: Generation, error: Optional[Exception]):
logger.error(f"Generation {generation.id} failed: {error}")
generation.status = GenerationStatus.FAILED
# Don't overwrite if reason is already set, unless a new error is provided
if error:
generation.failed_reason = str(error)
elif not generation.failed_reason:
generation.failed_reason = "Unknown error"
generation.updated_at = datetime.now(UTC)
await self.dao.generations.update_generation(generation)
async def _process_generated_images(self, generation: Generation, bytes_list: List[bytes]) -> List[Asset]:
created_assets = []
for img_bytes in bytes_list:
asset = await self._save_asset(
image_bytes=img_bytes,
name=f"Generated_{generation.linked_character_id}",
created_by=generation.created_by,
project_id=generation.project_id,
linked_char_id=generation.linked_character_id,
folder="generated"
)
created_assets.append(asset)
return created_assets
async def _save_asset(self, image_bytes: bytes, name: str, created_by: str, project_id: str, linked_char_id: str, folder: str) -> Asset:
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
filename = f"{folder}/{linked_char_id}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
new_asset = Asset(
name=name,
type=AssetType.GENERATED,
content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id,
data=None,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(new_asset)
new_asset.id = str(asset_id)
return new_asset
async def _finalize_generation(self, generation: Generation, assets: List[Asset], tech_prompt: str, start_time: datetime):
generation.result_list = [a.id for a in assets]
generation.status = GenerationStatus.DONE
generation.progress = 100
generation.updated_at = datetime.now(UTC)
generation.tech_prompt = tech_prompt
generation.execution_time_seconds = (datetime.now() - start_time).total_seconds()
await self.dao.generations.update_generation(generation)
logger.info(f"Generation {generation.id} finalized. Time: {generation.execution_time_seconds:.2f}s")
async def _notify_telegram(self, generation: Generation, assets: List[Asset]):
try:
for asset in assets:
# Need to get data for telegram if it's not in Asset object
img_data = await self.s3_adapter.get_file(asset.minio_object_name) if asset.minio_object_name else asset.data
if img_data:
await self.bot.send_photo(
chat_id=generation.telegram_id,
photo=BufferedInputFile(img_data, filename=f"{asset.name}.png"),
caption=f"Generated from: {generation.prompt[:100]}..."
)
except Exception as e:
logger.error(f"Failed to send to Telegram: {e}")
async def _simulate_progress(self, generation: Generation):
current_progress = 0
try:
while current_progress < 90:
await asyncio.sleep(4)
current_progress = min(current_progress + random.randint(5, 15), 90)
generation.progress = current_progress
await self.dao.generations.update_generation(generation)
except asyncio.CancelledError:
pass
async def _fetch_external_image(self, external_gen) -> bytes:
if external_gen.image_url:
async with httpx.AsyncClient() as client:
response = await client.get(external_gen.image_url, timeout=30.0)
response.raise_for_status()
return response.content
elif external_gen.image_data:
return base64.b64decode(external_gen.image_data)
raise ValueError("No image source provided")

View File

@@ -7,8 +7,14 @@ class IdeaService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str) -> Idea:
idea = Idea(name=name, description=description, project_id=project_id, created_by=user_id)
async def create_idea(self, name: str, description: Optional[str], project_id: Optional[str], user_id: str, inspiration_id: Optional[str] = None) -> Idea:
idea = Idea(
name=name,
description=description,
project_id=project_id,
created_by=user_id,
inspiration_id=inspiration_id
)
idea_id = await self.dao.ideas.create_idea(idea)
idea.id = idea_id
return idea
@@ -19,7 +25,7 @@ class IdeaService:
async def get_idea(self, idea_id: str) -> Optional[Idea]:
return await self.dao.ideas.get_idea(idea_id)
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Idea]:
async def update_idea(self, idea_id: str, name: Optional[str] = None, description: Optional[str] = None, inspiration_id: Optional[str] = None) -> Optional[Idea]:
idea = await self.dao.ideas.get_idea(idea_id)
if not idea:
return None
@@ -28,6 +34,8 @@ class IdeaService:
idea.name = name
if description is not None:
idea.description = description
if inspiration_id is not None:
idea.inspiration_id = inspiration_id
idea.updated_at = datetime.now()
await self.dao.ideas.update_idea(idea)
@@ -72,4 +80,3 @@ class IdeaService:
return True
return False

View File

@@ -0,0 +1,146 @@
import asyncio
import logging
import os
import tempfile
from datetime import datetime
from typing import List, Optional, Tuple
import httpx
from fastapi import HTTPException
from models.Asset import Asset, AssetType, AssetContentType
from models.Inspiration import Inspiration
from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
# Try to import yt_dlp, but don't crash if it's missing (though we added it to requirements)
try:
import yt_dlp
except ImportError:
yt_dlp = None
logger = logging.getLogger(__name__)
class InspirationService:
def __init__(self, dao: DAO, s3_adapter: S3Adapter):
self.dao = dao
self.s3_adapter = s3_adapter
async def create_inspiration(self, source_url: str, created_by: str, project_id: Optional[str] = None, caption: Optional[str] = None) -> Inspiration:
# 1. Download content from Instagram
try:
content_bytes, content_type, ext = await self._download_content(source_url)
except Exception as e:
logger.error(f"Failed to download content from {source_url}: {e}")
raise HTTPException(status_code=400, detail=f"Failed to download content: {str(e)}")
# 2. Save as Asset
filename = f"inspirations/{datetime.now().strftime('%Y%m%d_%H%M%S')}_insta.{ext}"
await self.s3_adapter.upload_file(filename, content_bytes, content_type=content_type)
asset = Asset(
name=f"Inspiration from {source_url}",
type=AssetType.INSPIRATION,
content_type=AssetContentType.VIDEO if content_type.startswith("video") else AssetContentType.IMAGE,
minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name,
created_by=created_by,
project_id=project_id
)
asset_id = await self.dao.assets.create_asset(asset)
# 3. Create Inspiration object
inspiration = Inspiration(
source_url=source_url,
caption=caption,
asset_id=str(asset_id),
created_by=created_by,
project_id=project_id
)
insp_id = await self.dao.inspirations.create_inspiration(inspiration)
inspiration.id = insp_id
return inspiration
async def get_inspirations(self, project_id: Optional[str], created_by: str, limit: int = 20, offset: int = 0) -> List[Inspiration]:
return await self.dao.inspirations.get_inspirations(project_id, created_by, limit, offset)
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
return await self.dao.inspirations.get_inspiration(inspiration_id)
async def mark_as_completed(self, inspiration_id: str, is_completed: bool = True) -> Optional[Inspiration]:
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
if not inspiration:
return None
inspiration.is_completed = is_completed
inspiration.updated_at = datetime.now()
await self.dao.inspirations.update_inspiration(inspiration)
return inspiration
async def delete_inspiration(self, inspiration_id: str) -> bool:
inspiration = await self.dao.inspirations.get_inspiration(inspiration_id)
if not inspiration:
return False
# Delete associated asset
if inspiration.asset_id:
await self.dao.assets.delete_asset(inspiration.asset_id)
return await self.dao.inspirations.delete_inspiration(inspiration_id)
async def _download_content(self, url: str) -> Tuple[bytes, str, str]:
"""
Downloads content using yt-dlp.
Returns (content_bytes, content_type, extension)
"""
if not yt_dlp:
raise RuntimeError("yt-dlp is not installed")
logger.info(f"Downloading from {url} using yt-dlp...")
def run_yt_dlp():
with tempfile.TemporaryDirectory() as tmpdirname:
ydl_opts = {
'outtmpl': f'{tmpdirname}/%(id)s.%(ext)s',
'quiet': True,
'no_warnings': True,
'format': 'best', # Best quality single file
'noplaylist': True, # Only single video if it's a playlist/profile
'writethumbnail': False,
'writesubtitles': False,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
# Find the downloaded file
files = os.listdir(tmpdirname)
if not files:
raise Exception("No files downloaded")
# Pick the largest file if multiple (e.g. if yt-dlp downloaded parts)
# But with 'format': 'best', it should be one.
# If carousel, it might be multiple. Let's pick the first one.
filename = files[0]
filepath = os.path.join(tmpdirname, filename)
with open(filepath, 'rb') as f:
data = f.read()
ext = filename.split('.')[-1].lower()
# Determine content type
if ext in ['mp4', 'mov', 'avi', 'mkv', 'webm']:
content_type = f"video/{ext}"
if ext == 'mov': content_type = "video/quicktime"
elif ext in ['jpg', 'jpeg', 'png', 'webp']:
content_type = f"image/{ext}"
if ext == 'jpg': content_type = "image/jpeg"
else:
content_type = "application/octet-stream"
return data, content_type, ext
return await asyncio.to_thread(run_yt_dlp)

View File

@@ -1,12 +1,11 @@
from datetime import datetime, UTC
from typing import Optional, List
from pydantic import BaseModel, Field
class Album(BaseModel):
id: Optional[str] = None
id: str | None = None
name: str
description: Optional[str] = None
cover_asset_id: Optional[str] = None
generation_ids: List[str] = []
description: str | None = None
cover_asset_id: str | None = None
generation_ids: list[str] = []
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,6 +1,6 @@
from datetime import datetime, UTC
from enum import Enum
from typing import Optional, Any, List
from typing import Any
from pydantic import BaseModel, computed_field, Field, model_validator
@@ -8,28 +8,30 @@ from pydantic import BaseModel, computed_field, Field, model_validator
class AssetContentType(str, Enum):
IMAGE = 'image'
PROMPT = 'prompt'
VIDEO = 'video'
class AssetType(str, Enum):
UPLOADED = 'uploaded'
GENERATED = 'generated'
INSPIRATION = 'inspiration'
class Asset(BaseModel):
id: Optional[str] = None
id: str | None = None
name: str
type: AssetType = AssetType.GENERATED
content_type: AssetContentType = AssetContentType.IMAGE
linked_char_id: Optional[str] = None
data: Optional[bytes] = None
tg_doc_file_id: Optional[str] = None
tg_photo_file_id: Optional[str] = None
minio_object_name: Optional[str] = None
minio_bucket: Optional[str] = None
minio_thumbnail_object_name: Optional[str] = None
thumbnail: Optional[bytes] = None
tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
linked_char_id: str | None = None
data: bytes | None = None
tg_doc_file_id: str | None = None
tg_photo_file_id: str | None = None
minio_object_name: str | None = None
minio_bucket: str | None = None
minio_thumbnail_object_name: str | None = None
thumbnail: bytes | None = None
tags: list[str] = []
created_by: str | None = None
project_id: str | None = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,16 +1,15 @@
from typing import Optional
from pydantic import BaseModel
from pydantic_core.core_schema import computed_field
class Character(BaseModel):
id: Optional[str] = None
id: str | None = None
name: str
avatar_asset_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: Optional[str] = None
character_bio: Optional[str] = None
created_by: Optional[str] = None
project_id: Optional[str] = None
avatar_asset_id: str | None = None
avatar_image: str | None = None
character_image_doc_tg_id: str | None = None
character_image_tg_id: str | None = None
character_bio: str | None = None
created_by: str | None = None
project_id: str | None = None

View File

@@ -1,15 +1,14 @@
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
from bson import ObjectId
class Environment(BaseModel):
id: Optional[str] = Field(None, alias="_id")
id: str | None = Field(None, alias="_id")
character_id: str
name: str = Field(..., min_length=1)
description: Optional[str] = None
asset_ids: List[str] = Field(default_factory=list)
description: str | None = None
asset_ids: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -1,11 +1,9 @@
from datetime import datetime, UTC
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel, Field, computed_field
from models.Asset import Asset
from models.enums import AspectRatios, Quality, GenType
from models.enums import AspectRatios, Quality
class GenerationStatus(str, Enum):
@@ -14,32 +12,36 @@ class GenerationStatus(str, Enum):
FAILED = "failed"
class Generation(BaseModel):
id: Optional[str] = None
id: str | None = None
status: GenerationStatus = GenerationStatus.RUNNING
failed_reason: Optional[str] = None
linked_character_id: Optional[str] = None
telegram_id: Optional[int] = None
failed_reason: str | None = None
linked_character_id: str | None = None
telegram_id: int | None = None
use_profile_image: bool = True
aspect_ratio: AspectRatios
quality: Quality
prompt: str
tech_prompt: Optional[str] = None
assets_list: List[str] = Field(default_factory=list)
result_list: List[str] = Field(default_factory=list)
result: Optional[str] = None
model: str | None = None
seed: int | None = None
tech_prompt: str | None = None
assets_list: list[str] = Field(default_factory=list)
result_list: list[str] = Field(default_factory=list)
result: str | None = None
progress: int = 0
execution_time_seconds: Optional[float] = None
api_execution_time_seconds: Optional[float] = None
token_usage: Optional[int] = None
input_token_usage: Optional[int] = None
output_token_usage: Optional[int] = None
execution_time_seconds: float | None = None
api_execution_time_seconds: float | None = None
token_usage: int | None = None
input_token_usage: int | None = None
output_token_usage: int | None = None
is_deleted: bool = False
album_id: Optional[str] = None
environment_id: Optional[str] = None
generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
idea_id: Optional[str] = None
album_id: str | None = None
environment_id: str | None = None
generation_group_id: str | None = None
created_by: str | None = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: str | None = None
idea_id: str | None = None
liked_by: list[str] = Field(default_factory=list)
nsfw: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@@ -49,4 +51,4 @@ class Generation(BaseModel):
cost_input = self.input_token_usage * 0.000002
cost_output = self.output_token_usage * 0.00012
return round(cost_input + cost_output, 3)
return 0.0
return 0.0

View File

@@ -1,12 +1,12 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel, Field
class Idea(BaseModel):
id: Optional[str] = None
id: str | None = None
name: str = "New Idea"
description: Optional[str] = None
project_id: Optional[str] = None
description: str | None = None
project_id: str | None = None
inspiration_id: str | None = None # Link to Inspiration
created_by: str # User ID
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

15
models/Inspiration.py Normal file
View File

@@ -0,0 +1,15 @@
from datetime import datetime, UTC
from pydantic import BaseModel, Field
class Inspiration(BaseModel):
id: str | None = None
source_url: str
caption: str | None = None
asset_id: str
is_completed: bool = False
created_by: str
project_id: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,14 +1,13 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
id: str | None = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
generation_ids: list[str] = Field(default_factory=list)
project_id: str | None = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -1,12 +1,11 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class Project(BaseModel):
id: Optional[str] = None
id: str | None = None
name: str
description: Optional[str] = None
description: str | None = None
owner_id: str
members: List[str] = [] # List of User IDs
members: list[str] = [] # List of User IDs
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

View File

@@ -2,19 +2,30 @@ from enum import Enum
class AspectRatios(str, Enum):
NINESIXTEEN = "NINESIXTEEN"
SIXTEENNINE = "SIXTEENNINE"
THREEFOUR = "THREEFOUR"
FOURTHREE = "FOURTHREE"
ONEONE = "1:1"
TWOTHREE = "2:3"
THREETWO = "3:2"
THREEFOUR = "3:4"
FOURTHREE = "4:3"
FOURFIVE = "4:5"
FIVEFOUR = "5:4"
NINESIXTEEN = "9:16"
SIXTEENNINE = "16:9"
TWENTYONENINE = "21:9"
@classmethod
def _missing_(cls, value):
mapping = {
"NINESIXTEEN": cls.NINESIXTEEN,
"SIXTEENNINE": cls.SIXTEENNINE,
"THREEFOUR": cls.THREEFOUR,
"FOURTHREE": cls.FOURTHREE,
}
return mapping.get(value)
@property
def value_ratio(self) -> str:
return {
AspectRatios.NINESIXTEEN: "9:16",
AspectRatios.SIXTEENNINE: "16:9",
AspectRatios.THREEFOUR: "3:4",
AspectRatios.FOURTHREE: "4:3",
}[self]
return self.value
class Quality(str, Enum):
@@ -41,3 +52,20 @@ class GenType(str, Enum):
GenType.TEXT: 'Text',
GenType.IMAGE: 'Image',
}[self]
class TextModel(str, Enum):
GEMINI_3_1_PRO_PREVIEW = "gemini-3.1-pro-preview"
@property
def value_model(self) -> str:
return self.value
class ImageModel(str, Enum):
GEMINI_3_PRO_IMAGE_PREVIEW = "gemini-3-pro-image-preview"
GEMINI_3_1_FLASH_IMAGE_PREVIEW = "gemini-3.1-flash-image-preview"
@property
def value_model(self) -> str:
return self.value

View File

@@ -102,7 +102,7 @@ class AssetsRepo:
return assets
async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset:
async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]:
projection = None
if not with_data:
projection = {"data": 0, "thumbnail": 0}
@@ -182,7 +182,9 @@ class AssetsRepo:
return await self.collection.count_documents(filter)
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids]
object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)]
if not object_ids:
return []
res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small?
# Original excluded thumbnail too.
assets = []

View File

@@ -9,6 +9,7 @@ from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from repos.environment_repo import EnvironmentRepo
from repos.inspiration_repo import InspirationRepo
from typing import Optional
@@ -25,3 +26,4 @@ class DAO:
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)
self.environments = EnvironmentRepo(client, db_name)
self.inspirations = InspirationRepo(client, db_name)

View File

@@ -26,7 +26,8 @@ class GenerationRepo:
return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None,
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> List[Generation]:
filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None:
@@ -43,6 +44,8 @@ class GenerationRepo:
filter["project_id"] = project_id
if idea_id is not None:
filter["idea_id"] = idea_id
if only_liked_by is not None:
filter["liked_by"] = only_liked_by
# If fetching for an idea, sort by created_at ascending (cronological)
# Otherwise typically descending (newest first)
@@ -57,7 +60,8 @@ class GenerationRepo:
return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int:
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None,
idea_id: Optional[str] = None, only_liked_by: Optional[str] = None) -> int:
args = {}
if character_id is not None:
args["linked_character_id"] = character_id
@@ -73,6 +77,8 @@ class GenerationRepo:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
if only_liked_by is not None:
args["liked_by"] = only_liked_by
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -94,6 +100,47 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def toggle_like(self, generation_id: str, user_id: str) -> bool | None:
"""
Toggles like for a user on a generation.
Returns True if liked, False if unliked, None if generation not found.
"""
if not ObjectId.is_valid(generation_id):
return None
oid = ObjectId(generation_id)
# Check if generation exists
gen = await self.collection.find_one({"_id": oid}, {"liked_by": 1})
if not gen:
return None
if user_id in gen.get("liked_by", []):
# Unlike
await self.collection.update_one(
{"_id": oid},
{"$pull": {"liked_by": user_id}}
)
return False
else:
# Like
await self.collection.update_one(
{"_id": oid},
{"$addToSet": {"liked_by": user_id}}
)
return True
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
if not ObjectId.is_valid(generation_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(generation_id)},
{"$set": {"nsfw": is_nsfw}}
)
return res.modified_count > 0
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
"""
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
@@ -253,7 +300,6 @@ class GenerationRepo:
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(

54
repos/inspiration_repo.py Normal file
View File

@@ -0,0 +1,54 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Inspiration import Inspiration
class InspirationRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["inspirations"]
async def create_inspiration(self, inspiration: Inspiration) -> str:
res = await self.collection.insert_one(inspiration.model_dump(exclude={"id"}))
return str(res.inserted_id)
async def get_inspiration(self, inspiration_id: str) -> Optional[Inspiration]:
res = await self.collection.find_one({"_id": ObjectId(inspiration_id)})
if res:
res["id"] = str(res.pop("_id"))
return Inspiration(**res)
return None
async def get_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None, limit: int = 20, offset: int = 0) -> List[Inspiration]:
query = {}
if project_id:
query["project_id"] = project_id
if created_by:
query["created_by"] = created_by
cursor = self.collection.find(query).sort("created_at", -1).skip(offset).limit(limit)
inspirations = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
inspirations.append(Inspiration(**doc))
return inspirations
async def count_inspirations(self, project_id: Optional[str] = None, created_by: Optional[str] = None) -> int:
query = {}
if project_id:
query["project_id"] = project_id
if created_by:
query["created_by"] = created_by
return await self.collection.count_documents(query)
async def update_inspiration(self, inspiration: Inspiration):
await self.collection.update_one(
{"_id": ObjectId(inspiration.id)},
{"$set": inspiration.model_dump(exclude={"id"})}
)
async def delete_inspiration(self, inspiration_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(inspiration_id)})
return res.deleted_count > 0

View File

@@ -52,3 +52,4 @@ python-multipart==0.0.22
email-validator
prometheus-fastapi-instrumentator
pydantic-settings==2.13.0
yt-dlp

View File

@@ -126,12 +126,11 @@ async def change_char(call: CallbackQuery, state: FSMContext, dao: DAO):
@router.callback_query(States.gen_mode, F.data == 'gen_mode_change_aspect_ratio')
async def gen_mode_change_aspect_ratio(call: CallbackQuery, state: FSMContext, dao: DAO):
await call.answer()
keyboards = []
for ratio in AspectRatios:
keyboards.append(InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}'))
buttons = [InlineKeyboardButton(text=ratio.value, callback_data=f'select_ratio_{ratio.name}') for ratio in AspectRatios]
keyboard_rows = [buttons[i:i + 4] for i in range(0, len(buttons), 4)]
keyboard_rows.append([InlineKeyboardButton(text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")])
await call.message.edit_caption(caption="Выбери соотношение сторон",
reply_markup=InlineKeyboardMarkup(inline_keyboard=[keyboards, [InlineKeyboardButton(
text="⬅️ Назад", callback_data="gen_mode_cancel_ratio_change")]]))
reply_markup=InlineKeyboardMarkup(inline_keyboard=keyboard_rows))
@router.callback_query(States.gen_mode, F.data.startswith('select_ratio_'))