+ env
This commit is contained in:
180
api/endpoints/environment_router.py
Normal file
180
api/endpoints/environment_router.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from starlette import status
|
||||
|
||||
from api.dependency import get_dao
|
||||
from api.endpoints.auth import get_current_user
|
||||
from api.models.EnvironmentRequest import EnvironmentCreate, EnvironmentUpdate, AssetToEnvironment, AssetsToEnvironment
|
||||
from models.Environment import Environment
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/environments", tags=["Environments"], dependencies=[Depends(get_current_user)])
|
||||
|
||||
|
||||
async def check_character_access(character_id: str, current_user: dict, dao: DAO):
|
||||
character = await dao.chars.get_character(character_id)
|
||||
if not character:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
is_creator = character.created_by == str(current_user["_id"])
|
||||
is_project_member = False
|
||||
if character.project_id and character.project_id in current_user.get("project_ids", []):
|
||||
is_project_member = True
|
||||
|
||||
if not is_creator and not is_project_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied to character")
|
||||
return character
|
||||
|
||||
|
||||
@router.post("/", response_model=Environment)
|
||||
async def create_environment(
|
||||
env_req: EnvironmentCreate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Creating environment '{env_req.name}' for character {env_req.character_id}")
|
||||
await check_character_access(env_req.character_id, current_user, dao)
|
||||
|
||||
# Verify assets exist if provided
|
||||
if env_req.asset_ids:
|
||||
for aid in env_req.asset_ids:
|
||||
asset = await dao.assets.get_asset(aid)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=400, detail=f"Asset {aid} not found")
|
||||
|
||||
new_env = Environment(**env_req.model_dump())
|
||||
created_env = await dao.environments.create_env(new_env)
|
||||
return created_env
|
||||
|
||||
|
||||
@router.get("/character/{character_id}", response_model=List[Environment])
|
||||
async def get_character_environments(
|
||||
character_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
logger.info(f"Getting environments for character {character_id}")
|
||||
await check_character_access(character_id, current_user, dao)
|
||||
return await dao.environments.get_character_envs(character_id)
|
||||
|
||||
|
||||
@router.get("/{env_id}", response_model=Environment)
|
||||
async def get_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
return env
|
||||
|
||||
|
||||
@router.put("/{env_id}", response_model=Environment)
|
||||
async def update_environment(
|
||||
env_id: str,
|
||||
env_update: EnvironmentUpdate,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
update_data = env_update.model_dump(exclude_unset=True)
|
||||
if not update_data:
|
||||
return env
|
||||
|
||||
success = await dao.environments.update_env(env_id, update_data)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update environment")
|
||||
|
||||
return await dao.environments.get_env(env_id)
|
||||
|
||||
|
||||
@router.delete("/{env_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_environment(
|
||||
env_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.delete_env(env_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete environment")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets", status_code=status.HTTP_200_OK)
|
||||
async def add_asset_to_environment(
|
||||
env_id: str,
|
||||
req: AssetToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify asset exists
|
||||
asset = await dao.assets.get_asset(req.asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
success = await dao.environments.add_asset(env_id, req.asset_id)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.post("/{env_id}/assets/batch", status_code=status.HTTP_200_OK)
|
||||
async def add_assets_batch_to_environment(
|
||||
env_id: str,
|
||||
req: AssetsToEnvironment,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
# Verify all assets exist
|
||||
assets = await dao.assets.get_assets_by_ids(req.asset_ids)
|
||||
if len(assets) != len(req.asset_ids):
|
||||
found_ids = {a.id for a in assets}
|
||||
missing_ids = [aid for aid in req.asset_ids if aid not in found_ids]
|
||||
raise HTTPException(status_code=404, detail=f"Some assets not found: {missing_ids}")
|
||||
|
||||
success = await dao.environments.add_assets(env_id, req.asset_ids)
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@router.delete("/{env_id}/assets/{asset_id}", status_code=status.HTTP_200_OK)
|
||||
async def remove_asset_from_environment(
|
||||
env_id: str,
|
||||
asset_id: str,
|
||||
dao: DAO = Depends(get_dao),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
env = await dao.environments.get_env(env_id)
|
||||
if not env:
|
||||
raise HTTPException(status_code=404, detail="Environment not found")
|
||||
|
||||
await check_character_access(env.character_id, current_user, dao)
|
||||
|
||||
success = await dao.environments.remove_asset(env_id, asset_id)
|
||||
return {"success": success}
|
||||
22
api/models/EnvironmentRequest.py
Normal file
22
api/models/EnvironmentRequest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EnvironmentCreate(BaseModel):
|
||||
character_id: str
|
||||
name: str = Field(..., min_length=1)
|
||||
description: Optional[str] = None
|
||||
asset_ids: Optional[List[str]] = []
|
||||
|
||||
|
||||
class EnvironmentUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1)
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class AssetToEnvironment(BaseModel):
|
||||
asset_id: str
|
||||
|
||||
|
||||
class AssetsToEnvironment(BaseModel):
|
||||
asset_ids: List[str]
|
||||
@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
|
||||
telegram_id: Optional[int] = None
|
||||
use_profile_image: bool = True
|
||||
assets_list: List[str]
|
||||
environment_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
idea_id: Optional[str] = None
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
@@ -9,6 +9,7 @@ from uuid import uuid4
|
||||
import httpx
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
from fastapi import HTTPException
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
@@ -133,6 +134,9 @@ class GenerationService:
|
||||
gen_id = None
|
||||
generation_model = None
|
||||
|
||||
if generation_request.environment_id and not generation_request.linked_character_id:
|
||||
raise HTTPException(status_code=400, detail="environment_id can only be used when linked_character_id is provided")
|
||||
|
||||
try:
|
||||
generation_model = Generation(**generation_request.model_dump(exclude={'count'}))
|
||||
if user_id:
|
||||
@@ -186,45 +190,40 @@ class GenerationService:
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = generation.prompt
|
||||
# generation_prompt = f"""
|
||||
|
||||
# Create detailed image of character in scene.
|
||||
|
||||
# SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
# Rules:
|
||||
# - Integrate the character's appearance naturally into the scene description.
|
||||
# - Focus on lighting, texture, and composition.
|
||||
# """
|
||||
|
||||
# 2.1 Аватар персонажа (всегда первый, если включен)
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
raise Exception(f"Character ID {generation.linked_character_id} not found")
|
||||
if generation.use_profile_image:
|
||||
if char_info.avatar_asset_id is not None:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset and avatar_asset.data:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
img_data = await self._get_asset_data(avatar_asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
reference_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
|
||||
# Извлекаем данные (bytes) из ассетов для отправки в Gemini
|
||||
for asset in reference_assets:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
continue
|
||||
|
||||
img_data = None
|
||||
if asset.minio_object_name:
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
elif asset.data:
|
||||
img_data = asset.data
|
||||
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
# 2.2 Явно указанные ассеты
|
||||
if generation.assets_list:
|
||||
explicit_assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in explicit_assets:
|
||||
ref_asset_data = await self._get_asset_data(asset)
|
||||
if ref_asset_data:
|
||||
media_group_bytes.append(ref_asset_data)
|
||||
|
||||
# 2.3 Ассеты из окружения (в самый конец)
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
logger.info(f"Loading {len(env.asset_ids)} assets from environment {env.name} ({env.id})")
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
img_data = await self._get_asset_data(asset)
|
||||
if img_data:
|
||||
media_group_bytes.append(img_data)
|
||||
|
||||
if media_group_bytes:
|
||||
generation_prompt += " \n\n[Reference Image Guidance]: Use the provided image(s) as the STRICT reference for the main character's facial features and hair, enviroment or clothes. Maintain high fidelity to the reference identity."
|
||||
@@ -341,6 +340,14 @@ class GenerationService:
|
||||
logger.error(f"Failed to send assets to Telegram ID {generation.telegram_id}: {e}")
|
||||
|
||||
|
||||
async def _get_asset_data(self, asset: Asset) -> Optional[bytes]:
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
|
||||
async def _simulate_progress(self, generation: Generation):
|
||||
"""
|
||||
Increments progress from 0 to 90 over ~20 seconds.
|
||||
|
||||
Reference in New Issue
Block a user