Compare commits
12 Commits
c93e577bcf
...
enviroment
| Author | SHA1 | Date | |
|---|---|---|---|
| 5aa6391dc8 | |||
| ffb0463fe0 | |||
| dd0f8a1cb6 | |||
| 4af5134726 | |||
| 7488665d04 | |||
| ecc88aca62 | |||
| 70f50170fc | |||
| f4207fc4c1 | |||
| c50d2c8ad9 | |||
| 4586daac38 | |||
| 198ac44960 | |||
| d820d9145b |
@@ -63,10 +63,12 @@ class S3Adapter:
|
||||
response = await client.get_object(Bucket=self.bucket_name, Key=object_name)
|
||||
# aioboto3 Body is an aiohttp StreamReader wrapper
|
||||
body = response['Body']
|
||||
data = await body.read()
|
||||
# Yield in chunks to avoid holding entire response in StreamingResponse buffer
|
||||
for i in range(0, len(data), chunk_size):
|
||||
yield data[i:i + chunk_size]
|
||||
|
||||
while True:
|
||||
chunk = await body.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
except ClientError as e:
|
||||
print(f"Error streaming from S3: {e}")
|
||||
return
|
||||
|
||||
29
aiws.py
29
aiws.py
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from aiogram import Bot, Dispatcher, Router, F
|
||||
@@ -9,7 +8,6 @@ from aiogram.enums import ParseMode
|
||||
from aiogram.filters import CommandStart, Command
|
||||
from aiogram.types import Message
|
||||
from aiogram.fsm.storage.mongo import MongoStorage
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from prometheus_client import Info
|
||||
@@ -17,6 +15,7 @@ from starlette.middleware.cors import CORSMiddleware
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
# --- ИМПОРТЫ ПРОЕКТА ---
|
||||
from config import settings
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.service.generation_service import GenerationService
|
||||
@@ -44,17 +43,19 @@ from api.endpoints.admin import router as api_admin_router
|
||||
from api.endpoints.album_router import router as api_album_router
|
||||
from api.endpoints.project_router import router as project_api_router
|
||||
from api.endpoints.idea_router import router as idea_api_router
|
||||
from api.endpoints.post_router import router as post_api_router
|
||||
from api.endpoints.environment_router import router as environment_api_router
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- КОНФИГУРАЦИЯ ---
|
||||
BOT_TOKEN = os.getenv("BOT_TOKEN")
|
||||
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
||||
# Настройки теперь берутся из config.py
|
||||
BOT_TOKEN = settings.BOT_TOKEN
|
||||
GEMINI_API_KEY = settings.GEMINI_API_KEY
|
||||
|
||||
MONGO_HOST = os.getenv("MONGO_HOST") # Например: mongodb://localhost:27017
|
||||
DB_NAME = os.getenv("DB_NAME", "my_bot_db") # Имя базы данных
|
||||
ADMIN_ID = int(os.getenv("ADMIN_ID", 0))
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
ADMIN_ID = settings.ADMIN_ID
|
||||
|
||||
|
||||
def setup_logging():
|
||||
@@ -78,10 +79,10 @@ char_repo = CharacterRepo(mongo_client)
|
||||
|
||||
# S3 Adapter
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://31.59.58.220:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "minioadmin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "minioadmin"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||
@@ -219,6 +220,8 @@ app.include_router(api_gen_router)
|
||||
app.include_router(api_album_router)
|
||||
app.include_router(project_api_router)
|
||||
app.include_router(idea_api_router)
|
||||
app.include_router(post_api_router)
|
||||
app.include_router(environment_api_router)
|
||||
|
||||
# Prometheus Metrics (Instrument after all routers are added)
|
||||
Instrumentator(
|
||||
@@ -257,7 +260,7 @@ if __name__ == "__main__":
|
||||
async def main():
|
||||
# Создаем конфигурацию uvicorn вручную
|
||||
# loop="asyncio" заставляет использовать стандартный цикл
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120, env_file=".env.development")
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8090, loop="asyncio", timeout_keep_alive=120)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Запускаем сервер (lifespan запустится внутри)
|
||||
|
||||
Binary file not shown.
@@ -58,3 +58,8 @@ async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Pro
|
||||
|
||||
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
|
||||
return AlbumService(dao)
|
||||
|
||||
from api.service.post_service import PostService
|
||||
|
||||
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
|
||||
return PostService(dao)
|
||||
@@ -5,6 +5,8 @@ from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
from repos.user_repo import UsersRepo, UserStatus
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||
from jose import JWTError, jwt
|
||||
from starlette.requests import Request
|
||||
|
||||
@@ -12,7 +12,7 @@ from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse, StreamingResponse
|
||||
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao, get_mongo_client, get_s3_adapter
|
||||
@@ -278,8 +278,7 @@ async def upload_asset(
|
||||
type=asset.type.value if hasattr(asset.type, "value") else asset.type,
|
||||
content_type=asset.content_type.value if hasattr(asset.content_type, "value") else asset.content_type,
|
||||
linked_char_id=asset.linked_char_id,
|
||||
created_at=asset.created_at,
|
||||
url=asset.url
|
||||
created_at=asset.created_at
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
|
||||
from api.models.AssetDTO import AssetsResponse, AssetResponse
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse
|
||||
from api.models import AssetsResponse, AssetResponse
|
||||
from api.models import GenerationRequest, GenerationResponse
|
||||
from models.Asset import Asset
|
||||
from models.Character import Character
|
||||
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from api.models import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from repos.dao import DAO
|
||||
from api.dependency import get_dao
|
||||
|
||||
@@ -24,8 +24,15 @@ router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Character])
|
||||
async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
|
||||
logger.info("get_characters called")
|
||||
async def get_characters(
|
||||
request: Request,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 100,
|
||||
offset: int = 0
|
||||
) -> List[Character]:
|
||||
logger.info(f"get_characters called. Limit: {limit}, Offset: {offset}")
|
||||
|
||||
user_id_filter = str(current_user["_id"])
|
||||
if project_id:
|
||||
@@ -34,7 +41,12 @@ async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None
|
||||
|
||||
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
|
||||
characters = await dao.chars.get_all_characters(
|
||||
created_by=user_id_filter,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
return characters
|
||||
|
||||
|
||||
@@ -178,10 +190,3 @@ async def delete_character(
|
||||
raise HTTPException(status_code=500, detail="Failed to delete character")
|
||||
|
||||
return
|
||||
|
||||
|
||||
@router.post("/{character_id}/_run", response_model=GenerationResponse)
|
||||
async def post_character_generation(character_id: str, generation: GenerationRequest,
|
||||
request: Request) -> GenerationResponse:
|
||||
logger.info(f"post_character_generation called. CharacterID: {character_id}")
|
||||
generation_service = request.app.state.generation_service
|
||||
|
||||
180
api/endpoints/environment_router.py
Normal file
180
api/endpoints/environment_router.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||
from models.Environment import Environment
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/", response_model=Environment)
|
||||
async def create_environment(
|
||||
env_req: EnvironmentCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||
await check_character_access(env_req.character_id, current_user, dao)
|
||||
|
||||
# Verify assets exist if provided
|
||||
if env_req.asset_ids:
|
||||
for aid in env_req.asset_ids:
|
||||
asset = await dao.assets.get_asset(aid)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||
|
||||
new_env = Environment(**env_req.model_dump())
|
||||
created_env = await dao.environments.create_env(new_env)
|
||||
return created_env
|
||||
|
||||
|
||||
@router.get("/character/{character_id}", response_model=List[Environment])
|
||||
async def get_character_environments(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Getting environments for character {character_id}")
|
||||
await check_character_access(character_id, current_user, dao)
|
||||
return await dao.environments.get_character_envs(character_id)
|
||||
|
||||
|
||||
@router.get("/{env_id}", response_model=Environment)
|
||||
async def get_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
return env
|
||||
|
||||
|
||||
@router.put("/{env_id}", response_model=Environment)
|
||||
async def update_environment(
|
||||
env_id: str,
|
||||
env_update: EnvironmentUpdate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
update_data = env_update.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
return env
|
||||
|
||||
success = await dao.environments.update_env(env_id, update_data)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||
|
||||
return await dao.environments.get_env(env_id)
|
||||
|
||||
|
||||
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.delete_env(env_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||
async def add_asset_to_environment(
|
||||
env_id: str,
|
||||
req: AssetToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify asset exists
|
||||
asset = await dao.assets.get_asset(req.asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||
async def add_assets_batch_to_environment(
|
||||
env_id: str,
|
||||
req: AssetsToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify all assets exist
|
||||
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||
if len(assets) != len(req.asset_ids):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||
async def remove_asset_from_environment(
|
||||
env_id: str,
|
||||
asset_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||
return {"success": success}
|
||||
@@ -1,25 +1,32 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from starlette import status
|
||||
from starlette.requests import Request
|
||||
|
||||
from api import service
|
||||
from config import settings
|
||||
from api.dependency import get_generation_service, get_project_id, get_dao
|
||||
from repos.dao import DAO
|
||||
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models import (
|
||||
GenerationResponse,
|
||||
GenerationRequest,
|
||||
GenerationsResponse,
|
||||
PromptResponse,
|
||||
PromptRequest,
|
||||
GenerationGroupResponse,
|
||||
FinancialReport,
|
||||
ExternalGenerationRequest
|
||||
)
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Generation import Generation
|
||||
|
||||
from starlette import status
|
||||
|
||||
import logging
|
||||
from repos.dao import DAO
|
||||
from utils.external_auth import verify_signature
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
@@ -68,6 +75,47 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o
|
||||
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.get("/usage", response_model=FinancialReport)
|
||||
async def get_usage_report(
|
||||
breakdown: Optional[str] = None, # "user" or "project"
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
dao: DAO = Depends(get_dao)
|
||||
) -> FinancialReport:
|
||||
"""
|
||||
Returns usage statistics (runs, tokens, cost) for the current user or project.
|
||||
If project_id is provided, returns stats for that project.
|
||||
Otherwise, returns stats for the current user.
|
||||
"""
|
||||
user_id_filter = str(current_user["_id"])
|
||||
breakdown_by = None
|
||||
|
||||
if project_id:
|
||||
# Permission check
|
||||
project = await dao.projects.get_project(project_id)
|
||||
if not project or str(current_user["_id"]) not in project.members:
|
||||
raise HTTPException(status_code=403, detail="Project access denied")
|
||||
user_id_filter = None # If we are in project, we see stats for the WHOLE project by default
|
||||
if breakdown == "user":
|
||||
breakdown_by = "created_by"
|
||||
elif breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
else:
|
||||
# Default: Stats for current user
|
||||
if breakdown == "project":
|
||||
breakdown_by = "project_id"
|
||||
elif breakdown == "user":
|
||||
# This would breakdown personal usage by user (yourself), but could be useful if it included collaborators?
|
||||
# No, if project_id is None, it's personal.
|
||||
breakdown_by = "created_by"
|
||||
|
||||
return await generation_service.get_financial_report(
|
||||
user_id=user_id_filter,
|
||||
project_id=project_id,
|
||||
breakdown_by=breakdown_by
|
||||
)
|
||||
|
||||
@router.post("/_run", response_model=GenerationGroupResponse)
|
||||
async def post_generation(generation: GenerationRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
@@ -120,7 +168,15 @@ async def get_generation(generation_id: str,
|
||||
logger.debug(f"get_generation called for ID: {generation_id}")
|
||||
gen = await generation_service.get_generation(generation_id)
|
||||
if gen and gen.created_by != str(current_user["_id"]):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
# Check project membership
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return gen
|
||||
|
||||
|
||||
@@ -136,17 +192,13 @@ async def import_external_generation(
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
import os
|
||||
from utils.external_auth import verify_signature
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Verify signature
|
||||
secret = os.getenv("EXTERNAL_API_SECRET")
|
||||
secret = settings.EXTERNAL_API_SECRET
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
@@ -156,7 +208,6 @@ async def import_external_generation(
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
|
||||
@@ -5,7 +5,8 @@ from api.endpoints.auth import get_current_user
|
||||
from api.service.idea_service import IdeaService
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Idea import Idea
|
||||
from api.models.GenerationRequest import GenerationResponse, GenerationsResponse
|
||||
from api.models import GenerationResponse, GenerationsResponse
|
||||
from api.models import IdeaRequest, PostRequest # Adjusting for general model usage
|
||||
from api.models.IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
|
||||
router = APIRouter(prefix="/api/ideas", tags=["ideas"])
|
||||
|
||||
99
api/endpoints/post_router.py
Normal file
99
api/endpoints/post_router.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from api.dependency import get_post_service, get_project_id
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.service.post_service import PostService
|
||||
from api.models import PostRequest, PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
from models.Post import Post
|
||||
|
||||
router = APIRouter(prefix="/api/posts", tags=["posts"])
|
||||
|
||||
|
||||
@router.post("", response_model=Post)
|
||||
async def create_post(
|
||||
request: PostCreateRequest,
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
pid = project_id or request.project_id
|
||||
return await post_service.create_post(
|
||||
date=request.date,
|
||||
topic=request.topic,
|
||||
generation_ids=request.generation_ids,
|
||||
project_id=pid,
|
||||
user_id=str(current_user["_id"]),
|
||||
)
|
||||
|
||||
|
||||
@router.get("", response_model=List[Post])
|
||||
async def get_posts(
|
||||
project_id: Optional[str] = Depends(get_project_id),
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
current_user: dict = Depends(get_current_user),
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
|
||||
|
||||
|
||||
@router.get("/{post_id}", response_model=Post)
|
||||
async def get_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.put("/{post_id}", response_model=Post)
|
||||
async def update_post(
|
||||
post_id: str,
|
||||
request: PostUpdateRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return post
|
||||
|
||||
|
||||
@router.delete("/{post_id}")
|
||||
async def delete_post(
|
||||
post_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.delete_post(post_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/{post_id}/generations")
|
||||
async def add_generations(
|
||||
post_id: str,
|
||||
request: AddGenerationsRequest,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.add_generations(post_id, request.generation_ids)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.delete("/{post_id}/generations/{generation_id}")
|
||||
async def remove_generation(
|
||||
post_id: str,
|
||||
generation_id: str,
|
||||
post_service: PostService = Depends(get_post_service),
|
||||
):
|
||||
success = await post_service.remove_generation(post_id, generation_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
|
||||
return {"status": "success"}
|
||||
@@ -1,4 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from bson import ObjectId
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from api.dependency import get_dao
|
||||
@@ -12,14 +14,46 @@ class ProjectCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
|
||||
class ProjectMemberResponse(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
|
||||
class ProjectResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
owner_id: str
|
||||
members: List[str]
|
||||
members: List[ProjectMemberResponse]
|
||||
is_owner: bool = False
|
||||
|
||||
async def _get_project_response(project: Project, current_user_id: str, dao: DAO) -> ProjectResponse:
|
||||
member_responses = []
|
||||
for member_id in project.members:
|
||||
# We need a way to get user by ID. Let's check UsersRepo for get_user by ObjectId or string.
|
||||
# Currently UsersRepo has get_user(user_id: int) for Telegram IDs.
|
||||
# But for Web users we might need to search by _id.
|
||||
# Let's try to get user info.
|
||||
# Since project.members contains strings (ObjectIds as strings), we search by _id.
|
||||
user_doc = await dao.users.collection.find_one({"_id": ObjectId(member_id)})
|
||||
if not user_doc and member_id.isdigit():
|
||||
# Fallback for telegram IDs if they are stored as strings of digits
|
||||
user_doc = await dao.users.get_user(int(member_id))
|
||||
|
||||
username = "unknown"
|
||||
if user_doc:
|
||||
username = user_doc.get("username", "unknown")
|
||||
|
||||
member_responses.append(ProjectMemberResponse(id=member_id, username=username))
|
||||
|
||||
return ProjectResponse(
|
||||
id=project.id,
|
||||
name=project.name,
|
||||
description=project.description,
|
||||
owner_id=project.owner_id,
|
||||
members=member_responses,
|
||||
is_owner=(project.owner_id == current_user_id)
|
||||
)
|
||||
|
||||
@router.post("", response_model=ProjectResponse)
|
||||
async def create_project(
|
||||
project_data: ProjectCreate,
|
||||
@@ -34,27 +68,15 @@ async def create_project(
|
||||
members=[user_id]
|
||||
)
|
||||
project_id = await dao.projects.create_project(new_project)
|
||||
new_project.id = project_id
|
||||
|
||||
# Add project to user's project list
|
||||
# Assuming user_repo has a method to add project or we do it directly?
|
||||
# UserRepo doesn't have add_project method yet.
|
||||
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
|
||||
# Better to update UserRepo. For now, let's just return success.
|
||||
# But user needs to see it in list.
|
||||
# Update user in DB
|
||||
await dao.users.collection.update_one(
|
||||
{"_id": current_user["_id"]},
|
||||
{"$addToSet": {"project_ids": project_id}}
|
||||
)
|
||||
|
||||
return ProjectResponse(
|
||||
id=project_id,
|
||||
name=new_project.name,
|
||||
description=new_project.description,
|
||||
owner_id=new_project.owner_id,
|
||||
members=new_project.members,
|
||||
is_owner=True
|
||||
)
|
||||
return await _get_project_response(new_project, user_id, dao)
|
||||
|
||||
@router.get("", response_model=List[ProjectResponse])
|
||||
async def get_my_projects(
|
||||
@@ -66,14 +88,7 @@ async def get_my_projects(
|
||||
|
||||
responses = []
|
||||
for p in projects:
|
||||
responses.append(ProjectResponse(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
owner_id=p.owner_id,
|
||||
members=p.members,
|
||||
is_owner=(p.owner_id == user_id)
|
||||
))
|
||||
responses.append(await _get_project_response(p, user_id, dao))
|
||||
return responses
|
||||
|
||||
class MemberAdd(BaseModel):
|
||||
|
||||
22
api/models/EnvironmentRequest.py
Normal file
22
api/models/EnvironmentRequest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EnvironmentCreate(BaseModel):
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: Optional[str] = None
|
||||
asset_ids: Optional[List[str]] = []
|
||||
|
||||
|
||||
class EnvironmentUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class AssetToEnvironment(BaseModel):
|
||||
asset_id: str
|
||||
|
||||
|
||||
class AssetsToEnvironment(BaseModel):
|
||||
asset_ids: List[str]
|
||||
18
api/models/FinancialUsageDTO.py
Normal file
18
api/models/FinancialUsageDTO.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
class UsageStats(BaseModel):
|
||||
total_runs: int
|
||||
total_tokens: int
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_cost: float
|
||||
|
||||
class UsageByEntity(BaseModel):
|
||||
entity_id: Optional[str] = None
|
||||
stats: UsageStats
|
||||
|
||||
class FinancialReport(BaseModel):
|
||||
summary: UsageStats
|
||||
by_user: Optional[List[UsageByEntity]] = None
|
||||
by_project: Optional[List[UsageByEntity]] = None
|
||||
@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
|
||||
telegram_id: Optional[int] = None
|
||||
use_profile_image: bool = True
|
||||
assets_list: List[str]
|
||||
environment_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
19
api/models/PostRequest.py
Normal file
19
api/models/PostRequest.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PostCreateRequest(BaseModel):
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = []
|
||||
project_id: Optional[str] = None
|
||||
|
||||
|
||||
class PostUpdateRequest(BaseModel):
|
||||
date: Optional[datetime] = None
|
||||
topic: Optional[str] = None
|
||||
|
||||
|
||||
class AddGenerationsRequest(BaseModel):
|
||||
generation_ids: List[str]
|
||||
@@ -0,0 +1,7 @@
|
||||
from .AssetDTO import AssetResponse, AssetsResponse
|
||||
from .CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
|
||||
from .ExternalGenerationDTO import ExternalGenerationRequest
|
||||
from .FinancialUsageDTO import FinancialReport, UsageStats, UsageByEntity
|
||||
from .GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse, PromptRequest, PromptResponse
|
||||
from .IdeaRequest import IdeaCreateRequest, IdeaUpdateRequest, IdeaResponse
|
||||
from .PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import random
|
||||
import base64
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
import httpx
|
||||
|
||||
import httpx
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
from fastapi import HTTPException
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from api.models.GenerationRequest import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import FinancialReport, UsageStats, UsageByEntity
|
||||
from api.models import GenerationRequest, GenerationResponse, GenerationsResponse, GenerationGroupResponse
|
||||
# Импортируйте ваши модели DAO, Asset, Generation корректно
|
||||
from models.Asset import Asset, AssetType, AssetContentType
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality, GenType
|
||||
from models.enums import AspectRatios, Quality
|
||||
from repos.dao import DAO
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,6 +134,9 @@ class GenerationService:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
if generation_request.environment_id and not generation_request.linked_character_id:
|
||||
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
if user_id:
|
||||
@@ -185,45 +190,40 @@ class GenerationService:
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = generation.prompt
|
||||
# generation_prompt = f"""
|
||||
|
||||
# Create detailed image of character in scene.
|
||||
|
||||
# SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
# Rules:
|
||||
# - Integrate the character's appearance naturally into the scene description.
|
||||
# - Focus on lighting, texture, and composition.
|
||||
# """
|
||||
# 2.1 Аватар персонажа (всегда первый, если включен)
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
if char_info.avatar_asset_id is not None:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset and avatar_asset.data:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
img_data = await self._get_asset_data(avatar_asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
# 2.2 Явно указанные ассеты
|
||||
if generation.assets_list:
|
||||
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in explicit_assets:
|
||||
ref_asset_data = await self._get_asset_data(asset)
|
||||
if ref_asset_data:
|
||||
media_group_bytes.append(ref_asset_data)
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
# 2.3 Ассеты из окружения (в самый конец)
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
img_data = await self._get_asset_data(asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
@@ -340,6 +340,14 @@ class GenerationService:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
@@ -377,7 +385,6 @@ class GenerationService:
|
||||
Returns:
|
||||
Created Generation object
|
||||
"""
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
# Validate image source
|
||||
external_gen.validate_image_source()
|
||||
@@ -507,3 +514,28 @@ class GenerationService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during old data cleanup: {e}")
|
||||
|
||||
async def get_financial_report(self, user_id: Optional[str] = None, project_id: Optional[str] = None, breakdown_by: Optional[str] = None) -> FinancialReport:
|
||||
"""
|
||||
Generates a financial usage report for a specific user or project.
|
||||
'breakdown_by' can be 'created_by' or 'project_id'.
|
||||
"""
|
||||
summary_data = await self.dao.generations.get_usage_stats(created_by=user_id, project_id=project_id)
|
||||
summary = UsageStats(**summary_data)
|
||||
|
||||
by_user = None
|
||||
by_project = None
|
||||
|
||||
if breakdown_by == "created_by":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="created_by", project_id=project_id, created_by=user_id)
|
||||
by_user = [UsageByEntity(**item) for item in res]
|
||||
|
||||
if breakdown_by == "project_id":
|
||||
res = await self.dao.generations.get_usage_breakdown(group_by="project_id", project_id=project_id, created_by=user_id)
|
||||
by_project = [UsageByEntity(**item) for item in res]
|
||||
|
||||
return FinancialReport(
|
||||
summary=summary,
|
||||
by_user=by_user,
|
||||
by_project=by_project
|
||||
)
|
||||
79
api/service/post_service.py
Normal file
79
api/service/post_service.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, UTC
|
||||
|
||||
from repos.dao import DAO
|
||||
from models.Post import Post
|
||||
|
||||
|
||||
class PostService:
|
||||
def __init__(self, dao: DAO):
|
||||
self.dao = dao
|
||||
|
||||
async def create_post(
|
||||
self,
|
||||
date: datetime,
|
||||
topic: str,
|
||||
generation_ids: List[str],
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Post:
|
||||
post = Post(
|
||||
date=date,
|
||||
topic=topic,
|
||||
generation_ids=generation_ids,
|
||||
project_id=project_id,
|
||||
created_by=user_id,
|
||||
)
|
||||
post_id = await self.dao.posts.create_post(post)
|
||||
post.id = post_id
|
||||
return post
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
|
||||
|
||||
async def update_post(
|
||||
self,
|
||||
post_id: str,
|
||||
date: Optional[datetime] = None,
|
||||
topic: Optional[str] = None,
|
||||
) -> Optional[Post]:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return None
|
||||
|
||||
updates: dict = {"updated_at": datetime.now(UTC)}
|
||||
if date is not None:
|
||||
updates["date"] = date
|
||||
if topic is not None:
|
||||
updates["topic"] = topic
|
||||
|
||||
await self.dao.posts.update_post(post_id, updates)
|
||||
|
||||
# Return refreshed post
|
||||
return await self.dao.posts.get_post(post_id)
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
return await self.dao.posts.delete_post(post_id)
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.add_generations(post_id, generation_ids)
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
post = await self.dao.posts.get_post(post_id)
|
||||
if not post:
|
||||
return False
|
||||
return await self.dao.posts.remove_generation(post_id, generation_id)
|
||||
39
config.py
Normal file
39
config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Telegram Bot
|
||||
BOT_TOKEN: str
|
||||
ADMIN_ID: int = 0
|
||||
|
||||
# AI Service
|
||||
GEMINI_API_KEY: str
|
||||
|
||||
# Database
|
||||
MONGO_HOST: str = "mongodb://localhost:27017"
|
||||
DB_NAME: str = "my_bot_db"
|
||||
|
||||
# S3 Storage (Minio)
|
||||
MINIO_ENDPOINT: str = "http://localhost:9000"
|
||||
MINIO_ACCESS_KEY: str = "minioadmin"
|
||||
MINIO_SECRET_KEY: str = "minioadmin"
|
||||
MINIO_BUCKET: str = "ai-char"
|
||||
|
||||
# External API
|
||||
EXTERNAL_API_SECRET: Optional[str] = None
|
||||
|
||||
# JWT Security
|
||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 * 24 * 60 # 30 days
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=os.getenv("ENV_FILE", ".env"),
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -27,19 +27,19 @@ class AlbumMiddleware(BaseMiddleware):
|
||||
# Ждем сбора остальных частей
|
||||
await asyncio.sleep(self.latency)
|
||||
|
||||
# Проверяем, что ключ все еще существует (на всякий случай)
|
||||
# Проверяем, что ключ все еще существует
|
||||
if group_id in self.album_data:
|
||||
# Передаем собранный альбом в хендлер
|
||||
# Сортируем по message_id, чтобы порядок был верным
|
||||
self.album_data[group_id].sort(key=lambda x: x.message_id)
|
||||
data["album"] = self.album_data[group_id]
|
||||
current_album = self.album_data[group_id]
|
||||
current_album.sort(key=lambda x: x.message_id)
|
||||
data["album"] = current_album
|
||||
return await handler(event, data)
|
||||
|
||||
finally:
|
||||
# ЧИСТКА: Удаляем всегда, если это "головной" поток, который создал запись
|
||||
# Проверяем, что мы удаляем именно то, что создали, и ключ существует
|
||||
if group_id in self.album_data and self.album_data[group_id][0] == event:
|
||||
del self.album_data[group_id]
|
||||
# ЧИСТКА: Удаляем запись после обработки или таймаута
|
||||
# Используем pop() с дефолтом, чтобы избежать KeyError
|
||||
self.album_data.pop(group_id, None)
|
||||
|
||||
else:
|
||||
# Если группа уже собирается - просто добавляем и выходим
|
||||
|
||||
@@ -63,6 +63,7 @@ class Asset(BaseModel):
|
||||
|
||||
# --- CALCULATED FIELD ---
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""
|
||||
Это поле автоматически вычислится и попадет в model_dump() / .json()
|
||||
|
||||
@@ -9,7 +9,6 @@ class Character(BaseModel):
|
||||
name: str
|
||||
avatar_asset_id: Optional[str] = None
|
||||
avatar_image: Optional[str] = None
|
||||
character_image_data: Optional[bytes] = None
|
||||
character_image_doc_tg_id: Optional[str] = None
|
||||
character_image_tg_id: Optional[str] = None
|
||||
character_bio: Optional[str] = None
|
||||
|
||||
20
models/Environment.py
Normal file
20
models/Environment.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
id: Optional[str] = Field(None, alias="_id")
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: Optional[str] = None
|
||||
asset_ids: List[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
json_encoders={ObjectId: str},
|
||||
arbitrary_types_allowed=True
|
||||
)
|
||||
@@ -35,6 +35,7 @@ class Generation(BaseModel):
|
||||
output_token_usage: Optional[int] = None
|
||||
is_deleted: bool = False
|
||||
album_id: Optional[str] = None
|
||||
environment_id: Optional[str] = None
|
||||
generation_group_id: Optional[str] = None
|
||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||
project_id: Optional[str] = None
|
||||
|
||||
23
models/Post.py
Normal file
23
models/Post.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from datetime import datetime, timezone, UTC
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
id: Optional[str] = None
|
||||
date: datetime
|
||||
topic: str
|
||||
generation_ids: List[str] = Field(default_factory=list)
|
||||
project_id: Optional[str] = None
|
||||
created_by: str
|
||||
is_deleted: bool = False
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_tz_aware(self):
|
||||
for field in ("date", "created_at", "updated_at"):
|
||||
val = getattr(self, field)
|
||||
if val is not None and val.tzinfo is None:
|
||||
setattr(self, field, val.replace(tzinfo=timezone.utc))
|
||||
return self
|
||||
Binary file not shown.
@@ -175,6 +175,8 @@ class AssetsRepo:
|
||||
filter["linked_char_id"] = character_id
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
return await self.collection.count_documents(filter)
|
||||
|
||||
@@ -15,26 +15,24 @@ class CharacterRepo:
|
||||
character.id = str(op.inserted_id)
|
||||
return character
|
||||
|
||||
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
|
||||
args = {}
|
||||
if not with_image_data:
|
||||
args["character_image_data"] = 0
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)}, args)
|
||||
async def get_character(self, character_id: str) -> Character | None:
|
||||
res = await self.collection.find_one({"_id": ObjectId(character_id)})
|
||||
if res is None:
|
||||
return None
|
||||
else:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Character(**res)
|
||||
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
|
||||
async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Character]:
|
||||
filter = {}
|
||||
if created_by:
|
||||
filter["created_by"] = created_by
|
||||
if project_id is None:
|
||||
filter["project_id"] = None
|
||||
if project_id:
|
||||
filter["project_id"] = project_id
|
||||
|
||||
args = {"character_image_data": 0} # don't return image data for list
|
||||
res = await self.collection.find(filter, args).to_list(None)
|
||||
res = await self.collection.find(filter).skip(offset).limit(limit).to_list(None)
|
||||
chars = []
|
||||
for doc in res:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
|
||||
@@ -7,6 +7,8 @@ from repos.user_repo import UsersRepo
|
||||
from repos.albums_repo import AlbumsRepo
|
||||
from repos.project_repo import ProjectRepo
|
||||
from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
from repos.environment_repo import EnvironmentRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -21,3 +23,5 @@ class DAO:
|
||||
self.projects = ProjectRepo(client, db_name)
|
||||
self.users = UsersRepo(client, db_name)
|
||||
self.ideas = IdeaRepo(client, db_name)
|
||||
self.posts = PostRepo(client, db_name)
|
||||
self.environments = EnvironmentRepo(client, db_name)
|
||||
|
||||
73
repos/environment_repo.py
Normal file
73
repos/environment_repo.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Environment import Environment
|
||||
|
||||
|
||||
class EnvironmentRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["environments"]
|
||||
|
||||
async def create_env(self, env: Environment) -> Environment:
|
||||
env_dict = env.model_dump(exclude={"id"})
|
||||
res = await self.collection.insert_one(env_dict)
|
||||
env.id = str(res.inserted_id)
|
||||
return env
|
||||
|
||||
async def get_env(self, env_id: str) -> Optional[Environment]:
|
||||
res = await self.collection.find_one({"_id": ObjectId(env_id)})
|
||||
if not res:
|
||||
return None
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Environment(**res)
|
||||
|
||||
async def get_character_envs(self, character_id: str) -> List[Environment]:
|
||||
cursor = self.collection.find({"character_id": character_id})
|
||||
envs = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
envs.append(Environment(**doc))
|
||||
return envs
|
||||
|
||||
async def update_env(self, env_id: str, update_data: dict) -> bool:
|
||||
update_data["updated_at"] = datetime.utcnow()
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{"$set": update_data}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_env(self, env_id: str) -> bool:
|
||||
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
|
||||
return res.deleted_count > 0
|
||||
|
||||
async def add_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$addToSet": {"asset_ids": {"$each": asset_ids}},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(env_id)},
|
||||
{
|
||||
"$pull": {"asset_ids": asset_id},
|
||||
"$set": {"updated_at": datetime.utcnow()}
|
||||
}
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -65,6 +65,8 @@ class GenerationRepo:
|
||||
args["status"] = status
|
||||
if created_by is not None:
|
||||
args["created_by"] = created_by
|
||||
if project_id is None:
|
||||
args["project_id"] = None
|
||||
if project_id is not None:
|
||||
args["project_id"] = project_id
|
||||
if idea_id is not None:
|
||||
@@ -92,6 +94,123 @@ class GenerationRepo:
|
||||
async def update_generation(self, generation: Generation, ):
|
||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||
|
||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
# 1. Match all done generations (including soft-deleted)
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
# 2. Group by null (total)
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": None,
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(1)
|
||||
|
||||
if not res:
|
||||
return {
|
||||
"total_runs": 0,
|
||||
"total_tokens": 0,
|
||||
"total_input_tokens": 0,
|
||||
"total_output_tokens": 0,
|
||||
"total_cost": 0.0
|
||||
}
|
||||
|
||||
result = res[0]
|
||||
result.pop("_id")
|
||||
result["total_cost"] = round(result["total_cost"], 4)
|
||||
return result
|
||||
|
||||
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
||||
"""
|
||||
Returns usage statistics grouped by user or project.
|
||||
Includes even soft-deleted generations to reflect actual expenditure.
|
||||
"""
|
||||
pipeline = []
|
||||
|
||||
match_stage = {"status": GenerationStatus.DONE}
|
||||
if project_id:
|
||||
match_stage["project_id"] = project_id
|
||||
if created_by:
|
||||
match_stage["created_by"] = created_by
|
||||
|
||||
pipeline.append({"$match": match_stage})
|
||||
|
||||
pipeline.append({
|
||||
"$group": {
|
||||
"_id": f"${group_by}",
|
||||
"total_runs": {"$sum": 1},
|
||||
"total_tokens": {
|
||||
"$sum": {
|
||||
"$cond": [
|
||||
{"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]},
|
||||
{"$add": ["$input_token_usage", "$output_token_usage"]},
|
||||
{"$ifNull": ["$token_usage", 0]}
|
||||
]
|
||||
}
|
||||
},
|
||||
"total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}},
|
||||
"total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}},
|
||||
"total_cost": {
|
||||
"$sum": {
|
||||
"$add": [
|
||||
{"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]},
|
||||
{"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
pipeline.append({"$sort": {"total_cost": -1}})
|
||||
|
||||
cursor = self.collection.aggregate(pipeline)
|
||||
res = await cursor.to_list(None)
|
||||
|
||||
results = []
|
||||
for item in res:
|
||||
entity_id = item.pop("_id")
|
||||
item["total_cost"] = round(item["total_cost"], 4)
|
||||
results.append({
|
||||
"entity_id": str(entity_id) if entity_id else "unknown",
|
||||
"stats": item
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
async def get_generations_by_group(self, group_id: str) -> List[Generation]:
|
||||
res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None)
|
||||
generations: List[Generation] = []
|
||||
|
||||
@@ -39,8 +39,17 @@ class IdeaRepo:
|
||||
"from": "generations",
|
||||
"let": {"idea_id": "$str_id"},
|
||||
"pipeline": [
|
||||
{"$match": {"$expr": {"$eq": ["$idea_id", "$$idea_id"]}}},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest
|
||||
{
|
||||
"$match": {
|
||||
"$and": [
|
||||
{"$expr": {"$eq": ["$idea_id", "$$idea_id"]}},
|
||||
{"status": "done"},
|
||||
{"result_list": {"$exists": True, "$not": {"$size": 0}}},
|
||||
{"is_deleted": False}
|
||||
]
|
||||
}
|
||||
},
|
||||
{"$sort": {"created_at": -1}}, # Ensure we get the latest successful
|
||||
{"$limit": 1}
|
||||
],
|
||||
"as": "generations"
|
||||
|
||||
97
repos/post_repo.py
Normal file
97
repos/post_repo.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from bson import ObjectId
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from models.Post import Post
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||
self.collection = client[db_name]["posts"]
|
||||
|
||||
async def create_post(self, post: Post) -> str:
|
||||
res = await self.collection.insert_one(post.model_dump())
|
||||
return str(res.inserted_id)
|
||||
|
||||
async def get_post(self, post_id: str) -> Optional[Post]:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return None
|
||||
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
|
||||
if res:
|
||||
res["id"] = str(res.pop("_id"))
|
||||
return Post(**res)
|
||||
return None
|
||||
|
||||
async def get_posts(
|
||||
self,
|
||||
project_id: Optional[str],
|
||||
user_id: str,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
date_from: Optional[datetime] = None,
|
||||
date_to: Optional[datetime] = None,
|
||||
) -> List[Post]:
|
||||
if project_id:
|
||||
match = {"project_id": project_id, "is_deleted": False}
|
||||
else:
|
||||
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
|
||||
|
||||
if date_from or date_to:
|
||||
date_filter = {}
|
||||
if date_from:
|
||||
date_filter["$gte"] = date_from
|
||||
if date_to:
|
||||
date_filter["$lte"] = date_to
|
||||
match["date"] = date_filter
|
||||
|
||||
cursor = (
|
||||
self.collection.find(match)
|
||||
.sort("date", -1)
|
||||
.skip(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
posts = []
|
||||
async for doc in cursor:
|
||||
doc["id"] = str(doc.pop("_id"))
|
||||
posts.append(Post(**doc))
|
||||
return posts
|
||||
|
||||
async def update_post(self, post_id: str, data: dict) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": data},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def delete_post(self, post_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$set": {"is_deleted": True}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
|
||||
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
|
||||
if not ObjectId.is_valid(post_id):
|
||||
return False
|
||||
res = await self.collection.update_one(
|
||||
{"_id": ObjectId(post_id)},
|
||||
{"$pull": {"generation_ids": generation_id}},
|
||||
)
|
||||
return res.modified_count > 0
|
||||
@@ -51,3 +51,4 @@ python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.22
|
||||
email-validator
|
||||
prometheus-fastapi-instrumentator
|
||||
pydantic-settings==2.13.0
|
||||
|
||||
@@ -51,57 +51,66 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
|
||||
wait_msg = await message.answer("💾 Сохраняю персонажа...")
|
||||
|
||||
try:
|
||||
# ВОТ ТУТ скачиваем файл (прямо перед сохранением)
|
||||
# 1. Скачиваем файл (один раз)
|
||||
# TODO: Для больших файлов лучше использовать streaming или сохранять во временный файл
|
||||
file_io = await bot.download(file_id)
|
||||
# photo_bytes = file_io.getvalue() # Получаем байты
|
||||
file_bytes = file_io.read()
|
||||
|
||||
|
||||
# Создаем модель
|
||||
# 2. Создаем Character (сначала без ассета, чтобы получить ID)
|
||||
char = Character(
|
||||
id=None,
|
||||
name=name,
|
||||
character_image_data=file_io.read(),
|
||||
character_image_tg_id=None,
|
||||
character_image_doc_tg_id=file_id,
|
||||
character_bio=bio,
|
||||
created_by=str(message.from_user.id)
|
||||
)
|
||||
file_io.close()
|
||||
|
||||
# Сохраняем через DAO
|
||||
|
||||
# Сохраняем, чтобы получить ID
|
||||
await dao.chars.add_character(char)
|
||||
file_info = await bot.get_file(char.character_image_doc_tg_id)
|
||||
file_bytes = await bot.download_file(file_info.file_path)
|
||||
file_io = file_bytes.read()
|
||||
avatar_asset = await dao.assets.create_asset(
|
||||
Asset(name="avatar.png", type=AssetType.UPLOADED, content_type=AssetContentType.IMAGE, linked_char_id=str(char.id), data=file_io,
|
||||
tg_doc_file_id=file_id))
|
||||
char.avatar_image = avatar_asset.link
|
||||
|
||||
# 3. Создаем Asset (связанный с персонажем)
|
||||
avatar_asset_id = await dao.assets.create_asset(
|
||||
Asset(
|
||||
name="avatar.png",
|
||||
type=AssetType.UPLOADED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=str(char.id),
|
||||
data=file_bytes,
|
||||
tg_doc_file_id=file_id
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Обновляем персонажа ссылками на ассет
|
||||
char.avatar_asset_id = avatar_asset_id
|
||||
char.avatar_image = f"/api/assets/{avatar_asset_id}" # Формируем ссылку вручную или используем метод, если появится
|
||||
|
||||
# Отправляем подтверждение
|
||||
# Используем байты для отправки обратно
|
||||
photo_msg = await message.answer_photo(
|
||||
photo=BufferedInputFile(file_io,
|
||||
filename="char.jpg") if not char.character_image_tg_id else char.character_image_tg_id,
|
||||
photo=BufferedInputFile(file_bytes, filename="char.jpg"),
|
||||
caption=(
|
||||
"🎉 <b>Персонаж создан!</b>\n\n"
|
||||
f"👤 <b>Имя:</b> {char.name}\n"
|
||||
f"📝 <b>Био:</b> {char.character_bio}"
|
||||
)
|
||||
)
|
||||
file_bytes.close()
|
||||
char.character_image_tg_id = photo_msg.photo[0].file_id
|
||||
|
||||
# Сохраняем TG ID фото (которое отправили как фото, а не документ)
|
||||
char.character_image_tg_id = photo_msg.photo[-1].file_id
|
||||
|
||||
# Финальное обновление персонажа
|
||||
await dao.chars.update_char(char.id, char)
|
||||
|
||||
await wait_msg.delete()
|
||||
file_io.close()
|
||||
|
||||
# Сбрасываем состояние
|
||||
await state.clear()
|
||||
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logger.error(f"Error creating character: {e}")
|
||||
traceback.print_exc()
|
||||
await wait_msg.edit_text(f"❌ Ошибка при сохранении: {e}")
|
||||
# Не сбрасываем стейт, даем возможность попробовать ввести био снова или начать заново
|
||||
|
||||
|
||||
@router.message(Command("chars"))
|
||||
|
||||
@@ -3,17 +3,17 @@ import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
import os
|
||||
import asyncio
|
||||
from config import settings
|
||||
|
||||
from main import app
|
||||
from aiws import app
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from models.Character import Character
|
||||
|
||||
# Config for test DB
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_chars"
|
||||
|
||||
# Mock User
|
||||
|
||||
@@ -10,13 +10,13 @@ import json
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import settings
|
||||
|
||||
load_dotenv()
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
# Configuration
|
||||
API_URL = "http://localhost:8090/api/generations/import"
|
||||
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
|
||||
SECRET = settings.EXTERNAL_API_SECRET or "your_super_secret_key_change_this_in_production"
|
||||
|
||||
# Sample generation data
|
||||
generation_data = {
|
||||
|
||||
@@ -10,11 +10,10 @@ from repos.dao import DAO
|
||||
from models.Idea import Idea
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
load_dotenv()
|
||||
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import asyncio
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from config import settings
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
async def test_s3():
|
||||
load_dotenv()
|
||||
|
||||
endpoint = os.getenv("MINIO_ENDPOINT", "http://localhost:9000")
|
||||
access_key = os.getenv("MINIO_ACCESS_KEY")
|
||||
secret_key = os.getenv("MINIO_SECRET_KEY")
|
||||
bucket = os.getenv("MINIO_BUCKET")
|
||||
endpoint = settings.MINIO_ENDPOINT
|
||||
access_key = settings.MINIO_ACCESS_KEY
|
||||
secret_key = settings.MINIO_SECRET_KEY
|
||||
bucket = settings.MINIO_BUCKET
|
||||
|
||||
print(f"Connecting to {endpoint}, bucket: {bucket}")
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@ from datetime import datetime, timedelta, UTC
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from repos.generation_repo import GenerationRepo
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
from config import settings
|
||||
|
||||
# Mock configs if not present in env
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
DB_NAME = os.getenv("DB_NAME", "bot_db")
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = settings.DB_NAME
|
||||
|
||||
print(f"Connecting to MongoDB: {MONGO_HOST}, DB: {DB_NAME}")
|
||||
|
||||
|
||||
@@ -10,10 +10,11 @@ from repos.dao import DAO
|
||||
from models.Album import Album
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
# Mock config
|
||||
# Use the same host as aiws.py but different DB
|
||||
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||
MONGO_HOST = settings.MONGO_HOST
|
||||
DB_NAME = "bot_db_test_albums"
|
||||
|
||||
async def test_albums():
|
||||
@@ -83,8 +84,6 @@ async def test_albums():
|
||||
client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
try:
|
||||
asyncio.run(test_albums())
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,29 +1,28 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
from config import settings
|
||||
from models.Asset import Asset, AssetType
|
||||
from repos.assets_repo import AssetsRepo
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
|
||||
# Load env to get credentials
|
||||
load_dotenv()
|
||||
# Load env is not needed as settings handles it
|
||||
|
||||
async def test_integration():
|
||||
print("🚀 Starting integration test...")
|
||||
|
||||
# 1. Setup Dependencies
|
||||
mongo_uri = os.getenv("MONGO_HOST", "mongodb://localhost:27017")
|
||||
mongo_uri = settings.MONGO_HOST
|
||||
client = AsyncIOMotorClient(mongo_uri)
|
||||
db_name = os.getenv("DB_NAME", "bot_db_test")
|
||||
db_name = settings.DB_NAME + "_test"
|
||||
|
||||
s3_adapter = S3Adapter(
|
||||
endpoint_url=os.getenv("MINIO_ENDPOINT", "http://localhost:9000"),
|
||||
aws_access_key_id=os.getenv("MINIO_ACCESS_KEY", "admin"),
|
||||
aws_secret_access_key=os.getenv("MINIO_SECRET_KEY", "SuperSecretPassword123!"),
|
||||
bucket_name=os.getenv("MINIO_BUCKET", "ai-char")
|
||||
endpoint_url=settings.MINIO_ENDPOINT,
|
||||
aws_access_key_id=settings.MINIO_ACCESS_KEY,
|
||||
aws_secret_access_key=settings.MINIO_SECRET_KEY,
|
||||
bucket_name=settings.MINIO_BUCKET
|
||||
)
|
||||
|
||||
repo = AssetsRepo(client, s3_adapter, db_name=db_name)
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Optional, Union, Any
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
from config import settings
|
||||
|
||||
# Настройки безопасности (лучше вынести в config/env, но для старта здесь)
|
||||
# SECRET_KEY должен быть сложным и секретным в продакшене!
|
||||
SECRET_KEY = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 дней, например
|
||||
# Настройки безопасности берутся из config.py
|
||||
SECRET_KEY = settings.SECRET_KEY
|
||||
ALGORITHM = settings.ALGORITHM
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
|
||||
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user