15 Commits

50 changed files with 956 additions and 194 deletions

1
.gitignore vendored
View File

@@ -23,3 +23,4 @@ services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

View File

@@ -23,10 +23,10 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> tuple:
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents = [prompt]
contents : list [Any]= [prompt]
opened_images = []
if images_list:
logger.info(f"Preparing content with {len(images_list)} images")
@@ -41,7 +41,7 @@ class GoogleAdapter:
logger.info("Preparing content with no images")
return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
"""
Генерация текста (Чат или Vision).
Возвращает строку с ответом.
@@ -74,7 +74,7 @@ class GoogleAdapter:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке).
@@ -130,7 +130,9 @@ class GoogleAdapter:
try:
# 1. Берем сырые байты
raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data)
if raw_data is None:
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp()

View File

@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager
async def _get_client(self):
async with self.session.client(
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id,
@@ -63,10 +63,12 @@ class S3Adapter:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
# aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body']
data = await body.read()
# Yield in chunks to avoid holding entire response in StreamingResponse buffer
for i in range(0, len(data), chunk_size):
yield data[i:i + chunk_size]
while True:
chunk = await body.read(chunk_size)
if not chunk:
break
yield chunk
except ClientError as e:
print(f"Error streaming from S3: {e}")
return

39
aiws.py
View File

@@ -1,6 +1,5 @@
import asyncio
import logging
import os
from contextlib import asynccontextmanager
from aiogram import Bot, Dispatcher, Router, F
@@ -9,7 +8,6 @@ from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command
from aiogram.types import Message
from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info
@@ -17,6 +15,7 @@ from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА ---
from config import settings
from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService
@@ -44,17 +43,19 @@ from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
load_dotenv()
logger = logging.getLogger(__name__)
# --- КОНФИГУРАЦИЯ ---
BOT_TOKEN = os.getenv("BOT_TOKEN")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
# Настройки теперь берутся из config.py
BOT_TOKEN = settings.BOT_TOKEN
GEMINI_API_KEY = settings.GEMINI_API_KEY
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
ADMIN_ID = settings.ADMIN_ID
def setup_logging():
@@ -64,6 +65,8 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -76,15 +79,19 @@ char_repo = CharacterRepo(mongo_client)
# S3 Adapter
s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=settings.MINIO_SECRET_KEY,
bucket_name=settings.MINIO_BUCKET
)
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot)
if bot is None:
raise ValueError("bot is not set")
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
album_service = AlbumService(dao)
# Dispatcher
@@ -126,11 +133,12 @@ async def start_scheduler(service: GenerationService):
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(60) # Check every 10 minutes
await asyncio.sleep(60) # Check every 60 seconds
# --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager
@@ -212,6 +220,7 @@ app.include_router(api_gen_router)
app.include_router(api_album_router)
app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(
@@ -250,7 +259,7 @@ if __name__ == "__main__":
async def main():
# Создаем конфигурацию uvicorn вручную
# loop="asyncio" заставляет использовать стандартный цикл
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
server = uvicorn.Server(config)
# Запускаем сервер (lifespan запустится внутри)

View File

@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService
from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ...
@@ -53,4 +54,12 @@ def get_idea_service(dao: DAO = Depends(get_dao)) -> IdeaService:
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id
return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)

View File

@@ -5,6 +5,8 @@ from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus
from api.dependency import get_dao
from repos.dao import DAO
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt
from starlette.requests import Request
@@ -23,7 +25,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
username: str | None = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:

View File

@@ -1,10 +1,13 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"])

View File

@@ -12,7 +12,7 @@ from starlette.requests import Request
from starlette.responses import Response, JSONResponse, StreamingResponse
from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
@@ -278,8 +278,7 @@ async def upload_asset(
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
linked_char_id=asset.linked_char_id,
created_at=asset.created_at,
url=asset.url
created_at=asset.created_at
)

View File

@@ -5,11 +5,11 @@ from pydantic import BaseModel
from starlette.exceptions import HTTPException
from starlette.requests import Request
from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from api.models import AssetsResponse, AssetResponse
from api.models import GenerationRequest, GenerationResponse
from models.Asset import Asset
from models.Character import Character
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from api.models import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO
from api.dependency import get_dao
@@ -24,8 +24,15 @@ router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[
@router.get("/", response_model=List[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
logger.info("get_characters called")
async def get_characters(
request: Request,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
limit: int = 100,
offset: int = 0
) -> List[Character]:
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"])
if project_id:
@@ -34,7 +41,12 @@ async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
characters = await dao.chars.get_all_characters(
created_by=user_id_filter,
project_id=project_id,
limit=limit,
offset=offset
)
return characters
@@ -178,10 +190,3 @@ async def delete_character(
raise HTTPException(status_code=500, detail="Failed to delete character")
return
@router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse:
logger.info(f"post_character_generation called. CharacterID: {character_id}")
generation_service = request.app.state.generation_service

View File

@@ -1,25 +1,32 @@
import logging
import os
import json
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends
from starlette import status
from starlette.requests import Request
from api import service
from config import settings
from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
from api.endpoints.auth import get_current_user
from api.models import (
GenerationResponse,
GenerationRequest,
GenerationsResponse,
PromptResponse,
PromptRequest,
GenerationGroupResponse,
FinancialReport,
ExternalGenerationRequest
)
from api.service.generation_service import GenerationService
from models.Generation import Generation
from starlette import status
import logging
from repos.dao import DAO
from utils.external_auth import verify_signature
logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"])
@@ -68,6 +75,47 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.get("/usage", response_model=FinancialReport)
async def get_usage_report(
breakdown: Optional[str] = None, # "user" or "project"
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
return await generation_service.get_financial_report(
user_id=user_id_filter,
project_id=project_id,
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service),
@@ -120,7 +168,15 @@ async def get_generation(generation_id: str,
logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
# Check project membership
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied")
return gen
@@ -136,17 +192,13 @@ async def import_external_generation(
Import a generation from an external source.
Requires server-to-server authentication via HMAC signature.
"""
import os
from utils.external_auth import verify_signature
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
logger.info("import_external_generation called")
# Get raw request body for signature verification
body = await request.body()
# Verify signature
secret = os.getenv("EXTERNAL_API_SECRET")
secret = settings.EXTERNAL_API_SECRET
if not secret:
logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error")
@@ -156,7 +208,6 @@ async def import_external_generation(
raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body
import json
try:
data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data)
@@ -181,4 +232,4 @@ async def delete_generation(generation_id: str,
deleted = await generation_service.delete_generation(generation_id)
if not deleted:
raise HTTPException(status_code=404, detail="Generation not found")
return None
return None

View File

@@ -5,7 +5,8 @@ from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService
from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
from api.models import GenerationResponse, GenerationsResponse
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"])

View File

@@ -0,0 +1,99 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=List[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

View File

@@ -1,4 +1,6 @@
from typing import List, Optional
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from api.dependency import get_dao
@@ -12,14 +14,46 @@ class ProjectCreate(BaseModel):
name: str
description: Optional[str] = None
class ProjectMemberResponse(BaseModel):
id: str
username: str
class ProjectResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
owner_id: str
members: List[str]
members: List[ProjectMemberResponse]
is_owner: bool = False
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
member_responses = []
for member_id in project.members:
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
# But for Web users we might need to search by _id.
# Let's try to get user info.
# Since project.members contains strings (ObjectIds as strings), we search by _id.
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
if not user_doc and member_id.isdigit():
# Fallback for telegram IDs if they are stored as strings of digits
user_doc = await dao.users.get_user(int(member_id))
username = "unknown"
if user_doc:
username = user_doc.get("username", "unknown")
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
return ProjectResponse(
id=project.id,
name=project.name,
description=project.description,
owner_id=project.owner_id,
members=member_responses,
is_owner=(project.owner_id == current_user_id)
)
@router.post("", response_model=ProjectResponse)
async def create_project(
project_data: ProjectCreate,
@@ -34,27 +68,15 @@ async def create_project(
members=[user_id]
)
project_id = await dao.projects.create_project(new_project)
new_project.id = project_id
# Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return ProjectResponse(
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
return await _get_project_response(new_project, user_id, dao)
@router.get("", response_model=List[ProjectResponse])
async def get_my_projects(
@@ -66,14 +88,7 @@ async def get_my_projects(
responses = []
for p in projects:
responses.append(ProjectResponse(
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
responses.append(await _get_project_response(p, user_id, dao))
return responses
class MemberAdd(BaseModel):

View File

@@ -0,0 +1,18 @@
from pydantic import BaseModel
from typing import List, Optional
class UsageStats(BaseModel):
total_runs: int
total_tokens: int
total_input_tokens: int
total_output_tokens: int
total_cost: float
class UsageByEntity(BaseModel):
entity_id: Optional[str] = None
stats: UsageStats
class FinancialReport(BaseModel):
summary: UsageStats
by_user: Optional[List[UsageByEntity]] = None
by_project: Optional[List[UsageByEntity]] = None

19
api/models/PostRequest.py Normal file
View File

@@ -0,0 +1,19 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]

View File

@@ -0,0 +1,7 @@
from .AssetDTO import AssetResponse, AssetsResponse
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from .ExternalGenerationDTO import ExternalGenerationRequest
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest

View File

@@ -1,27 +1,31 @@
import asyncio
import base64
import logging
import random
import base64
from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO
from uuid import uuid4
import httpx
import httpx
from aiogram import Bot
from aiogram.types import BufferedInputFile
from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
from adapters.s3_adapter import S3Adapter
from api.models import FinancialReport, UsageStats, UsageByEntity
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality, GenType
from models.enums import AspectRatios, Quality
from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__)
# Limit concurrent generations to 4
generation_semaphore = asyncio.Semaphore(4)
# --- Вспомогательная функция генерации ---
async def generate_image_task(
@@ -74,7 +78,7 @@ class GenerationService:
self.bot = bot
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
@@ -144,16 +148,19 @@ class GenerationService:
generation_model.id = gen_id
async def runner(gen):
logger.info(f"Starting background generation task for ID: {gen.id}")
logger.info(f"Generation {gen.id} entered queue (waiting for slot)...")
try:
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
async with generation_semaphore:
logger.info(f"Starting background generation task for ID: {gen.id}")
await self.create_generation(gen)
logger.info(f"Background generation task finished for ID: {gen.id}")
except Exception:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
if db_gen is not None:
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
logger.exception("Failed to mark generation as FAILED")
logger.exception("create_generation task failed")
@@ -167,8 +174,9 @@ class GenerationService:
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
if gen is not None:
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
logger.exception("Failed to mark generation as FAILED in create_generation_task")
raise
@@ -196,9 +204,10 @@ class GenerationService:
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
media_group_bytes.append(avatar_asset.data)
if char_info.avatar_asset_id is not None:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
@@ -299,7 +308,9 @@ class GenerationService:
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
@@ -367,8 +378,7 @@ class GenerationService:
Returns:
Created Generation object
"""
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
# Validate image source
external_gen.validate_image_source()
@@ -474,4 +484,51 @@ class GenerationService:
if count > 0:
logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
"""
Generates a financial usage report for a specific user or project.
'breakdown_by' can be 'created_by' or 'project_id'.
"""
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user = None
by_project = None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(
summary=summary,
by_user=by_user,
by_project=by_project
)

View File

@@ -0,0 +1,79 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

39
config.py Normal file
View File

@@ -0,0 +1,39 @@
import os
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# Telegram Bot
BOT_TOKEN: str
ADMIN_ID: int = 0
# AI Service
GEMINI_API_KEY: str
# Database
MONGO_HOST: str = "mongodb://localhost:27017"
DB_NAME: str = "my_bot_db"
# S3 Storage (Minio)
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "ai-char"
# External API
EXTERNAL_API_SECRET: Optional[str] = None
# JWT Security
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
model_config = SettingsConfigDict(
env_file=os.getenv("ENV_FILE", ".env"),
env_file_encoding="utf-8",
extra="ignore"
)
settings = Settings()

View File

@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
# Ждем сбора остальных частей
await asyncio.sleep(self.latency)
# Проверяем, что ключ все еще существует (на всякий случай)
# Проверяем, что ключ все еще существует
if group_id in self.album_data:
# Передаем собранный альбом в хендлер
# Сортируем по message_id, чтобы порядок был верным
self.album_data[group_id].sort(key=lambda x: x.message_id)
data["album"] = self.album_data[group_id]
current_album = self.album_data[group_id]
current_album.sort(key=lambda x: x.message_id)
data["album"] = current_album
return await handler(event, data)
finally:
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
# Проверяем, что мы удаляем именно то, что создали, и ключ существует
if group_id in self.album_data and self.album_data[group_id][0] == event:
del self.album_data[group_id]
# ЧИСТКА: Удаляем запись после обработки или таймаута
# Используем pop() с дефолтом, чтобы избежать KeyError
self.album_data.pop(group_id, None)
else:
# Если группа уже собирается - просто добавляем и выходим

View File

@@ -30,6 +30,7 @@ class Asset(BaseModel):
tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@@ -62,6 +63,7 @@ class Asset(BaseModel):
# --- CALCULATED FIELD ---
@computed_field
@property
def url(self) -> str:
"""
Это поле автоматически вычислится и попадет в model_dump() / .json()

View File

@@ -9,7 +9,6 @@ class Character(BaseModel):
name: str
avatar_asset_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_data: Optional[bytes] = None
character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: Optional[str] = None
character_bio: Optional[str] = None

23
models/Post.py Normal file
View File

@@ -0,0 +1,23 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from typing import Any, List, Optional
import logging
from datetime import datetime, UTC
from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient
@@ -50,7 +51,7 @@ class AssetsRepo:
return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {}
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
if asset_type:
filter["type"] = asset_type
args = {}
@@ -174,6 +175,8 @@ class AssetsRepo:
filter["linked_char_id"] = character_id
if created_by:
filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id:
filter["project_id"] = project_id
return await self.collection.count_documents(filter)
@@ -202,6 +205,61 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3:

View File

@@ -15,26 +15,24 @@ class CharacterRepo:
character.id = str(op.inserted_id)
return character
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
args = {}
if not with_image_data:
args["character_image_data"] = 0
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
async def get_character(self, character_id: str) -> Character | None:
res = await self.collection.find_one({"_id": ObjectId(character_id)})
if res is None:
return None
else:
res["id"] = str(res.pop("_id"))
return Character(**res)
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
filter = {}
if created_by:
filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id:
filter["project_id"] = project_id
args = {"character_image_data": 0} # don't return image data for list
res = await self.collection.find(filter, args).to_list(None)
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
chars = []
for doc in res:
doc["id"] = str(doc.pop("_id"))

View File

@@ -7,6 +7,7 @@ from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from typing import Optional
@@ -21,3 +22,4 @@ class DAO:
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)

View File

@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Any, Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset
@@ -17,7 +17,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
async def get_generation(self, generation_id: str) -> Generation | None:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
@@ -28,7 +28,7 @@ class GenerationRepo:
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False}
filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None:
filter["linked_character_id"] = character_id
if status is not None:
@@ -65,10 +65,14 @@ class GenerationRepo:
args["status"] = status
if created_by is not None:
args["created_by"] = created_by
if project_id is None:
args["project_id"] = None
if project_id is not None:
args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -90,6 +94,121 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
"""
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
"""
pipeline = []
# 1. Match active done generations
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
if created_by:
match_stage["created_by"] = created_by
if project_id:
match_stage["project_id"] = project_id
pipeline.append({"$match": match_stage})
# 2. Group by null (total)
pipeline.append({
"$group": {
"_id": None,
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(1)
if not res:
return {
"total_runs": 0,
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cost": 0.0
}
result = res[0]
result.pop("_id")
result["total_cost"] = round(result["total_cost"], 4)
return result
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
"""
Returns usage statistics grouped by user or project.
"""
pipeline = []
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
if project_id:
match_stage["project_id"] = project_id
if created_by:
match_stage["created_by"] = created_by
pipeline.append({"$match": match_stage})
pipeline.append({
"$group": {
"_id": f"${group_by}",
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
pipeline.append({"$sort": {"total_cost": -1}})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(None)
results = []
for item in res:
entity_id = item.pop("_id")
item["total_cost"] = round(item["total_cost"], 4)
results.append({
"entity_id": str(entity_id) if entity_id else "unknown",
"stats": item
})
return results
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = []
@@ -114,3 +233,37 @@ class GenerationRepo:
}
)
return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

View File

@@ -39,8 +39,17 @@ class IdeaRepo:
"from": "generations",
"let": {"idea_id": "$str_id"},
"pipeline": [
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
{"$sort": {"created_at": -1}}, # Ensure we get the latest
{
"$match": {
"$and": [
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
{"status": "done"},
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
{"is_deleted": False}
]
}
},
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
{"$limit": 1}
],
"as": "generations"

97
repos/post_repo.py Normal file
View File

@@ -0,0 +1,97 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0

View File

@@ -51,3 +51,4 @@ python-jose[cryptography]==3.3.0
python-multipart==0.0.22
email-validator
prometheus-fastapi-instrumentator
pydantic-settings==2.13.0

View File

@@ -51,57 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
wait_msg = await message.answer("💾 Сохраняю персонажа...")
try:
# ВОТ ТУТ скачиваем файл (прямо перед сохранением)
# 1. Скачиваем файл (один раз)
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
file_io = await bot.download(file_id)
# photo_bytes = file_io.getvalue() # Получаем байты
# Создаем модель
file_bytes = file_io.read()
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
char = Character(
id=None,
name=name,
character_image_data=file_io.read(),
character_image_tg_id=None,
character_image_doc_tg_id=file_id,
character_bio=bio,
created_by=str(message.from_user.id)
)
file_io.close()
# Сохраняем через DAO
# Сохраняем, чтобы получить ID
await dao.chars.add_character(char)
file_info = await bot.get_file(char.character_image_doc_tg_id)
file_bytes = await bot.download_file(file_info.file_path)
file_io = file_bytes.read()
avatar_asset = await dao.assets.create_asset(
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
tg_doc_file_id=file_id))
char.avatar_image = avatar_asset.link
# 3. Создаем Asset (связанный с персонажем)
avatar_asset_id = await dao.assets.create_asset(
Asset(
name="avatar.png",
type=AssetType.UPLOADED,
content_type=AssetContentType.IMAGE,
linked_char_id=str(char.id),
data=file_bytes,
tg_doc_file_id=file_id
)
)
# 4. Обновляем персонажа ссылками на ассет
char.avatar_asset_id = avatar_asset_id
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
# Отправляем подтверждение
# Используем байты для отправки обратно
photo_msg = await message.answer_photo(
photo=BufferedInputFile(file_io,
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
photo=BufferedInputFile(file_bytes, filename="char.jpg"),
caption=(
"🎉 <b>Персонаж создан!</b>\n\n"
f"👤 <b>Имя:</b> {char.name}\n"
f"📝 <b>Био:</b> {char.character_bio}"
)
)
file_bytes.close()
char.character_image_tg_id = photo_msg.photo[0].file_id
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
char.character_image_tg_id = photo_msg.photo[-1].file_id
# Финальное обновление персонажа
await dao.chars.update_char(char.id, char)
await wait_msg.delete()
file_io.close()
# Сбрасываем состояние
await state.clear()
except Exception as e:
logging.error(e)
logger.error(f"Error creating character: {e}")
traceback.print_exc()
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
@router.message(Command("chars"))

View File

@@ -3,17 +3,17 @@ import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio
from config import settings
from main import app
from aiws import app
from api.endpoints.auth import get_current_user
from api.dependency import get_dao
from repos.dao import DAO
from models.Character import Character
# Config for test DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
MONGO_HOST = settings.MONGO_HOST
DB_NAME = "bot_db_test_chars"
# Mock User

View File

@@ -10,13 +10,13 @@ import json
import requests
import base64
import os
from dotenv import load_dotenv
from config import settings
load_dotenv()
# Load env is not needed as settings handles it
# Configuration
API_URL = "http://localhost:8090/api/generations/import"
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production"
# Sample generation data
generation_data = {

View File

@@ -10,11 +10,10 @@ from repos.dao import DAO
from models.Idea import Idea
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from config import settings
load_dotenv()
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")

View File

@@ -1,15 +1,14 @@
import asyncio
import os
from dotenv import load_dotenv
from config import settings
from adapters.s3_adapter import S3Adapter
async def test_s3():
load_dotenv()
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
access_key = os.getenv("MINIO_ACCESS_KEY")
secret_key = os.getenv("MINIO_SECRET_KEY")
bucket = os.getenv("MINIO_BUCKET")
endpoint = settings.MINIO_ENDPOINT
access_key = settings.MINIO_ACCESS_KEY
secret_key = settings.MINIO_SECRET_KEY
bucket = settings.MINIO_BUCKET
print(f"Connecting to {endpoint}, bucket: {bucket}")

View File

@@ -4,13 +4,11 @@ from datetime import datetime, timedelta, UTC
from motor.motor_asyncio import AsyncIOMotorClient
from models.Generation import Generation, GenerationStatus
from repos.generation_repo import GenerationRepo
from dotenv import load_dotenv
load_dotenv()
from config import settings
# Mock configs if not present in env
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")

View File

@@ -10,10 +10,11 @@ from repos.dao import DAO
from models.Album import Album
from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality
from config import settings
# Mock config
# Use the same host as aiws.py but different DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
MONGO_HOST = settings.MONGO_HOST
DB_NAME = "bot_db_test_albums"
async def test_albums():
@@ -83,8 +84,6 @@ async def test_albums():
client.close()
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
try:
asyncio.run(test_albums())
except Exception as e:

View File

@@ -1,29 +1,28 @@
import asyncio
import os
from datetime import datetime
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient
from config import settings
from models.Asset import Asset, AssetType
from repos.assets_repo import AssetsRepo
from adapters.s3_adapter import S3Adapter
# Load env to get credentials
load_dotenv()
# Load env is not needed as settings handles it
async def test_integration():
print("🚀 Starting integration test...")
# 1. Setup Dependencies
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
mongo_uri = settings.MONGO_HOST
client = AsyncIOMotorClient(mongo_uri)
db_name = os.getenv("DB_NAME", "bot_db_test")
db_name = settings.DB_NAME + "_test"
s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=settings.MINIO_SECRET_KEY,
bucket_name=settings.MINIO_BUCKET
)
repo = AssetsRepo(client, s3_adapter, db_name=db_name)

View File

@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
from jose import jwt
from passlib.context import CryptContext
from config import settings
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
# SECRET_KEY должен быть сложным и секретным в продакшене!
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
# Настройки безопасности берутся из config.py
SECRET_KEY = settings.SECRET_KEY
ALGORITHM = settings.ALGORITHM
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")