From 5aa6391dc84e2ad75e00ab014969683149cf547c Mon Sep 17 00:00:00 2001 From: xds Date: Thu, 19 Feb 2026 21:25:29 +0300 Subject: [PATCH] + env --- aiws.py | 3 +- api/endpoints/environment_router.py | 180 ++++++++++++++++++++++++++++ api/models/EnvironmentRequest.py | 22 ++++ api/models/GenerationRequest.py | 1 + api/service/generation_service.py | 71 ++++++----- models/Environment.py | 20 ++++ models/Generation.py | 1 + repos/dao.py | 2 + repos/environment_repo.py | 73 +++++++++++ repos/generation_repo.py | 8 +- 10 files changed, 345 insertions(+), 36 deletions(-) create mode 100644 api/endpoints/environment_router.py create mode 100644 api/models/EnvironmentRequest.py create mode 100644 models/Environment.py create mode 100644 repos/environment_repo.py diff --git a/aiws.py b/aiws.py index e245dac..ee0cb86 100644 --- a/aiws.py +++ b/aiws.py @@ -44,7 +44,7 @@ from api.endpoints.album_router import router as api_album_router from api.endpoints.project_router import router as project_api_router from api.endpoints.idea_router import router as idea_api_router from api.endpoints.post_router import router as post_api_router - +from api.endpoints.environment_router import router as environment_api_router logger = logging.getLogger(__name__) @@ -221,6 +221,7 @@ app.include_router(api_album_router) app.include_router(project_api_router) app.include_router(idea_api_router) app.include_router(post_api_router) +app.include_router(environment_api_router) # Prometheus Metrics (Instrument after all routers are added) Instrumentator( diff --git a/api/endpoints/environment_router.py b/api/endpoints/environment_router.py new file mode 100644 index 0000000..aafb2c5 --- /dev/null +++ b/api/endpoints/environment_router.py @@ -0,0 +1,180 @@ +import logging +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException +from starlette import status + +from api.dependency import get_dao +from api.endpoints.auth import get_current_user +from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment +from models.Environment import Environment +from repos.dao import DAO + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)]) + + +async def check_character_access(character_id: str, current_user: dict, dao: DAO): + character = await dao.chars.get_character(character_id) + if not character: + raise HTTPException(status_code=404, detail="Character not found") + + is_creator = character.created_by == str(current_user["_id"]) + is_project_member = False + if character.project_id and character.project_id in current_user.get("project_ids", []): + is_project_member = True + + if not is_creator and not is_project_member: + raise HTTPException(status_code=403, detail="Access denied to character") + return character + + +@router.post("/", response_model=Environment) +async def create_environment( + env_req: EnvironmentCreate, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}") + await check_character_access(env_req.character_id, current_user, dao) + + # Verify assets exist if provided + if env_req.asset_ids: + for aid in env_req.asset_ids: + asset = await dao.assets.get_asset(aid) + if not asset: + raise HTTPException(status_code=400, detail=f"Asset {aid} not found") + + new_env = Environment(**env_req.model_dump()) + created_env = await dao.environments.create_env(new_env) + return created_env + + +@router.get("/character/{character_id}", response_model=List[Environment]) +async def get_character_environments( + character_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + logger.info(f"Getting environments for character {character_id}") + await check_character_access(character_id, current_user, dao) + return await dao.environments.get_character_envs(character_id) + + +@router.get("/{env_id}", response_model=Environment) +async def get_environment( + env_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + return env + + +@router.put("/{env_id}", response_model=Environment) +async def update_environment( + env_id: str, + env_update: EnvironmentUpdate, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + + update_data = env_update.model_dump(exclude_unset=True) + if not update_data: + return env + + success = await dao.environments.update_env(env_id, update_data) + if not success: + raise HTTPException(status_code=500, detail="Failed to update environment") + + return await dao.environments.get_env(env_id) + + +@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_environment( + env_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + + success = await dao.environments.delete_env(env_id) + if not success: + raise HTTPException(status_code=500, detail="Failed to delete environment") + return None + + +@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK) +async def add_asset_to_environment( + env_id: str, + req: AssetToEnvironment, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + + # Verify asset exists + asset = await dao.assets.get_asset(req.asset_id) + if not asset: + raise HTTPException(status_code=404, detail="Asset not found") + + success = await dao.environments.add_asset(env_id, req.asset_id) + return {"success": success} + + +@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK) +async def add_assets_batch_to_environment( + env_id: str, + req: AssetsToEnvironment, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + + # Verify all assets exist + assets = await dao.assets.get_assets_by_ids(req.asset_ids) + if len(assets) != len(req.asset_ids): + found_ids = {a.id for a in assets} + missing_ids = [aid for aid in req.asset_ids if aid not in found_ids] + raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}") + + success = await dao.environments.add_assets(env_id, req.asset_ids) + return {"success": success} + + +@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK) +async def remove_asset_from_environment( + env_id: str, + asset_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + env = await dao.environments.get_env(env_id) + if not env: + raise HTTPException(status_code=404, detail="Environment not found") + + await check_character_access(env.character_id, current_user, dao) + + success = await dao.environments.remove_asset(env_id, asset_id) + return {"success": success} diff --git a/api/models/EnvironmentRequest.py b/api/models/EnvironmentRequest.py new file mode 100644 index 0000000..2057b08 --- /dev/null +++ b/api/models/EnvironmentRequest.py @@ -0,0 +1,22 @@ +from typing import Optional, List +from pydantic import BaseModel, Field + + +class EnvironmentCreate(BaseModel): + character_id: str + name: str = Field(..., min_length=1) + description: Optional[str] = None + asset_ids: Optional[List[str]] = [] + + +class EnvironmentUpdate(BaseModel): + name: Optional[str] = Field(None, min_length=1) + description: Optional[str] = None + + +class AssetToEnvironment(BaseModel): + asset_id: str + + +class AssetsToEnvironment(BaseModel): + asset_ids: List[str] diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index e6b2ae6..0094189 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -16,6 +16,7 @@ class GenerationRequest(BaseModel): telegram_id: Optional[int] = None use_profile_image: bool = True assets_list: List[str] + environment_id: Optional[str] = None project_id: Optional[str] = None idea_id: Optional[str] = None count: int = Field(default=1, ge=1, le=10) diff --git a/api/service/generation_service.py b/api/service/generation_service.py index 0081bdb..fd675e0 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -9,6 +9,7 @@ from uuid import uuid4 import httpx from aiogram import Bot from aiogram.types import BufferedInputFile +from fastapi import HTTPException from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter @@ -133,6 +134,9 @@ class GenerationService: gen_id = None generation_model = None + if generation_request.environment_id and not generation_request.linked_character_id: + raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided") + try: generation_model = Generation(**generation_request.model_dump(exclude={'count'})) if user_id: @@ -186,45 +190,40 @@ class GenerationService: logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") # 2. Получаем ассеты-референсы (если они есть) - reference_assets: List[Asset] = [] media_group_bytes: List[bytes] = [] generation_prompt = generation.prompt -# generation_prompt = f""" - -# Create detailed image of character in scene. - -# SCENE DESCRIPTION: {generation.prompt} - -# Rules: -# - Integrate the character's appearance naturally into the scene description. -# - Focus on lighting, texture, and composition. -# """ + + # 2.1 Аватар персонажа (всегда первый, если включен) if generation.linked_character_id is not None: char_info = await self.dao.chars.get_character(generation.linked_character_id) if char_info is None: raise Exception(f"Character ID {generation.linked_character_id} not found") - if generation.use_profile_image: - if char_info.avatar_asset_id is not None: - avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) - if avatar_asset and avatar_asset.data: - media_group_bytes.append(avatar_asset.data) - # generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}") + + if generation.use_profile_image and char_info.avatar_asset_id: + avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) + if avatar_asset: + img_data = await self._get_asset_data(avatar_asset) + if img_data: + media_group_bytes.append(img_data) - reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) - - # Извлекаем данные (bytes) из ассетов для отправки в Gemini - for asset in reference_assets: - if asset.content_type != AssetContentType.IMAGE: - continue - - img_data = None - if asset.minio_object_name: - img_data = await self.s3_adapter.get_file(asset.minio_object_name) - elif asset.data: - img_data = asset.data - - if img_data: - media_group_bytes.append(img_data) + # 2.2 Явно указанные ассеты + if generation.assets_list: + explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) + for asset in explicit_assets: + ref_asset_data = await self._get_asset_data(asset) + if ref_asset_data: + media_group_bytes.append(ref_asset_data) + + # 2.3 Ассеты из окружения (в самый конец) + if generation.environment_id: + env = await self.dao.environments.get_env(generation.environment_id) + if env and env.asset_ids: + logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})") + env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids) + for asset in env_assets: + img_data = await self._get_asset_data(asset) + if img_data: + media_group_bytes.append(img_data) if media_group_bytes: generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity." @@ -341,6 +340,14 @@ class GenerationService: logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}") + async def _get_asset_data(self, asset: Asset) -> Optional[bytes]: + if asset.content_type != AssetContentType.IMAGE: + return None + + if asset.minio_object_name: + return await self.s3_adapter.get_file(asset.minio_object_name) + return asset.data + async def _simulate_progress(self, generation: Generation): """ Increments progress from 0 to 90 over ~20 seconds. diff --git a/models/Environment.py b/models/Environment.py new file mode 100644 index 0000000..0b1e3f1 --- /dev/null +++ b/models/Environment.py @@ -0,0 +1,20 @@ +from typing import List, Optional +from pydantic import BaseModel, Field, ConfigDict +from datetime import datetime +from bson import ObjectId + + +class Environment(BaseModel): + id: Optional[str] = Field(None, alias="_id") + character_id: str + name: str = Field(..., min_length=1) + description: Optional[str] = None + asset_ids: List[str] = Field(default_factory=list) + created_at: datetime = Field(default_factory=datetime.utcnow) + updated_at: datetime = Field(default_factory=datetime.utcnow) + + model_config = ConfigDict( + populate_by_name=True, + json_encoders={ObjectId: str}, + arbitrary_types_allowed=True + ) diff --git a/models/Generation.py b/models/Generation.py index 8f164f9..50c535f 100644 --- a/models/Generation.py +++ b/models/Generation.py @@ -35,6 +35,7 @@ class Generation(BaseModel): output_token_usage: Optional[int] = None is_deleted: bool = False album_id: Optional[str] = None + environment_id: Optional[str] = None generation_group_id: Optional[str] = None created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) project_id: Optional[str] = None diff --git a/repos/dao.py b/repos/dao.py index f251a80..6c53045 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -8,6 +8,7 @@ from repos.albums_repo import AlbumsRepo from repos.project_repo import ProjectRepo from repos.idea_repo import IdeaRepo from repos.post_repo import PostRepo +from repos.environment_repo import EnvironmentRepo from typing import Optional @@ -23,3 +24,4 @@ class DAO: self.users = UsersRepo(client, db_name) self.ideas = IdeaRepo(client, db_name) self.posts = PostRepo(client, db_name) + self.environments = EnvironmentRepo(client, db_name) diff --git a/repos/environment_repo.py b/repos/environment_repo.py new file mode 100644 index 0000000..c77834e --- /dev/null +++ b/repos/environment_repo.py @@ -0,0 +1,73 @@ +from typing import List, Optional +from datetime import datetime +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient +from models.Environment import Environment + + +class EnvironmentRepo: + def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): + self.collection = client[db_name]["environments"] + + async def create_env(self, env: Environment) -> Environment: + env_dict = env.model_dump(exclude={"id"}) + res = await self.collection.insert_one(env_dict) + env.id = str(res.inserted_id) + return env + + async def get_env(self, env_id: str) -> Optional[Environment]: + res = await self.collection.find_one({"_id": ObjectId(env_id)}) + if not res: + return None + res["id"] = str(res.pop("_id")) + return Environment(**res) + + async def get_character_envs(self, character_id: str) -> List[Environment]: + cursor = self.collection.find({"character_id": character_id}) + envs = [] + async for doc in cursor: + doc["id"] = str(doc.pop("_id")) + envs.append(Environment(**doc)) + return envs + + async def update_env(self, env_id: str, update_data: dict) -> bool: + update_data["updated_at"] = datetime.utcnow() + res = await self.collection.update_one( + {"_id": ObjectId(env_id)}, + {"$set": update_data} + ) + return res.modified_count > 0 + + async def delete_env(self, env_id: str) -> bool: + res = await self.collection.delete_one({"_id": ObjectId(env_id)}) + return res.deleted_count > 0 + + async def add_asset(self, env_id: str, asset_id: str) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(env_id)}, + { + "$addToSet": {"asset_ids": asset_id}, + "$set": {"updated_at": datetime.utcnow()} + } + ) + return res.modified_count > 0 + + async def add_assets(self, env_id: str, asset_ids: List[str]) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(env_id)}, + { + "$addToSet": {"asset_ids": {"$each": asset_ids}}, + "$set": {"updated_at": datetime.utcnow()} + } + ) + return res.modified_count > 0 + + async def remove_asset(self, env_id: str, asset_id: str) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(env_id)}, + { + "$pull": {"asset_ids": asset_id}, + "$set": {"updated_at": datetime.utcnow()} + } + ) + return res.modified_count > 0 diff --git a/repos/generation_repo.py b/repos/generation_repo.py index c3165e6..1109196 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -97,11 +97,12 @@ class GenerationRepo: async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict: """ Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation. + Includes even soft-deleted generations to reflect actual expenditure. """ pipeline = [] - # 1. Match active done generations - match_stage = {"is_deleted": False, "status": GenerationStatus.DONE} + # 1. Match all done generations (including soft-deleted) + match_stage = {"status": GenerationStatus.DONE} if created_by: match_stage["created_by"] = created_by if project_id: @@ -156,10 +157,11 @@ class GenerationRepo: async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]: """ Returns usage statistics grouped by user or project. + Includes even soft-deleted generations to reflect actual expenditure. """ pipeline = [] - match_stage = {"is_deleted": False, "status": GenerationStatus.DONE} + match_stage = {"status": GenerationStatus.DONE} if project_id: match_stage["project_id"] = project_id if created_by: