12 Commits

Author SHA1 Message Date
c7c27197c9 Merge pull request '+ env' (#4) from enviroments into main
Reviewed-on: #4
2026-02-19 18:32:51 +00:00
xds
5aa6391dc8 + env 2026-02-19 21:25:29 +03:00
xds
ffb0463fe0 os.getenv -> config.py 2026-02-19 15:28:04 +03:00
xds
dd0f8a1cb6 os.getenv -> config.py 2026-02-19 13:00:51 +03:00
xds
4af5134726 fixes 2026-02-18 17:06:17 +03:00
xds
7488665d04 fixes 2026-02-18 17:01:06 +03:00
xds
ecc88aca62 fixes 2026-02-18 16:53:28 +03:00
xds
70f50170fc fixes 2026-02-18 16:45:39 +03:00
xds
f4207fc4c1 fixes 2026-02-18 16:45:02 +03:00
xds
c50d2c8ad9 fixes 2026-02-18 16:44:04 +03:00
xds
4586daac38 fixes 2026-02-18 16:35:04 +03:00
198ac44960 Merge pull request 'feat: introduce post resource with full CRUD operations and generation linking.' (#3) from posts into main
Reviewed-on: #3
2026-02-17 12:54:47 +00:00
37 changed files with 802 additions and 199 deletions

View File

@@ -63,10 +63,12 @@ class S3Adapter:
response = await client.get_object(Bucket=self.bucket_name, Key=object_name) response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
# aioboto3 Body is an aiohttp StreamReader wrapper # aioboto3 Body is an aiohttp StreamReader wrapper
body = response['Body'] body = response['Body']
data = await body.read()
# Yield in chunks to avoid holding entire response in StreamingResponse buffer while True:
for i in range(0, len(data), chunk_size): chunk = await body.read(chunk_size)
yield data[i:i + chunk_size] if not chunk:
break
yield chunk
except ClientError as e: except ClientError as e:
print(f"Error streaming from S3: {e}") print(f"Error streaming from S3: {e}")
return return

27
aiws.py
View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from aiogram import Bot, Dispatcher, Router, F from aiogram import Bot, Dispatcher, Router, F
@@ -9,7 +8,6 @@ from aiogram.enums import ParseMode
from aiogram.filters import CommandStart, Command from aiogram.filters import CommandStart, Command
from aiogram.types import Message from aiogram.types import Message
from aiogram.fsm.storage.mongo import MongoStorage from aiogram.fsm.storage.mongo import MongoStorage
from dotenv import load_dotenv
from fastapi import FastAPI from fastapi import FastAPI
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from prometheus_client import Info from prometheus_client import Info
@@ -17,6 +15,7 @@ from starlette.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
# --- ИМПОРТЫ ПРОЕКТА --- # --- ИМПОРТЫ ПРОЕКТА ---
from config import settings
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
@@ -45,17 +44,18 @@ from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router from api.endpoints.post_router import router as post_api_router
from api.endpoints.environment_router import router as environment_api_router
load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# --- КОНФИГУРАЦИЯ --- # --- КОНФИГУРАЦИЯ ---
BOT_TOKEN = os.getenv("BOT_TOKEN") # Настройки теперь берутся из config.py
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") BOT_TOKEN = settings.BOT_TOKEN
GEMINI_API_KEY = settings.GEMINI_API_KEY
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017 MONGO_HOST = settings.MONGO_HOST
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных DB_NAME = settings.DB_NAME
ADMIN_ID = int(os.getenv("ADMIN_ID", 0)) ADMIN_ID = settings.ADMIN_ID
def setup_logging(): def setup_logging():
@@ -79,10 +79,10 @@ char_repo = CharacterRepo(mongo_client)
# S3 Adapter # S3 Adapter
s3_adapter = S3Adapter( s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"), endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"), aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"), aws_secret_access_key=settings.MINIO_SECRET_KEY,
bucket_name=os.getenv("MINIO_BUCKET", "ai-char") bucket_name=settings.MINIO_BUCKET
) )
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
@@ -221,6 +221,7 @@ app.include_router(api_album_router)
app.include_router(project_api_router) app.include_router(project_api_router)
app.include_router(idea_api_router) app.include_router(idea_api_router)
app.include_router(post_api_router) app.include_router(post_api_router)
app.include_router(environment_api_router)
# Prometheus Metrics (Instrument after all routers are added) # Prometheus Metrics (Instrument after all routers are added)
Instrumentator( Instrumentator(
@@ -259,7 +260,7 @@ if __name__ == "__main__":
async def main(): async def main():
# Создаем конфигурацию uvicorn вручную # Создаем конфигурацию uvicorn вручную
# loop="asyncio" заставляет использовать стандартный цикл # loop="asyncio" заставляет использовать стандартный цикл
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development") config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
server = uvicorn.Server(config) server = uvicorn.Server(config)
# Запускаем сервер (lifespan запустится внутри) # Запускаем сервер (lifespan запустится внутри)

View File

@@ -5,6 +5,8 @@ from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel from pydantic import BaseModel
from repos.user_repo import UsersRepo, UserStatus from repos.user_repo import UsersRepo, UserStatus
from api.dependency import get_dao
from repos.dao import DAO
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
from jose import JWTError, jwt from jose import JWTError, jwt
from starlette.requests import Request from starlette.requests import Request

View File

@@ -12,7 +12,7 @@ from starlette.requests import Request
from starlette.responses import Response, JSONResponse, StreamingResponse from starlette.responses import Response, JSONResponse, StreamingResponse
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models import AssetsResponse, AssetResponse
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao, get_mongo_client, get_s3_adapter from api.dependency import get_dao, get_mongo_client, get_s3_adapter
@@ -278,8 +278,7 @@ async def upload_asset(
type=asset.type.value if hasattr(asset.type, "value") else asset.type, type=asset.type.value if hasattr(asset.type, "value") else asset.type,
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type, content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
linked_char_id=asset.linked_char_id, linked_char_id=asset.linked_char_id,
created_at=asset.created_at, created_at=asset.created_at
url=asset.url
) )

View File

@@ -5,11 +5,11 @@ from pydantic import BaseModel
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse from api.models import GenerationRequest, GenerationResponse
from models.Asset import Asset from models.Asset import Asset
from models.Character import Character from models.Character import Character
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest from api.models import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao
@@ -24,8 +24,15 @@ router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[
@router.get("/", response_model=List[Character]) @router.get("/", response_model=List[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]: async def get_characters(
logger.info("get_characters called") request: Request,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
limit: int = 100,
offset: int = 0
) -> List[Character]:
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
user_id_filter = str(current_user["_id"]) user_id_filter = str(current_user["_id"])
if project_id: if project_id:
@@ -34,7 +41,12 @@ async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_
raise HTTPException(status_code=403, detail="Project access denied") raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None user_id_filter = None
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id) characters = await dao.chars.get_all_characters(
created_by=user_id_filter,
project_id=project_id,
limit=limit,
offset=offset
)
return characters return characters
@@ -178,10 +190,3 @@ async def delete_character(
raise HTTPException(status_code=500, detail="Failed to delete character") raise HTTPException(status_code=500, detail="Failed to delete character")
return return
@router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse:
logger.info(f"post_character_generation called. CharacterID: {character_id}")
generation_service = request.app.state.generation_service

View File

@@ -0,0 +1,180 @@
import logging
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from starlette import status
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
from models.Environment import Environment
from repos.dao import DAO
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied to character")
return character
@router.post("/", response_model=Environment)
async def create_environment(
env_req: EnvironmentCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
await check_character_access(env_req.character_id, current_user, dao)
# Verify assets exist if provided
if env_req.asset_ids:
for aid in env_req.asset_ids:
asset = await dao.assets.get_asset(aid)
if not asset:
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
new_env = Environment(**env_req.model_dump())
created_env = await dao.environments.create_env(new_env)
return created_env
@router.get("/character/{character_id}", response_model=List[Environment])
async def get_character_environments(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"Getting environments for character {character_id}")
await check_character_access(character_id, current_user, dao)
return await dao.environments.get_character_envs(character_id)
@router.get("/{env_id}", response_model=Environment)
async def get_environment(
env_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
return env
@router.put("/{env_id}", response_model=Environment)
async def update_environment(
env_id: str,
env_update: EnvironmentUpdate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
update_data = env_update.model_dump(exclude_unset=True)
if not update_data:
return env
success = await dao.environments.update_env(env_id, update_data)
if not success:
raise HTTPException(status_code=500, detail="Failed to update environment")
return await dao.environments.get_env(env_id)
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_environment(
env_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
success = await dao.environments.delete_env(env_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete environment")
return None
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
async def add_asset_to_environment(
env_id: str,
req: AssetToEnvironment,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
# Verify asset exists
asset = await dao.assets.get_asset(req.asset_id)
if not asset:
raise HTTPException(status_code=404, detail="Asset not found")
success = await dao.environments.add_asset(env_id, req.asset_id)
return {"success": success}
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
async def add_assets_batch_to_environment(
env_id: str,
req: AssetsToEnvironment,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
# Verify all assets exist
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
if len(assets) != len(req.asset_ids):
found_ids = {a.id for a in assets}
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
success = await dao.environments.add_assets(env_id, req.asset_ids)
return {"success": success}
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
async def remove_asset_from_environment(
env_id: str,
asset_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
env = await dao.environments.get_env(env_id)
if not env:
raise HTTPException(status_code=404, detail="Environment not found")
await check_character_access(env.character_id, current_user, dao)
success = await dao.environments.remove_asset(env_id, asset_id)
return {"success": success}

View File

@@ -1,25 +1,32 @@
import logging
import os
import json
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
from fastapi.params import Depends from fastapi.params import Depends
from starlette import status
from starlette.requests import Request from starlette.requests import Request
from api import service from config import settings
from api.dependency import get_generation_service, get_project_id, get_dao from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO from api.endpoints.auth import get_current_user
from api.models import (
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse GenerationResponse,
GenerationRequest,
GenerationsResponse,
PromptResponse,
PromptRequest,
GenerationGroupResponse,
FinancialReport,
ExternalGenerationRequest
)
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Generation import Generation from repos.dao import DAO
from utils.external_auth import verify_signature
from starlette import status
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user
router = APIRouter(prefix='/api/generations', tags=["Generation"]) router = APIRouter(prefix='/api/generations', tags=["Generation"])
@@ -68,6 +75,47 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.get("/usage", response_model=FinancialReport)
async def get_usage_report(
breakdown: Optional[str] = None, # "user" or "project"
generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)
) -> FinancialReport:
"""
Returns usage statistics (runs, tokens, cost) for the current user or project.
If project_id is provided, returns stats for that project.
Otherwise, returns stats for the current user.
"""
user_id_filter = str(current_user["_id"])
breakdown_by = None
if project_id:
# Permission check
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
if breakdown == "user":
breakdown_by = "created_by"
elif breakdown == "project":
breakdown_by = "project_id"
else:
# Default: Stats for current user
if breakdown == "project":
breakdown_by = "project_id"
elif breakdown == "user":
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
# No, if project_id is None, it's personal.
breakdown_by = "created_by"
return await generation_service.get_financial_report(
user_id=user_id_filter,
project_id=project_id,
breakdown_by=breakdown_by
)
@router.post("/_run", response_model=GenerationGroupResponse) @router.post("/_run", response_model=GenerationGroupResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service), generation_service: GenerationService = Depends(get_generation_service),
@@ -120,6 +168,14 @@ async def get_generation(generation_id: str,
logger.debug(f"get_generation called for ID: {generation_id}") logger.debug(f"get_generation called for ID: {generation_id}")
gen = await generation_service.get_generation(generation_id) gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]): if gen and gen.created_by != str(current_user["_id"]):
# Check project membership
is_member = False
if gen.project_id:
project = await generation_service.dao.projects.get_project(gen.project_id)
if project and str(current_user["_id"]) in project.members:
is_member = True
if not is_member:
raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=403, detail="Access denied")
return gen return gen
@@ -136,17 +192,13 @@ async def import_external_generation(
Import a generation from an external source. Import a generation from an external source.
Requires server-to-server authentication via HMAC signature. Requires server-to-server authentication via HMAC signature.
""" """
import os
from utils.external_auth import verify_signature
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
logger.info("import_external_generation called") logger.info("import_external_generation called")
# Get raw request body for signature verification # Get raw request body for signature verification
body = await request.body() body = await request.body()
# Verify signature # Verify signature
secret = os.getenv("EXTERNAL_API_SECRET") secret = settings.EXTERNAL_API_SECRET
if not secret: if not secret:
logger.error("EXTERNAL_API_SECRET not configured") logger.error("EXTERNAL_API_SECRET not configured")
raise HTTPException(status_code=500, detail="Server configuration error") raise HTTPException(status_code=500, detail="Server configuration error")
@@ -156,7 +208,6 @@ async def import_external_generation(
raise HTTPException(status_code=401, detail="Invalid signature") raise HTTPException(status_code=401, detail="Invalid signature")
# Parse request body # Parse request body
import json
try: try:
data = json.loads(body.decode('utf-8')) data = json.loads(body.decode('utf-8'))
external_gen = ExternalGenerationRequest(**data) external_gen = ExternalGenerationRequest(**data)

View File

@@ -5,7 +5,8 @@ from api.endpoints.auth import get_current_user
from api.service.idea_service import IdeaService from api.service.idea_service import IdeaService
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
from models.Idea import Idea from models.Idea import Idea
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse from api.models import GenerationResponse, GenerationsResponse
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
router = APIRouter(prefix="/api/ideas", tags=["ideas"]) router = APIRouter(prefix="/api/ideas", tags=["ideas"])

View File

@@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.service.post_service import PostService from api.service.post_service import PostService
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"]) router = APIRouter(prefix="/api/posts", tags=["posts"])

View File

@@ -1,4 +1,6 @@
from typing import List, Optional from typing import List, Optional
from bson import ObjectId
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel from pydantic import BaseModel
from api.dependency import get_dao from api.dependency import get_dao
@@ -12,14 +14,46 @@ class ProjectCreate(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
class ProjectMemberResponse(BaseModel):
id: str
username: str
class ProjectResponse(BaseModel): class ProjectResponse(BaseModel):
id: str id: str
name: str name: str
description: Optional[str] = None description: Optional[str] = None
owner_id: str owner_id: str
members: List[str] members: List[ProjectMemberResponse]
is_owner: bool = False is_owner: bool = False
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
member_responses = []
for member_id in project.members:
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
# But for Web users we might need to search by _id.
# Let's try to get user info.
# Since project.members contains strings (ObjectIds as strings), we search by _id.
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
if not user_doc and member_id.isdigit():
# Fallback for telegram IDs if they are stored as strings of digits
user_doc = await dao.users.get_user(int(member_id))
username = "unknown"
if user_doc:
username = user_doc.get("username", "unknown")
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
return ProjectResponse(
id=project.id,
name=project.name,
description=project.description,
owner_id=project.owner_id,
members=member_responses,
is_owner=(project.owner_id == current_user_id)
)
@router.post("", response_model=ProjectResponse) @router.post("", response_model=ProjectResponse)
async def create_project( async def create_project(
project_data: ProjectCreate, project_data: ProjectCreate,
@@ -34,27 +68,15 @@ async def create_project(
members=[user_id] members=[user_id]
) )
project_id = await dao.projects.create_project(new_project) project_id = await dao.projects.create_project(new_project)
new_project.id = project_id
# Add project to user's project list # Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one( await dao.users.collection.update_one(
{"_id": current_user["_id"]}, {"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}} {"$addToSet": {"project_ids": project_id}}
) )
return ProjectResponse( return await _get_project_response(new_project, user_id, dao)
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
@router.get("", response_model=List[ProjectResponse]) @router.get("", response_model=List[ProjectResponse])
async def get_my_projects( async def get_my_projects(
@@ -66,14 +88,7 @@ async def get_my_projects(
responses = [] responses = []
for p in projects: for p in projects:
responses.append(ProjectResponse( responses.append(await _get_project_response(p, user_id, dao))
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
return responses return responses
class MemberAdd(BaseModel): class MemberAdd(BaseModel):

View File

@@ -0,0 +1,22 @@
from typing import Optional, List
from pydantic import BaseModel, Field
class EnvironmentCreate(BaseModel):
character_id: str
name: str = Field(..., min_length=1)
description: Optional[str] = None
asset_ids: Optional[List[str]] = []
class EnvironmentUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1)
description: Optional[str] = None
class AssetToEnvironment(BaseModel):
asset_id: str
class AssetsToEnvironment(BaseModel):
asset_ids: List[str]

View File

@@ -0,0 +1,18 @@
from pydantic import BaseModel
from typing import List, Optional
class UsageStats(BaseModel):
total_runs: int
total_tokens: int
total_input_tokens: int
total_output_tokens: int
total_cost: float
class UsageByEntity(BaseModel):
entity_id: Optional[str] = None
stats: UsageStats
class FinancialReport(BaseModel):
summary: UsageStats
by_user: Optional[List[UsageByEntity]] = None
by_project: Optional[List[UsageByEntity]] = None

View File

@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
telegram_id: Optional[int] = None telegram_id: Optional[int] = None
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
environment_id: Optional[str] = None
project_id: Optional[str] = None project_id: Optional[str] = None
idea_id: Optional[str] = None idea_id: Optional[str] = None
count: int = Field(default=1, ge=1, le=10) count: int = Field(default=1, ge=1, le=10)

View File

@@ -0,0 +1,7 @@
from .AssetDTO import AssetResponse, AssetsResponse
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from .ExternalGenerationDTO import ExternalGenerationRequest
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest

View File

@@ -1,24 +1,26 @@
import asyncio import asyncio
import base64
import logging import logging
import random import random
import base64
from datetime import datetime, UTC from datetime import datetime, UTC
from typing import List, Optional, Tuple, Any, Dict from typing import List, Optional, Tuple, Any, Dict
from io import BytesIO
from uuid import uuid4 from uuid import uuid4
import httpx
import httpx
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile from aiogram.types import BufferedInputFile
from fastapi import HTTPException
from adapters.Exception import GoogleGenerationException from adapters.Exception import GoogleGenerationException
from adapters.google_adapter import GoogleAdapter from adapters.google_adapter import GoogleAdapter
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse from adapters.s3_adapter import S3Adapter
from api.models import FinancialReport, UsageStats, UsageByEntity
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
# Импортируйте ваши модели DAO, Asset, Generation корректно # Импортируйте ваши модели DAO, Asset, Generation корректно
from models.Asset import Asset, AssetType, AssetContentType from models.Asset import Asset, AssetType, AssetContentType
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality, GenType from models.enums import AspectRatios, Quality
from repos.dao import DAO from repos.dao import DAO
from adapters.s3_adapter import S3Adapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -132,6 +134,9 @@ class GenerationService:
gen_id = None gen_id = None
generation_model = None generation_model = None
if generation_request.environment_id and not generation_request.linked_character_id:
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
try: try:
generation_model = Generation(**generation_request.model_dump(exclude={'count'})) generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
if user_id: if user_id:
@@ -185,43 +190,38 @@ class GenerationService:
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
# 2. Получаем ассеты-референсы (если они есть) # 2. Получаем ассеты-референсы (если они есть)
reference_assets: List[Asset] = []
media_group_bytes: List[bytes] = [] media_group_bytes: List[bytes] = []
generation_prompt = generation.prompt generation_prompt = generation.prompt
# generation_prompt = f"""
# Create detailed image of character in scene. # 2.1 Аватар персонажа (всегда первый, если включен)
# SCENE DESCRIPTION: {generation.prompt}
# Rules:
# - Integrate the character's appearance naturally into the scene description.
# - Focus on lighting, texture, and composition.
# """
if generation.linked_character_id is not None: if generation.linked_character_id is not None:
char_info = await self.dao.chars.get_character(generation.linked_character_id) char_info = await self.dao.chars.get_character(generation.linked_character_id)
if char_info is None: if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found") raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
if char_info.avatar_asset_id is not None: if generation.use_profile_image and char_info.avatar_asset_id:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset and avatar_asset.data: if avatar_asset:
media_group_bytes.append(avatar_asset.data) img_data = await self._get_asset_data(avatar_asset)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") if img_data:
media_group_bytes.append(img_data)
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) # 2.2 Явно указанные ассеты
if generation.assets_list:
# Извлекаем данные (bytes) из ассетов для отправки в Gemini explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
for asset in reference_assets: for asset in explicit_assets:
if asset.content_type != AssetContentType.IMAGE: ref_asset_data = await self._get_asset_data(asset)
continue if ref_asset_data:
media_group_bytes.append(ref_asset_data)
img_data = None
if asset.minio_object_name:
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
elif asset.data:
img_data = asset.data
# 2.3 Ассеты из окружения (в самый конец)
if generation.environment_id:
env = await self.dao.environments.get_env(generation.environment_id)
if env and env.asset_ids:
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
for asset in env_assets:
img_data = await self._get_asset_data(asset)
if img_data: if img_data:
media_group_bytes.append(img_data) media_group_bytes.append(img_data)
@@ -340,6 +340,14 @@ class GenerationService:
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}") logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
if asset.content_type != AssetContentType.IMAGE:
return None
if asset.minio_object_name:
return await self.s3_adapter.get_file(asset.minio_object_name)
return asset.data
async def _simulate_progress(self, generation: Generation): async def _simulate_progress(self, generation: Generation):
""" """
Increments progress from 0 to 90 over ~20 seconds. Increments progress from 0 to 90 over ~20 seconds.
@@ -377,7 +385,6 @@ class GenerationService:
Returns: Returns:
Created Generation object Created Generation object
""" """
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
# Validate image source # Validate image source
external_gen.validate_image_source() external_gen.validate_image_source()
@@ -507,3 +514,28 @@ class GenerationService:
except Exception as e: except Exception as e:
logger.error(f"Error during old data cleanup: {e}") logger.error(f"Error during old data cleanup: {e}")
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
"""
Generates a financial usage report for a specific user or project.
'breakdown_by' can be 'created_by' or 'project_id'.
"""
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
summary = UsageStats(**summary_data)
by_user = None
by_project = None
if breakdown_by == "created_by":
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
by_user = [UsageByEntity(**item) for item in res]
if breakdown_by == "project_id":
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
by_project = [UsageByEntity(**item) for item in res]
return FinancialReport(
summary=summary,
by_user=by_user,
by_project=by_project
)

39
config.py Normal file
View File

@@ -0,0 +1,39 @@
import os
from typing import Optional
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
# Telegram Bot
BOT_TOKEN: str
ADMIN_ID: int = 0
# AI Service
GEMINI_API_KEY: str
# Database
MONGO_HOST: str = "mongodb://localhost:27017"
DB_NAME: str = "my_bot_db"
# S3 Storage (Minio)
MINIO_ENDPOINT: str = "http://localhost:9000"
MINIO_ACCESS_KEY: str = "minioadmin"
MINIO_SECRET_KEY: str = "minioadmin"
MINIO_BUCKET: str = "ai-char"
# External API
EXTERNAL_API_SECRET: Optional[str] = None
# JWT Security
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
model_config = SettingsConfigDict(
env_file=os.getenv("ENV_FILE", ".env"),
env_file_encoding="utf-8",
extra="ignore"
)
settings = Settings()

View File

@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
# Ждем сбора остальных частей # Ждем сбора остальных частей
await asyncio.sleep(self.latency) await asyncio.sleep(self.latency)
# Проверяем, что ключ все еще существует (на всякий случай) # Проверяем, что ключ все еще существует
if group_id in self.album_data: if group_id in self.album_data:
# Передаем собранный альбом в хендлер # Передаем собранный альбом в хендлер
# Сортируем по message_id, чтобы порядок был верным # Сортируем по message_id, чтобы порядок был верным
self.album_data[group_id].sort(key=lambda x: x.message_id) current_album = self.album_data[group_id]
data["album"] = self.album_data[group_id] current_album.sort(key=lambda x: x.message_id)
data["album"] = current_album
return await handler(event, data) return await handler(event, data)
finally: finally:
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись # ЧИСТКА: Удаляем запись после обработки или таймаута
# Проверяем, что мы удаляем именно то, что создали, и ключ существует # Используем pop() с дефолтом, чтобы избежать KeyError
if group_id in self.album_data and self.album_data[group_id][0] == event: self.album_data.pop(group_id, None)
del self.album_data[group_id]
else: else:
# Если группа уже собирается - просто добавляем и выходим # Если группа уже собирается - просто добавляем и выходим

View File

@@ -63,6 +63,7 @@ class Asset(BaseModel):
# --- CALCULATED FIELD --- # --- CALCULATED FIELD ---
@computed_field @computed_field
@property
def url(self) -> str: def url(self) -> str:
""" """
Это поле автоматически вычислится и попадет в model_dump() / .json() Это поле автоматически вычислится и попадет в model_dump() / .json()

View File

@@ -9,7 +9,6 @@ class Character(BaseModel):
name: str name: str
avatar_asset_id: Optional[str] = None avatar_asset_id: Optional[str] = None
avatar_image: Optional[str] = None avatar_image: Optional[str] = None
character_image_data: Optional[bytes] = None
character_image_doc_tg_id: Optional[str] = None character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: Optional[str] = None character_image_tg_id: Optional[str] = None
character_bio: Optional[str] = None character_bio: Optional[str] = None

20
models/Environment.py Normal file
View File

@@ -0,0 +1,20 @@
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from datetime import datetime
from bson import ObjectId
class Environment(BaseModel):
id: Optional[str] = Field(None, alias="_id")
character_id: str
name: str = Field(..., min_length=1)
description: Optional[str] = None
asset_ids: List[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
model_config = ConfigDict(
populate_by_name=True,
json_encoders={ObjectId: str},
arbitrary_types_allowed=True
)

View File

@@ -35,6 +35,7 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None album_id: Optional[str] = None
environment_id: Optional[str] = None
generation_group_id: Optional[str] = None generation_group_id: Optional[str] = None
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None project_id: Optional[str] = None

View File

@@ -175,6 +175,8 @@ class AssetsRepo:
filter["linked_char_id"] = character_id filter["linked_char_id"] = character_id
if created_by: if created_by:
filter["created_by"] = created_by filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id: if project_id:
filter["project_id"] = project_id filter["project_id"] = project_id
return await self.collection.count_documents(filter) return await self.collection.count_documents(filter)

View File

@@ -15,26 +15,24 @@ class CharacterRepo:
character.id = str(op.inserted_id) character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None: async def get_character(self, character_id: str) -> Character | None:
args = {} res = await self.collection.find_one({"_id": ObjectId(character_id)})
if not with_image_data:
args["character_image_data"] = 0
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
if res is None: if res is None:
return None return None
else: else:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
filter = {} filter = {}
if created_by: if created_by:
filter["created_by"] = created_by filter["created_by"] = created_by
if project_id is None:
filter["project_id"] = None
if project_id: if project_id:
filter["project_id"] = project_id filter["project_id"] = project_id
args = {"character_image_data": 0} # don't return image data for list res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
res = await self.collection.find(filter, args).to_list(None)
chars = [] chars = []
for doc in res: for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))

View File

@@ -8,6 +8,7 @@ from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo from repos.post_repo import PostRepo
from repos.environment_repo import EnvironmentRepo
from typing import Optional from typing import Optional
@@ -23,3 +24,4 @@ class DAO:
self.users = UsersRepo(client, db_name) self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name) self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name) self.posts = PostRepo(client, db_name)
self.environments = EnvironmentRepo(client, db_name)

73
repos/environment_repo.py Normal file
View File

@@ -0,0 +1,73 @@
from typing import List, Optional
from datetime import datetime
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Environment import Environment
class EnvironmentRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["environments"]
async def create_env(self, env: Environment) -> Environment:
env_dict = env.model_dump(exclude={"id"})
res = await self.collection.insert_one(env_dict)
env.id = str(res.inserted_id)
return env
async def get_env(self, env_id: str) -> Optional[Environment]:
res = await self.collection.find_one({"_id": ObjectId(env_id)})
if not res:
return None
res["id"] = str(res.pop("_id"))
return Environment(**res)
async def get_character_envs(self, character_id: str) -> List[Environment]:
cursor = self.collection.find({"character_id": character_id})
envs = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
envs.append(Environment(**doc))
return envs
async def update_env(self, env_id: str, update_data: dict) -> bool:
update_data["updated_at"] = datetime.utcnow()
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{"$set": update_data}
)
return res.modified_count > 0
async def delete_env(self, env_id: str) -> bool:
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
return res.deleted_count > 0
async def add_asset(self, env_id: str, asset_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$addToSet": {"asset_ids": asset_id},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$addToSet": {"asset_ids": {"$each": asset_ids}},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(env_id)},
{
"$pull": {"asset_ids": asset_id},
"$set": {"updated_at": datetime.utcnow()}
}
)
return res.modified_count > 0

View File

@@ -65,6 +65,8 @@ class GenerationRepo:
args["status"] = status args["status"] = status
if created_by is not None: if created_by is not None:
args["created_by"] = created_by args["created_by"] = created_by
if project_id is None:
args["project_id"] = None
if project_id is not None: if project_id is not None:
args["project_id"] = project_id args["project_id"] = project_id
if idea_id is not None: if idea_id is not None:
@@ -92,6 +94,123 @@ class GenerationRepo:
async def update_generation(self, generation: Generation, ): async def update_generation(self, generation: Generation, ):
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
"""
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
Includes even soft-deleted generations to reflect actual expenditure.
"""
pipeline = []
# 1. Match all done generations (including soft-deleted)
match_stage = {"status": GenerationStatus.DONE}
if created_by:
match_stage["created_by"] = created_by
if project_id:
match_stage["project_id"] = project_id
pipeline.append({"$match": match_stage})
# 2. Group by null (total)
pipeline.append({
"$group": {
"_id": None,
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(1)
if not res:
return {
"total_runs": 0,
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cost": 0.0
}
result = res[0]
result.pop("_id")
result["total_cost"] = round(result["total_cost"], 4)
return result
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
"""
Returns usage statistics grouped by user or project.
Includes even soft-deleted generations to reflect actual expenditure.
"""
pipeline = []
match_stage = {"status": GenerationStatus.DONE}
if project_id:
match_stage["project_id"] = project_id
if created_by:
match_stage["created_by"] = created_by
pipeline.append({"$match": match_stage})
pipeline.append({
"$group": {
"_id": f"${group_by}",
"total_runs": {"$sum": 1},
"total_tokens": {
"$sum": {
"$cond": [
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
{"$add": ["$input_token_usage", "$output_token_usage"]},
{"$ifNull": ["$token_usage", 0]}
]
}
},
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
"total_cost": {
"$sum": {
"$add": [
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
]
}
}
}
})
pipeline.append({"$sort": {"total_cost": -1}})
cursor = self.collection.aggregate(pipeline)
res = await cursor.to_list(None)
results = []
for item in res:
entity_id = item.pop("_id")
item["total_cost"] = round(item["total_cost"], 4)
results.append({
"entity_id": str(entity_id) if entity_id else "unknown",
"stats": item
})
return results
async def get_generations_by_group(self, group_id: str) -> List[Generation]: async def get_generations_by_group(self, group_id: str) -> List[Generation]:
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []

View File

@@ -39,8 +39,17 @@ class IdeaRepo:
"from": "generations", "from": "generations",
"let": {"idea_id": "$str_id"}, "let": {"idea_id": "$str_id"},
"pipeline": [ "pipeline": [
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}}, {
{"$sort": {"created_at": -1}}, # Ensure we get the latest "$match": {
"$and": [
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
{"status": "done"},
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
{"is_deleted": False}
]
}
},
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
{"$limit": 1} {"$limit": 1}
], ],
"as": "generations" "as": "generations"

View File

@@ -51,3 +51,4 @@ python-jose[cryptography]==3.3.0
python-multipart==0.0.22 python-multipart==0.0.22
email-validator email-validator
prometheus-fastapi-instrumentator prometheus-fastapi-instrumentator
pydantic-settings==2.13.0

View File

@@ -51,57 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
wait_msg = await message.answer("💾 Сохраняю персонажа...") wait_msg = await message.answer("💾 Сохраняю персонажа...")
try: try:
# ВОТ ТУТ скачиваем файл (прямо перед сохранением) # 1. Скачиваем файл (один раз)
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
file_io = await bot.download(file_id) file_io = await bot.download(file_id)
# photo_bytes = file_io.getvalue() # Получаем байты file_bytes = file_io.read()
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
# Создаем модель
char = Character( char = Character(
id=None, id=None,
name=name, name=name,
character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio, character_bio=bio,
created_by=str(message.from_user.id) created_by=str(message.from_user.id)
) )
file_io.close()
# Сохраняем через DAO
# Сохраняем, чтобы получить ID
await dao.chars.add_character(char) await dao.chars.add_character(char)
file_info = await bot.get_file(char.character_image_doc_tg_id)
file_bytes = await bot.download_file(file_info.file_path) # 3. Создаем Asset (связанный с персонажем)
file_io = file_bytes.read() avatar_asset_id = await dao.assets.create_asset(
avatar_asset = await dao.assets.create_asset( Asset(
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io, name="avatar.png",
tg_doc_file_id=file_id)) type=AssetType.UPLOADED,
char.avatar_image = avatar_asset.link content_type=AssetContentType.IMAGE,
linked_char_id=str(char.id),
data=file_bytes,
tg_doc_file_id=file_id
)
)
# 4. Обновляем персонажа ссылками на ассет
char.avatar_asset_id = avatar_asset_id
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
# Отправляем подтверждение # Отправляем подтверждение
# Используем байты для отправки обратно
photo_msg = await message.answer_photo( photo_msg = await message.answer_photo(
photo=BufferedInputFile(file_io, photo=BufferedInputFile(file_bytes, filename="char.jpg"),
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
caption=( caption=(
"🎉 <b>Персонаж создан!</b>\n\n" "🎉 <b>Персонаж создан!</b>\n\n"
f"👤 <b>Имя:</b> {char.name}\n" f"👤 <b>Имя:</b> {char.name}\n"
f"📝 <b>Био:</b> {char.character_bio}" f"📝 <b>Био:</b> {char.character_bio}"
) )
) )
file_bytes.close()
char.character_image_tg_id = photo_msg.photo[0].file_id
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
char.character_image_tg_id = photo_msg.photo[-1].file_id
# Финальное обновление персонажа
await dao.chars.update_char(char.id, char) await dao.chars.update_char(char.id, char)
await wait_msg.delete() await wait_msg.delete()
file_io.close()
# Сбрасываем состояние # Сбрасываем состояние
await state.clear() await state.clear()
except Exception as e: except Exception as e:
logging.error(e) logger.error(f"Error creating character: {e}")
traceback.print_exc()
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}") await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
@router.message(Command("chars")) @router.message(Command("chars"))

View File

@@ -3,17 +3,17 @@ import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from unittest.mock import MagicMock from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio import asyncio
from config import settings
from main import app from aiws import app
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_dao from api.dependency import get_dao
from repos.dao import DAO from repos.dao import DAO
from models.Character import Character from models.Character import Character
# Config for test DB # Config for test DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017") MONGO_HOST = settings.MONGO_HOST
DB_NAME = "bot_db_test_chars" DB_NAME = "bot_db_test_chars"
# Mock User # Mock User

View File

@@ -10,13 +10,13 @@ import json
import requests import requests
import base64 import base64
import os import os
from dotenv import load_dotenv from config import settings
load_dotenv() # Load env is not needed as settings handles it
# Configuration # Configuration
API_URL = "http://localhost:8090/api/generations/import" API_URL = "http://localhost:8090/api/generations/import"
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production") SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production"
# Sample generation data # Sample generation data
generation_data = { generation_data = {

View File

@@ -10,11 +10,10 @@ from repos.dao import DAO
from models.Idea import Idea from models.Idea import Idea
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality
from config import settings
load_dotenv() MONGO_HOST = settings.MONGO_HOST
DB_NAME = settings.DB_NAME
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
DB_NAME = os.getenv("DB_NAME", "bot_db")
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}") print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")

View File

@@ -1,15 +1,14 @@
import asyncio import asyncio
import os import os
from dotenv import load_dotenv from config import settings
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
async def test_s3(): async def test_s3():
load_dotenv()
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000") endpoint = settings.MINIO_ENDPOINT
access_key = os.getenv("MINIO_ACCESS_KEY") access_key = settings.MINIO_ACCESS_KEY
secret_key = os.getenv("MINIO_SECRET_KEY") secret_key = settings.MINIO_SECRET_KEY
bucket = os.getenv("MINIO_BUCKET") bucket = settings.MINIO_BUCKET
print(f"Connecting to {endpoint}, bucket: {bucket}") print(f"Connecting to {endpoint}, bucket: {bucket}")

View File

@@ -4,13 +4,11 @@ from datetime import datetime, timedelta, UTC
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from repos.generation_repo import GenerationRepo from repos.generation_repo import GenerationRepo
from dotenv import load_dotenv from config import settings
load_dotenv()
# Mock configs if not present in env # Mock configs if not present in env
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017") MONGO_HOST = settings.MONGO_HOST
DB_NAME = os.getenv("DB_NAME", "bot_db") DB_NAME = settings.DB_NAME
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}") print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")

View File

@@ -10,10 +10,11 @@ from repos.dao import DAO
from models.Album import Album from models.Album import Album
from models.Generation import Generation, GenerationStatus from models.Generation import Generation, GenerationStatus
from models.enums import AspectRatios, Quality from models.enums import AspectRatios, Quality
from config import settings
# Mock config # Mock config
# Use the same host as aiws.py but different DB # Use the same host as aiws.py but different DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017") MONGO_HOST = settings.MONGO_HOST
DB_NAME = "bot_db_test_albums" DB_NAME = "bot_db_test_albums"
async def test_albums(): async def test_albums():
@@ -83,8 +84,6 @@ async def test_albums():
client.close() client.close()
if __name__ == "__main__": if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
try: try:
asyncio.run(test_albums()) asyncio.run(test_albums())
except Exception as e: except Exception as e:

View File

@@ -1,29 +1,28 @@
import asyncio import asyncio
import os import os
from datetime import datetime from datetime import datetime
from dotenv import load_dotenv
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
from config import settings
from models.Asset import Asset, AssetType from models.Asset import Asset, AssetType
from repos.assets_repo import AssetsRepo from repos.assets_repo import AssetsRepo
from adapters.s3_adapter import S3Adapter from adapters.s3_adapter import S3Adapter
# Load env to get credentials # Load env is not needed as settings handles it
load_dotenv()
async def test_integration(): async def test_integration():
print("🚀 Starting integration test...") print("🚀 Starting integration test...")
# 1. Setup Dependencies # 1. Setup Dependencies
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017") mongo_uri = settings.MONGO_HOST
client = AsyncIOMotorClient(mongo_uri) client = AsyncIOMotorClient(mongo_uri)
db_name = os.getenv("DB_NAME", "bot_db_test") db_name = settings.DB_NAME + "_test"
s3_adapter = S3Adapter( s3_adapter = S3Adapter(
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"), endpoint_url=settings.MINIO_ENDPOINT,
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"), aws_access_key_id=settings.MINIO_ACCESS_KEY,
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"), aws_secret_access_key=settings.MINIO_SECRET_KEY,
bucket_name=os.getenv("MINIO_BUCKET", "ai-char") bucket_name=settings.MINIO_BUCKET
) )
repo = AssetsRepo(client, s3_adapter, db_name=db_name) repo = AssetsRepo(client, s3_adapter, db_name=db_name)

View File

@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
from jose import jwt from jose import jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from config import settings
# Настройки безопасности (лучше вынести в config/env, но для старта здесь) # Настройки безопасности берутся из config.py
# SECRET_KEY должен быть сложным и секретным в продакшене! SECRET_KEY = settings.SECRET_KEY
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY" ALGORITHM = settings.ALGORITHM
ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")