Compare commits
1 Commits
ffb0463fe0
...
enviroment
| Author | SHA1 | Date | |
|---|---|---|---|
| 5aa6391dc8 |
3
aiws.py
3
aiws.py
@@ -44,7 +44,7 @@ from api.endpoints.album_router import router as api_album_router
|
|||||||
from api.endpoints.project_router import router as project_api_router
|
from api.endpoints.project_router import router as project_api_router
|
||||||
from api.endpoints.idea_router import router as idea_api_router
|
from api.endpoints.idea_router import router as idea_api_router
|
||||||
from api.endpoints.post_router import router as post_api_router
|
from api.endpoints.post_router import router as post_api_router
|
||||||
|
from api.endpoints.environment_router import router as environment_api_router
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -221,6 +221,7 @@ app.include_router(api_album_router)
|
|||||||
app.include_router(project_api_router)
|
app.include_router(project_api_router)
|
||||||
app.include_router(idea_api_router)
|
app.include_router(idea_api_router)
|
||||||
app.include_router(post_api_router)
|
app.include_router(post_api_router)
|
||||||
|
app.include_router(environment_api_router)
|
||||||
|
|
||||||
# Prometheus Metrics (Instrument after all routers are added)
|
# Prometheus Metrics (Instrument after all routers are added)
|
||||||
Instrumentator(
|
Instrumentator(
|
||||||
|
|||||||
180
api/endpoints/environment_router.py
Normal file
180
api/endpoints/environment_router.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from starlette import status
|
||||||
|
|
||||||
|
from api.dependency import get_dao
|
||||||
|
from api.endpoints.auth import get_current_user
|
||||||
|
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||||
|
from models.Environment import Environment
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||||
|
|
||||||
|
|
||||||
|
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||||
|
character = await dao.chars.get_character(character_id)
|
||||||
|
if not character:
|
||||||
|
raise HTTPException(status_code=404, detail="Character not found")
|
||||||
|
|
||||||
|
is_creator = character.created_by == str(current_user["_id"])
|
||||||
|
is_project_member = False
|
||||||
|
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||||
|
is_project_member = True
|
||||||
|
|
||||||
|
if not is_creator and not is_project_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||||
|
return character
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=Environment)
|
||||||
|
async def create_environment(
|
||||||
|
env_req: EnvironmentCreate,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||||
|
await check_character_access(env_req.character_id, current_user, dao)
|
||||||
|
|
||||||
|
# Verify assets exist if provided
|
||||||
|
if env_req.asset_ids:
|
||||||
|
for aid in env_req.asset_ids:
|
||||||
|
asset = await dao.assets.get_asset(aid)
|
||||||
|
if not asset:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||||
|
|
||||||
|
new_env = Environment(**env_req.model_dump())
|
||||||
|
created_env = await dao.environments.create_env(new_env)
|
||||||
|
return created_env
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/character/{character_id}", response_model=List[Environment])
|
||||||
|
async def get_character_environments(
|
||||||
|
character_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
logger.info(f"Getting environments for character {character_id}")
|
||||||
|
await check_character_access(character_id, current_user, dao)
|
||||||
|
return await dao.environments.get_character_envs(character_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{env_id}", response_model=Environment)
|
||||||
|
async def get_environment(
|
||||||
|
env_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{env_id}", response_model=Environment)
|
||||||
|
async def update_environment(
|
||||||
|
env_id: str,
|
||||||
|
env_update: EnvironmentUpdate,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
|
||||||
|
update_data = env_update.model_dump(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
return env
|
||||||
|
|
||||||
|
success = await dao.environments.update_env(env_id, update_data)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||||
|
|
||||||
|
return await dao.environments.get_env(env_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_environment(
|
||||||
|
env_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
|
||||||
|
success = await dao.environments.delete_env(env_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||||
|
async def add_asset_to_environment(
|
||||||
|
env_id: str,
|
||||||
|
req: AssetToEnvironment,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
|
||||||
|
# Verify asset exists
|
||||||
|
asset = await dao.assets.get_asset(req.asset_id)
|
||||||
|
if not asset:
|
||||||
|
raise HTTPException(status_code=404, detail="Asset not found")
|
||||||
|
|
||||||
|
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||||
|
return {"success": success}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||||
|
async def add_assets_batch_to_environment(
|
||||||
|
env_id: str,
|
||||||
|
req: AssetsToEnvironment,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
|
||||||
|
# Verify all assets exist
|
||||||
|
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||||
|
if len(assets) != len(req.asset_ids):
|
||||||
|
found_ids = {a.id for a in assets}
|
||||||
|
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||||
|
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||||
|
|
||||||
|
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||||
|
return {"success": success}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||||
|
async def remove_asset_from_environment(
|
||||||
|
env_id: str,
|
||||||
|
asset_id: str,
|
||||||
|
dao: DAO = Depends(get_dao),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
env = await dao.environments.get_env(env_id)
|
||||||
|
if not env:
|
||||||
|
raise HTTPException(status_code=404, detail="Environment not found")
|
||||||
|
|
||||||
|
await check_character_access(env.character_id, current_user, dao)
|
||||||
|
|
||||||
|
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||||
|
return {"success": success}
|
||||||
22
api/models/EnvironmentRequest.py
Normal file
22
api/models/EnvironmentRequest.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentCreate(BaseModel):
|
||||||
|
character_id: str
|
||||||
|
name: str = Field(..., min_length=1)
|
||||||
|
description: Optional[str] = None
|
||||||
|
asset_ids: Optional[List[str]] = []
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentUpdate(BaseModel):
|
||||||
|
name: Optional[str] = Field(None, min_length=1)
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetToEnvironment(BaseModel):
|
||||||
|
asset_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class AssetsToEnvironment(BaseModel):
|
||||||
|
asset_ids: List[str]
|
||||||
@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
|
|||||||
telegram_id: Optional[int] = None
|
telegram_id: Optional[int] = None
|
||||||
use_profile_image: bool = True
|
use_profile_image: bool = True
|
||||||
assets_list: List[str]
|
assets_list: List[str]
|
||||||
|
environment_id: Optional[str] = None
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
idea_id: Optional[str] = None
|
idea_id: Optional[str] = None
|
||||||
count: int = Field(default=1, ge=1, le=10)
|
count: int = Field(default=1, ge=1, le=10)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from uuid import uuid4
|
|||||||
import httpx
|
import httpx
|
||||||
from aiogram import Bot
|
from aiogram import Bot
|
||||||
from aiogram.types import BufferedInputFile
|
from aiogram.types import BufferedInputFile
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
@@ -133,6 +134,9 @@ class GenerationService:
|
|||||||
gen_id = None
|
gen_id = None
|
||||||
generation_model = None
|
generation_model = None
|
||||||
|
|
||||||
|
if generation_request.environment_id and not generation_request.linked_character_id:
|
||||||
|
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -186,45 +190,40 @@ class GenerationService:
|
|||||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||||
|
|
||||||
# 2. Получаем ассеты-референсы (если они есть)
|
# 2. Получаем ассеты-референсы (если они есть)
|
||||||
reference_assets: List[Asset] = []
|
|
||||||
media_group_bytes: List[bytes] = []
|
media_group_bytes: List[bytes] = []
|
||||||
generation_prompt = generation.prompt
|
generation_prompt = generation.prompt
|
||||||
# generation_prompt = f"""
|
|
||||||
|
# 2.1 Аватар персонажа (всегда первый, если включен)
|
||||||
# Create detailed image of character in scene.
|
|
||||||
|
|
||||||
# SCENE DESCRIPTION: {generation.prompt}
|
|
||||||
|
|
||||||
# Rules:
|
|
||||||
# - Integrate the character's appearance naturally into the scene description.
|
|
||||||
# - Focus on lighting, texture, and composition.
|
|
||||||
# """
|
|
||||||
if generation.linked_character_id is not None:
|
if generation.linked_character_id is not None:
|
||||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||||
if char_info is None:
|
if char_info is None:
|
||||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||||
if generation.use_profile_image:
|
|
||||||
if char_info.avatar_asset_id is not None:
|
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||||
if avatar_asset and avatar_asset.data:
|
if avatar_asset:
|
||||||
media_group_bytes.append(avatar_asset.data)
|
img_data = await self._get_asset_data(avatar_asset)
|
||||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
if img_data:
|
||||||
|
media_group_bytes.append(img_data)
|
||||||
|
|
||||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
# 2.2 Явно указанные ассеты
|
||||||
|
if generation.assets_list:
|
||||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
for asset in reference_assets:
|
for asset in explicit_assets:
|
||||||
if asset.content_type != AssetContentType.IMAGE:
|
ref_asset_data = await self._get_asset_data(asset)
|
||||||
continue
|
if ref_asset_data:
|
||||||
|
media_group_bytes.append(ref_asset_data)
|
||||||
img_data = None
|
|
||||||
if asset.minio_object_name:
|
# 2.3 Ассеты из окружения (в самый конец)
|
||||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
if generation.environment_id:
|
||||||
elif asset.data:
|
env = await self.dao.environments.get_env(generation.environment_id)
|
||||||
img_data = asset.data
|
if env and env.asset_ids:
|
||||||
|
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
|
||||||
if img_data:
|
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||||
media_group_bytes.append(img_data)
|
for asset in env_assets:
|
||||||
|
img_data = await self._get_asset_data(asset)
|
||||||
|
if img_data:
|
||||||
|
media_group_bytes.append(img_data)
|
||||||
|
|
||||||
if media_group_bytes:
|
if media_group_bytes:
|
||||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||||
@@ -341,6 +340,14 @@ class GenerationService:
|
|||||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
|
||||||
|
if asset.content_type != AssetContentType.IMAGE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if asset.minio_object_name:
|
||||||
|
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||||
|
return asset.data
|
||||||
|
|
||||||
async def _simulate_progress(self, generation: Generation):
|
async def _simulate_progress(self, generation: Generation):
|
||||||
"""
|
"""
|
||||||
Increments progress from 0 to 90 over ~20 seconds.
|
Increments progress from 0 to 90 over ~20 seconds.
|
||||||
|
|||||||
20
models/Environment.py
Normal file
20
models/Environment.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
from datetime import datetime
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
|
||||||
|
class Environment(BaseModel):
|
||||||
|
id: Optional[str] = Field(None, alias="_id")
|
||||||
|
character_id: str
|
||||||
|
name: str = Field(..., min_length=1)
|
||||||
|
description: Optional[str] = None
|
||||||
|
asset_ids: List[str] = Field(default_factory=list)
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
json_encoders={ObjectId: str},
|
||||||
|
arbitrary_types_allowed=True
|
||||||
|
)
|
||||||
@@ -35,6 +35,7 @@ class Generation(BaseModel):
|
|||||||
output_token_usage: Optional[int] = None
|
output_token_usage: Optional[int] = None
|
||||||
is_deleted: bool = False
|
is_deleted: bool = False
|
||||||
album_id: Optional[str] = None
|
album_id: Optional[str] = None
|
||||||
|
environment_id: Optional[str] = None
|
||||||
generation_group_id: Optional[str] = None
|
generation_group_id: Optional[str] = None
|
||||||
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
|
||||||
project_id: Optional[str] = None
|
project_id: Optional[str] = None
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from repos.albums_repo import AlbumsRepo
|
|||||||
from repos.project_repo import ProjectRepo
|
from repos.project_repo import ProjectRepo
|
||||||
from repos.idea_repo import IdeaRepo
|
from repos.idea_repo import IdeaRepo
|
||||||
from repos.post_repo import PostRepo
|
from repos.post_repo import PostRepo
|
||||||
|
from repos.environment_repo import EnvironmentRepo
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -23,3 +24,4 @@ class DAO:
|
|||||||
self.users = UsersRepo(client, db_name)
|
self.users = UsersRepo(client, db_name)
|
||||||
self.ideas = IdeaRepo(client, db_name)
|
self.ideas = IdeaRepo(client, db_name)
|
||||||
self.posts = PostRepo(client, db_name)
|
self.posts = PostRepo(client, db_name)
|
||||||
|
self.environments = EnvironmentRepo(client, db_name)
|
||||||
|
|||||||
73
repos/environment_repo.py
Normal file
73
repos/environment_repo.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from models.Environment import Environment
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["environments"]
|
||||||
|
|
||||||
|
async def create_env(self, env: Environment) -> Environment:
|
||||||
|
env_dict = env.model_dump(exclude={"id"})
|
||||||
|
res = await self.collection.insert_one(env_dict)
|
||||||
|
env.id = str(res.inserted_id)
|
||||||
|
return env
|
||||||
|
|
||||||
|
async def get_env(self, env_id: str) -> Optional[Environment]:
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(env_id)})
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Environment(**res)
|
||||||
|
|
||||||
|
async def get_character_envs(self, character_id: str) -> List[Environment]:
|
||||||
|
cursor = self.collection.find({"character_id": character_id})
|
||||||
|
envs = []
|
||||||
|
async for doc in cursor:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
envs.append(Environment(**doc))
|
||||||
|
return envs
|
||||||
|
|
||||||
|
async def update_env(self, env_id: str, update_data: dict) -> bool:
|
||||||
|
update_data["updated_at"] = datetime.utcnow()
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(env_id)},
|
||||||
|
{"$set": update_data}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def delete_env(self, env_id: str) -> bool:
|
||||||
|
res = await self.collection.delete_one({"_id": ObjectId(env_id)})
|
||||||
|
return res.deleted_count > 0
|
||||||
|
|
||||||
|
async def add_asset(self, env_id: str, asset_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(env_id)},
|
||||||
|
{
|
||||||
|
"$addToSet": {"asset_ids": asset_id},
|
||||||
|
"$set": {"updated_at": datetime.utcnow()}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(env_id)},
|
||||||
|
{
|
||||||
|
"$addToSet": {"asset_ids": {"$each": asset_ids}},
|
||||||
|
"$set": {"updated_at": datetime.utcnow()}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def remove_asset(self, env_id: str, asset_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(env_id)},
|
||||||
|
{
|
||||||
|
"$pull": {"asset_ids": asset_id},
|
||||||
|
"$set": {"updated_at": datetime.utcnow()}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
@@ -97,11 +97,12 @@ class GenerationRepo:
|
|||||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||||
|
Includes even soft-deleted generations to reflect actual expenditure.
|
||||||
"""
|
"""
|
||||||
pipeline = []
|
pipeline = []
|
||||||
|
|
||||||
# 1. Match active done generations
|
# 1. Match all done generations (including soft-deleted)
|
||||||
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
|
match_stage = {"status": GenerationStatus.DONE}
|
||||||
if created_by:
|
if created_by:
|
||||||
match_stage["created_by"] = created_by
|
match_stage["created_by"] = created_by
|
||||||
if project_id:
|
if project_id:
|
||||||
@@ -156,10 +157,11 @@ class GenerationRepo:
|
|||||||
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Returns usage statistics grouped by user or project.
|
Returns usage statistics grouped by user or project.
|
||||||
|
Includes even soft-deleted generations to reflect actual expenditure.
|
||||||
"""
|
"""
|
||||||
pipeline = []
|
pipeline = []
|
||||||
|
|
||||||
match_stage = {"is_deleted": False, "status": GenerationStatus.DONE}
|
match_stage = {"status": GenerationStatus.DONE}
|
||||||
if project_id:
|
if project_id:
|
||||||
match_stage["project_id"] = project_id
|
match_stage["project_id"] = project_id
|
||||||
if created_by:
|
if created_by:
|
||||||
|
|||||||
Reference in New Issue
Block a user