diff --git a/api/endpoints/environment_router.py b/api/endpoints/environment_router.py index aafb2c5..3b1f6b0 100644 --- a/api/endpoints/environment_router.py +++ b/api/endpoints/environment_router.py @@ -91,6 +91,18 @@ async def update_environment( update_data = env_update.model_dump(exclude_unset=True) if not update_data: return env + + # Verify assets exist if provided + if "asset_ids" in update_data: + if update_data["asset_ids"] is None: + del update_data["asset_ids"] + elif update_data["asset_ids"]: + # Verify all assets exist using batch check + assets = await dao.assets.get_assets_by_ids(update_data["asset_ids"]) + if len(assets) != len(update_data["asset_ids"]): + found_ids = {a.id for a in assets} + missing_ids = [aid for aid in update_data["asset_ids"] if aid not in found_ids] + raise HTTPException(status_code=400, detail=f"Some assets not found: {missing_ids}") success = await dao.environments.update_env(env_id, update_data) if not success: diff --git a/api/models/EnvironmentRequest.py b/api/models/EnvironmentRequest.py index 2057b08..430bc0a 100644 --- a/api/models/EnvironmentRequest.py +++ b/api/models/EnvironmentRequest.py @@ -12,6 +12,7 @@ class EnvironmentCreate(BaseModel): class EnvironmentUpdate(BaseModel): name: Optional[str] = Field(None, min_length=1) description: Optional[str] = None + asset_ids: Optional[List[str]] = None class AssetToEnvironment(BaseModel): diff --git a/repos/assets_repo.py b/repos/assets_repo.py index 07ae690..7df5dba 100644 --- a/repos/assets_repo.py +++ b/repos/assets_repo.py @@ -102,7 +102,7 @@ class AssetsRepo: return assets - async def get_asset(self, asset_id: str, with_data: bool = True) -> Asset: + async def get_asset(self, asset_id: str, with_data: bool = True) -> Optional[Asset]: projection = None if not with_data: projection = {"data": 0, "thumbnail": 0} @@ -182,7 +182,9 @@ class AssetsRepo: return await self.collection.count_documents(filter) async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: - object_ids = [ObjectId(asset_id) for asset_id in asset_ids] + object_ids = [ObjectId(asset_id) for asset_id in asset_ids if ObjectId.is_valid(asset_id)] + if not object_ids: + return [] res = self.collection.find({"_id": {"$in": object_ids}}, {"data": 0}) # Exclude data but maybe allow thumbnail if small? # Original excluded thumbnail too. assets = []