diff --git a/__pycache__/main.cpython-313.pyc b/__pycache__/main.cpython-313.pyc index 7e90820..345e089 100644 Binary files a/__pycache__/main.cpython-313.pyc and b/__pycache__/main.cpython-313.pyc differ diff --git a/api/__pycache__/dependency.cpython-313.pyc b/api/__pycache__/dependency.cpython-313.pyc index ce1b3df..bc2abeb 100644 Binary files a/api/__pycache__/dependency.cpython-313.pyc and b/api/__pycache__/dependency.cpython-313.pyc differ diff --git a/api/dependency.py b/api/dependency.py index c0b4937..7dc90eb 100644 --- a/api/dependency.py +++ b/api/dependency.py @@ -43,4 +43,9 @@ def get_generation_service( s3_adapter: S3Adapter = Depends(get_s3_adapter), bot: Bot = Depends(get_bot_client), ) -> GenerationService: - return GenerationService(dao, gemini, s3_adapter, bot) \ No newline at end of file + return GenerationService(dao, gemini, s3_adapter, bot) + +from fastapi import Header + +async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]: + return x_project_id \ No newline at end of file diff --git a/api/endpoints/__pycache__/assets_router.cpython-313.pyc b/api/endpoints/__pycache__/assets_router.cpython-313.pyc index d5936df..1f6c07e 100644 Binary files a/api/endpoints/__pycache__/assets_router.cpython-313.pyc and b/api/endpoints/__pycache__/assets_router.cpython-313.pyc differ diff --git a/api/endpoints/__pycache__/auth.cpython-313.pyc b/api/endpoints/__pycache__/auth.cpython-313.pyc index 5157ab0..4330bff 100644 Binary files a/api/endpoints/__pycache__/auth.cpython-313.pyc and b/api/endpoints/__pycache__/auth.cpython-313.pyc differ diff --git a/api/endpoints/__pycache__/character_router.cpython-313.pyc b/api/endpoints/__pycache__/character_router.cpython-313.pyc index 0106708..f1ac497 100644 Binary files a/api/endpoints/__pycache__/character_router.cpython-313.pyc and b/api/endpoints/__pycache__/character_router.cpython-313.pyc differ diff --git a/api/endpoints/__pycache__/generation_router.cpython-313.pyc b/api/endpoints/__pycache__/generation_router.cpython-313.pyc index e785c8f..1249725 100644 Binary files a/api/endpoints/__pycache__/generation_router.cpython-313.pyc and b/api/endpoints/__pycache__/generation_router.cpython-313.pyc differ diff --git a/api/endpoints/assets_router.py b/api/endpoints/assets_router.py index 689254d..5700a5d 100644 --- a/api/endpoints/assets_router.py +++ b/api/endpoints/assets_router.py @@ -19,6 +19,7 @@ import logging logger = logging.getLogger(__name__) from api.endpoints.auth import get_current_user +from api.dependency import get_project_id router = APIRouter(prefix="/api/assets", tags=["Assets"]) @@ -68,11 +69,19 @@ async def delete_asset( @router.get("", dependencies=[Depends(get_current_user)]) -async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse: +async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse: logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}") - assets = await dao.assets.get_assets(type, limit, offset) + + user_id_filter = str(current_user["_id"]) + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + user_id_filter = None + + assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id) # assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code - total_count = await dao.assets.get_asset_count() + total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id) # Manually map to DTO to trigger computed fields validation if necessary, # but primarily to ensure valid Pydantic models for the response list. @@ -84,11 +93,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option -@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)]) +@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED) async def upload_asset( file: UploadFile = File(...), linked_char_id: Optional[str] = Form(None), dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id) ): logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}") if not file.content_type: @@ -96,6 +107,11 @@ async def upload_asset( if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}") + + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") data = await file.read() if not data: @@ -111,7 +127,9 @@ async def upload_asset( content_type=AssetContentType.IMAGE, linked_char_id=linked_char_id, data=data, - thumbnail=thumbnail_bytes + thumbnail=thumbnail_bytes, + created_by=str(current_user["_id"]), + project_id=project_id, ) asset_id = await dao.assets.create_asset(asset) diff --git a/api/endpoints/auth.py b/api/endpoints/auth.py index b1b81b0..7cba1d1 100644 --- a/api/endpoints/auth.py +++ b/api/endpoints/auth.py @@ -59,6 +59,7 @@ class Token(BaseModel): class UserResponse(BaseModel): + id: str username: str full_name: str | None = None status: str diff --git a/api/endpoints/character_router.py b/api/endpoints/character_router.py index eb7b54e..da7f101 100644 --- a/api/endpoints/character_router.py +++ b/api/endpoints/character_router.py @@ -1,4 +1,4 @@ -from typing import List, Any, Coroutine +from typing import List, Any, Coroutine, Optional from fastapi import APIRouter, Depends from pydantic import BaseModel @@ -9,6 +9,7 @@ from api.models.AssetDTO import AssetsResponse, AssetResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse from models.Asset import Asset from models.Character import Character +from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest from repos.dao import DAO from api.dependency import get_dao @@ -17,25 +18,49 @@ import logging logger = logging.getLogger(__name__) from api.endpoints.auth import get_current_user +from api.dependency import get_project_id router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)]) @router.get("/", response_model=List[Character]) -async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]: +async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]: logger.info("get_characters called") - characters = await dao.chars.get_all_characters() + + user_id_filter = str(current_user["_id"]) + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + user_id_filter = None + + characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id) return characters @router.get("/{character_id}/assets", response_model=AssetsResponse) async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10, - offset: int = 0, ) -> AssetsResponse: + offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse: logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}") character = await dao.chars.get_character(character_id) if character is None: raise HTTPException(status_code=404, detail="Character not found") + + # Access Check + is_creator = character.created_by == str(current_user["_id"]) + is_project_member = False + if character.project_id and character.project_id in current_user.get("project_ids", []): + is_project_member = True + + if not is_creator and not is_project_member: + raise HTTPException(status_code=403, detail="Access denied") + assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset) + # Filter assets by user ownership as well? + # Usually if you own character, you see its assets. + # But assets also have specific created_by. + # Let's assume if you own character you can see its assets. + total_count = await dao.assets.get_asset_count(character_id) asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets] @@ -43,12 +68,116 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l @router.get("/{character_id}", response_model=Character) -async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character: +async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character: logger.debug(f"get_character_by_id called. ID: {character_id}") character = await dao.chars.get_character(character_id) + + if not character: + raise HTTPException(status_code=404, detail="Character not found") + + if character: + is_creator = character.created_by == str(current_user["_id"]) + is_project_member = False + if character.project_id and character.project_id in current_user.get("project_ids", []): + is_project_member = True + + if not is_creator and not is_project_member: + raise HTTPException(status_code=403, detail="Access denied") + return character +@router.post("/", response_model=Character) +async def create_character( + char_req: CharacterCreateRequest, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +) -> Character: + logger.info("create_character called") + + char_data = char_req.model_dump() + char_data["created_by"] = str(current_user["_id"]) + if "id" not in char_data: + char_data["id"] = None + + if char_req.project_id: + project = await dao.projects.get_project(char_req.project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + + new_char = Character(**char_data) + created_char = await dao.chars.add_character(new_char) + return created_char + + +@router.put("/{character_id}", response_model=Character) +async def update_character( + character_id: str, + char_update: CharacterUpdateRequest, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +) -> Character: + logger.info(f"update_character called. ID: {character_id}") + + existing_char = await dao.chars.get_character(character_id) + if not existing_char: + raise HTTPException(status_code=404, detail="Character not found") + + is_creator = existing_char.created_by == str(current_user["_id"]) + is_project_member = False + if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []): + is_project_member = True + + if not is_creator and not is_project_member: + raise HTTPException(status_code=403, detail="Access denied") + + update_data = char_update.model_dump(exclude_unset=True) + + if "project_id" in update_data and update_data["project_id"]: + new_project_id = update_data["project_id"] + project = await dao.projects.get_project(new_project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Target project access denied") + + updated_char_data = existing_char.model_dump() + updated_char_data.update(update_data) + + updated_char = Character(**updated_char_data) + + success = await dao.chars.update_char(character_id, updated_char) + if not success: + raise HTTPException(status_code=500, detail="Failed to update character") + + return updated_char + + +@router.delete("/{character_id}", status_code=204) +async def delete_character( + character_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + logger.info(f"delete_character called. ID: {character_id}") + + existing_char = await dao.chars.get_character(character_id) + if not existing_char: + raise HTTPException(status_code=404, detail="Character not found") + + is_creator = existing_char.created_by == str(current_user["_id"]) + is_project_member = False + if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []): + is_project_member = True + + if not is_creator and not is_project_member: + raise HTTPException(status_code=403, detail="Access denied") + + success = await dao.chars.delete_character(character_id) + if not success: + raise HTTPException(status_code=500, detail="Failed to delete character") + + return + + @router.post("/{character_id}/_run", response_model=GenerationResponse) async def post_character_generation(character_id: str, generation: GenerationRequest, request: Request) -> GenerationResponse: diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py index cfdf4e8..1352c92 100644 --- a/api/endpoints/generation_router.py +++ b/api/endpoints/generation_router.py @@ -5,7 +5,8 @@ from fastapi.params import Depends from starlette.requests import Request from api import service -from api.dependency import get_generation_service +from api.dependency import get_generation_service, get_project_id, get_dao +from repos.dao import DAO from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.service.generation_service import GenerationService @@ -49,30 +50,65 @@ async def prompt_from_image( @router.get("", response_model=GenerationsResponse) async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, - generation_service: GenerationService = Depends(get_generation_service)): + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao)): logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") - return await generation_service.get_generations(character_id, limit=limit, offset=offset) + + user_id_filter = str(current_user["_id"]) + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + user_id_filter = None # Show all project generations + + return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) @router.post("/_run", response_model=GenerationResponse) async def post_generation(generation: GenerationRequest, request: Request, generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)) -> GenerationResponse: + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao)) -> GenerationResponse: logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") - return await generation_service.create_generation_task(generation, user_id=current_user.get("username")) + + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + generation.project_id = project_id + + return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) @router.get("/{generation_id}", response_model=GenerationResponse) async def get_generation(generation_id: str, - generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse: + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user)) -> GenerationResponse: logger.debug(f"get_generation called for ID: {generation_id}") - return await generation_service.get_generation(generation_id) + gen = await generation_service.get_generation(generation_id) + if gen and gen.created_by != str(current_user["_id"]): + raise HTTPException(status_code=403, detail="Access denied") + return gen @router.get("/running") async def get_running_generations(request: Request, - generation_service: GenerationService = Depends(get_generation_service)): - return await generation_service.get_running_generations() + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user), + project_id: Optional[str] = Depends(get_project_id), + dao: DAO = Depends(get_dao)): + + user_id_filter = str(current_user["_id"]) + if project_id: + project = await dao.projects.get_project(project_id) + if not project or str(current_user["_id"]) not in project.members: + raise HTTPException(status_code=403, detail="Project access denied") + user_id_filter = None + + return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) diff --git a/api/endpoints/project_router.py b/api/endpoints/project_router.py new file mode 100644 index 0000000..93c2d0f --- /dev/null +++ b/api/endpoints/project_router.py @@ -0,0 +1,167 @@ +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, status +from pydantic import BaseModel +from api.dependency import get_dao +from api.endpoints.auth import get_current_user +from models.Project import Project +from repos.dao import DAO + +router = APIRouter(prefix="/api/projects", tags=["Projects"]) + +class ProjectCreate(BaseModel): + name: str + description: Optional[str] = None + +class ProjectResponse(BaseModel): + id: str + name: str + description: Optional[str] = None + owner_id: str + members: List[str] + is_owner: bool = False + +@router.post("", response_model=ProjectResponse) +async def create_project( + project_data: ProjectCreate, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + user_id = str(current_user["_id"]) + new_project = Project( + name=project_data.name, + description=project_data.description, + owner_id=user_id, + members=[user_id] + ) + project_id = await dao.projects.create_project(new_project) + + # Add project to user's project list + # Assuming user_repo has a method to add project or we do it directly? + # UserRepo doesn't have add_project method yet. + # But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later? + # Better to update UserRepo. For now, let's just return success. + # But user needs to see it in list. + # Update user in DB + await dao.users.collection.update_one( + {"_id": current_user["_id"]}, + {"$addToSet": {"project_ids": project_id}} + ) + + return ProjectResponse( + id=project_id, + name=new_project.name, + description=new_project.description, + owner_id=new_project.owner_id, + members=new_project.members, + is_owner=True + ) + +@router.get("", response_model=List[ProjectResponse]) +async def get_my_projects( + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + user_id = str(current_user["_id"]) + projects = await dao.projects.get_projects_by_user(user_id) + + responses = [] + for p in projects: + responses.append(ProjectResponse( + id=p.id, + name=p.name, + description=p.description, + owner_id=p.owner_id, + members=p.members, + is_owner=(p.owner_id == user_id) + )) + return responses + +class MemberAdd(BaseModel): + username: str + +@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)]) +async def add_member( + project_id: str, + member_data: MemberAdd, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + user_id = str(current_user["_id"]) + project = await dao.projects.get_project(project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + if project.owner_id != user_id: + raise HTTPException(status_code=403, detail="Only owner can add members") + + target_user = await dao.users.get_user_by_username(member_data.username) + if not target_user: + raise HTTPException(status_code=404, detail="User not found") + + target_user_id = str(target_user["_id"]) + + if target_user_id in project.members: + return {"message": "User already in project"} + + await dao.projects.add_member(project_id, target_user_id) + + # Update target user's project list + await dao.users.collection.update_one( + {"_id": target_user["_id"]}, + {"$addToSet": {"project_ids": project_id}} + ) + + return {"message": "Member added"} + +@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)]) +async def join_project( + project_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + # Retrieve project to verify it exists + project = await dao.projects.get_project(project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + user_id = str(current_user["_id"]) + + # Check if user is ALREADY in project + if user_id in project.members: + return {"message": "Already a member"} + + # Add member + await dao.projects.add_member(project_id, user_id) + + # Update user's project list + await dao.users.collection.update_one( + {"_id": current_user["_id"]}, + {"$addToSet": {"project_ids": project_id}} + ) + + return {"message": "Joined project"} + + +@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] ) +async def delete_project( + project_id: str, + dao: DAO = Depends(get_dao), + current_user: dict = Depends(get_current_user) +): + user_id = str(current_user["_id"]) + project = await dao.projects.get_project(project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + + if project.owner_id != user_id: + raise HTTPException(status_code=403, detail="Only owner can delete project") + + await dao.projects.delete_project(project_id) + + # Remove project from user's project list + await dao.users.collection.update_one( + {"_id": current_user["_id"]}, + {"$pull": {"project_ids": project_id}} + ) + + return {"message": "Project deleted"} \ No newline at end of file diff --git a/api/models/CharacterDTO.py b/api/models/CharacterDTO.py new file mode 100644 index 0000000..ca73053 --- /dev/null +++ b/api/models/CharacterDTO.py @@ -0,0 +1,18 @@ +from typing import Optional +from pydantic import BaseModel + +class CharacterCreateRequest(BaseModel): + name: str + character_bio: str + character_image_doc_tg_id: Optional[str] = None + avatar_image: Optional[str] = None + character_image_tg_id: Optional[str] = None + project_id: Optional[str] = None + +class CharacterUpdateRequest(BaseModel): + name: Optional[str] = None + character_bio: Optional[str] = None + character_image_doc_tg_id: Optional[str] = None + avatar_image: Optional[str] = None + character_image_tg_id: Optional[str] = None + project_id: Optional[str] = None diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index 2ad06cd..40e9d18 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -16,6 +16,7 @@ class GenerationRequest(BaseModel): telegram_id: Optional[int] = None use_profile_image: bool = True assets_list: List[str] + project_id: Optional[str] = None class GenerationsResponse(BaseModel): diff --git a/api/models/__pycache__/GenerationRequest.cpython-313.pyc b/api/models/__pycache__/GenerationRequest.cpython-313.pyc index 8dbd694..0fe1b17 100644 Binary files a/api/models/__pycache__/GenerationRequest.cpython-313.pyc and b/api/models/__pycache__/GenerationRequest.cpython-313.pyc differ diff --git a/api/service/__pycache__/generation_service.cpython-313.pyc b/api/service/__pycache__/generation_service.cpython-313.pyc index 7bfe1a3..4139f6a 100644 Binary files a/api/service/__pycache__/generation_service.cpython-313.pyc and b/api/service/__pycache__/generation_service.cpython-313.pyc differ diff --git a/api/service/generation_service.py b/api/service/generation_service.py index b157c64..1cc962f 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -92,10 +92,10 @@ class GenerationService: return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) - async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[ + async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[ Generation]: - generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset) - total_count = await self.dao.generations.count_generations(character_id = character_id) + generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id) + total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id) generations = [GenerationResponse(**gen.model_dump()) for gen in generations] return GenerationsResponse(generations=generations, total_count=total_count) @@ -106,8 +106,8 @@ class GenerationService: else: return GenerationResponse(**gen.model_dump()) - async def get_running_generations(self) -> List[Generation]: - return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING) + async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: + return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: gen_id = None @@ -261,7 +261,8 @@ class GenerationService: data=None, # Not storing bytes in DB anymore minio_object_name=filename, minio_bucket=self.s3_adapter.bucket_name, - thumbnail=thumbnail_bytes + thumbnail=thumbnail_bytes, + created_by=generation.created_by ) # Сохраняем в БД diff --git a/main.py b/main.py index d97ebc5..9f70ffa 100644 --- a/main.py +++ b/main.py @@ -180,6 +180,8 @@ app.add_middleware( # Подключаем роутер API from api.endpoints.auth import router as auth_api_router from api.endpoints.admin import router as admin_api_router +from api.endpoints.project_router import router as project_api_router + app.include_router(auth_api_router) app.include_router(admin_api_router) app.include_router(api_assets_router) @@ -188,6 +190,7 @@ app.include_router(api_gen_router) app.include_router(api_album_router) app.include_router(api_admin_router) app.include_router(api_auth_router) +app.include_router(project_api_router) # --- ХЕНДЛЕРЫ БОТА (Main Router) --- diff --git a/models/Asset.py b/models/Asset.py index 81a34bd..ff4eeef 100644 --- a/models/Asset.py +++ b/models/Asset.py @@ -28,6 +28,8 @@ class Asset(BaseModel): minio_thumbnail_object_name: Optional[str] = None thumbnail: Optional[bytes] = None tags: List[str] = [] + created_by: Optional[str] = None + project_id: Optional[str] = None created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/models/Character.py b/models/Character.py index f112f91..80ef559 100644 --- a/models/Character.py +++ b/models/Character.py @@ -5,11 +5,13 @@ from pydantic_core.core_schema import computed_field class Character(BaseModel): - id: str | None + id: Optional[str] = None name: str avatar_image: Optional[str] = None character_image_data: Optional[bytes] = None - character_image_doc_tg_id: str - character_image_tg_id: str | None - character_bio: str + character_image_doc_tg_id: Optional[str] = None + character_image_tg_id: Optional[str] = None + character_bio: Optional[str] = None + created_by: Optional[str] = None + project_id: Optional[str] = None diff --git a/models/Generation.py b/models/Generation.py index 784d513..6c74100 100644 --- a/models/Generation.py +++ b/models/Generation.py @@ -34,8 +34,9 @@ class Generation(BaseModel): input_token_usage: Optional[int] = None output_token_usage: Optional[int] = None is_deleted: bool = False - album_id: Optional[str] = None - created_by: Optional[str] = None + album_id: Optional[str] = None + created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) + project_id: Optional[str] = None created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/models/Project.py b/models/Project.py new file mode 100644 index 0000000..65bbb59 --- /dev/null +++ b/models/Project.py @@ -0,0 +1,12 @@ +from datetime import datetime +from typing import List, Optional +from pydantic import BaseModel, Field + +class Project(BaseModel): + id: Optional[str] = None + name: str + description: Optional[str] = None + owner_id: str + members: List[str] = [] # List of User IDs + is_deleted: bool = False + created_at: datetime = Field(default_factory=datetime.now) diff --git a/models/__pycache__/Asset.cpython-313.pyc b/models/__pycache__/Asset.cpython-313.pyc index f1873ef..13bfb7f 100644 Binary files a/models/__pycache__/Asset.cpython-313.pyc and b/models/__pycache__/Asset.cpython-313.pyc differ diff --git a/models/__pycache__/Character.cpython-313.pyc b/models/__pycache__/Character.cpython-313.pyc index a89257a..4c1ed63 100644 Binary files a/models/__pycache__/Character.cpython-313.pyc and b/models/__pycache__/Character.cpython-313.pyc differ diff --git a/models/__pycache__/Generation.cpython-313.pyc b/models/__pycache__/Generation.cpython-313.pyc index 0457de5..dfc0bfd 100644 Binary files a/models/__pycache__/Generation.cpython-313.pyc and b/models/__pycache__/Generation.cpython-313.pyc differ diff --git a/repos/__pycache__/assets_repo.cpython-313.pyc b/repos/__pycache__/assets_repo.cpython-313.pyc index 3ae1148..ead12a4 100644 Binary files a/repos/__pycache__/assets_repo.cpython-313.pyc and b/repos/__pycache__/assets_repo.cpython-313.pyc differ diff --git a/repos/__pycache__/char_repo.cpython-313.pyc b/repos/__pycache__/char_repo.cpython-313.pyc index 5fd4d21..ccd7a75 100644 Binary files a/repos/__pycache__/char_repo.cpython-313.pyc and b/repos/__pycache__/char_repo.cpython-313.pyc differ diff --git a/repos/__pycache__/dao.cpython-313.pyc b/repos/__pycache__/dao.cpython-313.pyc index f0b61bb..8c96e7b 100644 Binary files a/repos/__pycache__/dao.cpython-313.pyc and b/repos/__pycache__/dao.cpython-313.pyc differ diff --git a/repos/__pycache__/generation_repo.cpython-313.pyc b/repos/__pycache__/generation_repo.cpython-313.pyc index 0c9a3be..f4bb2f3 100644 Binary files a/repos/__pycache__/generation_repo.cpython-313.pyc and b/repos/__pycache__/generation_repo.cpython-313.pyc differ diff --git a/repos/__pycache__/user_repo.cpython-313.pyc b/repos/__pycache__/user_repo.cpython-313.pyc index 118d701..271f461 100644 Binary files a/repos/__pycache__/user_repo.cpython-313.pyc and b/repos/__pycache__/user_repo.cpython-313.pyc differ diff --git a/repos/assets_repo.py b/repos/assets_repo.py index 6f9cb62..5249723 100644 --- a/repos/assets_repo.py +++ b/repos/assets_repo.py @@ -46,7 +46,7 @@ class AssetsRepo: res = await self.collection.insert_one(asset.model_dump()) return str(res.inserted_id) - async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]: + async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]: filter = {} if asset_type: filter["type"] = asset_type @@ -71,6 +71,9 @@ class AssetsRepo: # So list DOES NOT return thumbnails by default. args["thumbnail"] = 0 + if project_id: + filter["project_id"] = project_id + res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None) assets = [] for doc in res: @@ -157,8 +160,15 @@ class AssetsRepo: assets.append(Asset(**doc)) return assets - async def get_asset_count(self, character_id: Optional[str] = None) -> int: - return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {}) + async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int: + filter = {} + if character_id: + filter["linked_char_id"] = character_id + if created_by: + filter["created_by"] = created_by + if project_id: + filter["project_id"] = project_id + return await self.collection.count_documents(filter) async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: object_ids = [ObjectId(asset_id) for asset_id in asset_ids] diff --git a/repos/char_repo.py b/repos/char_repo.py index 8c17531..e28e2a5 100644 --- a/repos/char_repo.py +++ b/repos/char_repo.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient @@ -12,7 +12,7 @@ class CharacterRepo: async def add_character(self, character: Character) -> Character: op = await self.collection.insert_one(character.model_dump()) - character.id = op.inserted_id + character.id = str(op.inserted_id) return character async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None: @@ -26,18 +26,25 @@ class CharacterRepo: res["id"] = str(res.pop("_id")) return Character(**res) - async def get_all_characters(self) -> List[Character]: - docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None) - - characters = [] - for doc in docs: - # Конвертируем ObjectId в строку и кладем в поле id + async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]: + filter = {} + if created_by: + filter["created_by"] = created_by + if project_id: + filter["project_id"] = project_id + + args = {"character_image_data": 0} # don't return image data for list + res = await self.collection.find(filter, args).to_list(None) + chars = [] + for doc in res: doc["id"] = str(doc.pop("_id")) + chars.append(Character(**doc)) + return chars - # Создаем объект - characters.append(Character(**doc)) + async def update_char(self, char_id: str, character: Character) -> bool: + result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()}) + return result.modified_count > 0 - return characters - - async def update_char(self, char_id: str, character: Character) -> None: - await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()}) + async def delete_character(self, char_id: str) -> bool: + result = await self.collection.delete_one({"_id": ObjectId(char_id)}) + return result.deleted_count > 0 diff --git a/repos/dao.py b/repos/dao.py index 5bc70bd..23e7bbf 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -5,6 +5,7 @@ from repos.char_repo import CharacterRepo from repos.generation_repo import GenerationRepo from repos.user_repo import UsersRepo from repos.albums_repo import AlbumsRepo +from repos.project_repo import ProjectRepo from typing import Optional @@ -16,3 +17,5 @@ class DAO: self.assets = AssetsRepo(client, s3_adapter, db_name) self.generations = GenerationRepo(client, db_name) self.albums = AlbumsRepo(client, db_name) + self.projects = ProjectRepo(client, db_name) + self.users = UsersRepo(client, db_name) diff --git a/repos/generation_repo.py b/repos/generation_repo.py index 6035fd7..c668e97 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -25,13 +25,19 @@ class GenerationRepo: return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, - limit: int = 10, offset: int = 10) -> List[Generation]: + limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]: filter = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status + if created_by is not None: + filter["created_by"] = created_by + filter["project_id"] = None + if project_id is not None: + filter["project_id"] = project_id + res = await self.collection.find(filter).sort("created_at", -1).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] @@ -40,12 +46,17 @@ class GenerationRepo: generations.append(Generation(**generation)) return generations - async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int: + async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, + album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status + if created_by is not None: + args["created_by"] = created_by + if project_id is not None: + args["project_id"] = project_id return await self.collection.count_documents(args) async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: diff --git a/repos/project_repo.py b/repos/project_repo.py new file mode 100644 index 0000000..3edf3b4 --- /dev/null +++ b/repos/project_repo.py @@ -0,0 +1,62 @@ +from typing import List, Optional +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient +from models.Project import Project + +class ProjectRepo: + def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): + self.collection = client[db_name]["projects"] + + async def create_project(self, project: Project) -> str: + res = await self.collection.insert_one(project.model_dump()) + return str(res.inserted_id) + + async def get_project(self, project_id: str) -> Optional[Project]: + if not ObjectId.is_valid(project_id): + return None + res = await self.collection.find_one({"_id": ObjectId(project_id)}) + if res: + res["id"] = str(res.pop("_id")) + return Project(**res) + return None + + async def get_projects_by_user(self, user_id: str) -> List[Project]: + # Find projects where user is owner OR in members + filter = { + "$or": [ + {"owner_id": user_id}, + {"members": user_id} + ], + "is_deleted": False + } + cursor = self.collection.find(filter).sort("created_at", -1) + projects = [] + async for doc in cursor: + doc["id"] = str(doc.pop("_id")) + projects.append(Project(**doc)) + return projects + + async def add_member(self, project_id: str, user_id: str) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(project_id)}, + {"$addToSet": {"members": user_id}} + ) + return res.modified_count > 0 + + async def remove_member(self, project_id: str, user_id: str) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(project_id)}, + {"$pull": {"members": user_id}} + ) + return res.modified_count > 0 + + async def update_project(self, project_id: str, updates: dict) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(project_id)}, + {"$set": updates} + ) + return res.modified_count > 0 + + async def delete_project(self, project_id: str) -> bool: + res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}}) + return res.modified_count > 0 diff --git a/repos/user_repo.py b/repos/user_repo.py index a1434cd..c4e5e9d 100644 --- a/repos/user_repo.py +++ b/repos/user_repo.py @@ -19,10 +19,14 @@ class UsersRepo: self.collection = client[db_name]["users"] async def get_user(self, user_id: int): - return await self.collection.find_one({"user_id": user_id}) + user = await self.collection.find_one({"user_id": user_id}) + user["id"] = str(user["_id"]) + return user async def get_user_by_username(self, username: str): - return await self.collection.find_one({"username": username}) + user = await self.collection.find_one({"username": username}) + user["id"] = str(user["_id"]) + return user async def create_user(self, username: str, password: str, full_name: Optional[str] = None): """Создает нового пользователя с username/паролем""" @@ -38,15 +42,22 @@ class UsersRepo: "created_at": datetime.now(), "is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать) "is_web_user": True, - "is_admin": False + "is_admin": False, + "project_ids": [], + "current_project_id": None } result = await self.collection.insert_one(user_doc) - return await self.collection.find_one({"_id": result.inserted_id}) + user = await self.collection.find_one({"_id": result.inserted_id}) + user["id"] = str(user["_id"]) + return user async def get_pending_users(self): """Возвращает список пользователей со статусом PENDING""" cursor = self.collection.find({"status": UserStatus.PENDING}) - return await cursor.to_list(length=100) + users = await cursor.to_list(length=100) + for user in users: + user["id"] = str(user["_id"]) + return users async def approve_user(self, username: str): await self.collection.update_one( diff --git a/routers/__pycache__/char_router.cpython-313.pyc b/routers/__pycache__/char_router.cpython-313.pyc index 478ec0b..7dabe6b 100644 Binary files a/routers/__pycache__/char_router.cpython-313.pyc and b/routers/__pycache__/char_router.cpython-313.pyc differ diff --git a/routers/__pycache__/gen_router.cpython-313.pyc b/routers/__pycache__/gen_router.cpython-313.pyc index 6791fb9..669f7de 100644 Binary files a/routers/__pycache__/gen_router.cpython-313.pyc and b/routers/__pycache__/gen_router.cpython-313.pyc differ diff --git a/routers/char_router.py b/routers/char_router.py index 0d13acc..a408e33 100644 --- a/routers/char_router.py +++ b/routers/char_router.py @@ -63,7 +63,8 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot): character_image_data=file_io.read(), character_image_tg_id=None, character_image_doc_tg_id=file_id, - character_bio=bio + character_bio=bio, + created_by=str(message.from_user.id) ) file_io.close() diff --git a/routers/gen_router.py b/routers/gen_router.py index 86461a2..ca525c4 100644 --- a/routers/gen_router.py +++ b/routers/gen_router.py @@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi await wait_msg.delete() doc = await message.answer_document(res[0], caption="Generated result 💫") await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data, - tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) + tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id))) @router.message(Command("gen_mode")) @@ -259,7 +259,8 @@ async def handle_album( doc = await message.answer_document(file, caption="✨ Generated result") await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, - linked_char_id = data["char_id"])) + linked_char_id = data["char_id"], + created_by=str(message.from_user.id))) else: await message.answer("❌ Генерация не вернула изображений.") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") @@ -314,7 +315,8 @@ async def gen_mode_start( doc = await message.answer_document(file, caption="✨ Generated result") await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, - linked_char_id=data["char_id"])) + linked_char_id=data["char_id"], + created_by=str(message.from_user.id))) else: await message.answer("❌ Ничего не сгенерировалось.") diff --git a/tests/test_character_crud.py b/tests/test_character_crud.py new file mode 100644 index 0000000..4be072d --- /dev/null +++ b/tests/test_character_crud.py @@ -0,0 +1,101 @@ + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock +from motor.motor_asyncio import AsyncIOMotorClient +import os +import asyncio + +from main import app +from api.endpoints.auth import get_current_user +from api.dependency import get_dao +from repos.dao import DAO +from models.Character import Character + +# Config for test DB +MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017") +DB_NAME = "bot_db_test_chars" + +# Mock User +MOCK_USER_ID = "507f1f77bcf86cd799439011" +MOCK_USER = { + "_id": MOCK_USER_ID, + "username": "testuser", + "is_admin": False, + "status": "allowed" +} + +# Override get_current_user to bypass auth +def mock_get_current_user(): + return MOCK_USER + +app.dependency_overrides[get_current_user] = mock_get_current_user + +# Setup Real DAO with Test DB +client_mongo = AsyncIOMotorClient(MONGO_HOST) +dao = DAO(client_mongo, db_name=DB_NAME) + +def mock_get_dao(): + return dao + +app.dependency_overrides[get_dao] = mock_get_dao + +client = TestClient(app) + +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + # Setup: Ensure clean state + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop()) + + yield + + # Teardown + loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop()) + loop.close() + +def test_character_crud_flow(): + # 1. Create Character + create_payload = { + "name": "Test Character", + "character_bio": "A bio for test character", + "character_image_doc_tg_id": "file_123", + "avatar_image": "http://example.com/avatar.jpg" + } + + response = client.post("/api/characters/", json=create_payload) + assert response.status_code == 200, response.text + char_data = response.json() + assert char_data["name"] == create_payload["name"] + assert char_data["created_by"] == MOCK_USER_ID + char_id = char_data["id"] + assert char_id is not None + + # 2. Get Character + response = client.get(f"/api/characters/{char_id}") + assert response.status_code == 200 + assert response.json()["id"] == char_id + + # 3. Update Character + update_payload = { + "name": "Updated Name", + "character_bio": "Updated bio" + } + response = client.put(f"/api/characters/{char_id}", json=update_payload) + assert response.status_code == 200 + updated_data = response.json() + assert updated_data["name"] == "Updated Name" + assert updated_data["character_bio"] == "Updated bio" + + # Verify update persistent + response = client.get(f"/api/characters/{char_id}") + assert response.json()["name"] == "Updated Name" + + # 4. Delete Character + response = client.delete(f"/api/characters/{char_id}") + assert response.status_code == 204 + + # Verify deletion + response = client.get(f"/api/characters/{char_id}") + assert response.status_code == 404, "Deleted character should return 404" diff --git a/tests/test_character_integration.py b/tests/test_character_integration.py new file mode 100644 index 0000000..1381886 --- /dev/null +++ b/tests/test_character_integration.py @@ -0,0 +1,64 @@ + +import os +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock + +# 1. Set Auth Bypass and Test Config +os.environ["DB_NAME"] = "bot_db_test_integration" +# We keep MONGO_HOST as is (it works in verified script) + +# 2. Import app AFTER setting env +from main import app +from api.endpoints.auth import get_current_user + +# 3. Override Auth +MOCK_USER_ID = "507f1f77bcf86cd799439011" +MOCK_USER = { + "_id": MOCK_USER_ID, + "username": "testuser", + "is_admin": False, + "status": "allowed", + "project_ids": [] +} + +def mock_get_current_user(): + return MOCK_USER + +app.dependency_overrides[get_current_user] = mock_get_current_user + +client = TestClient(app) + +def test_character_crud_lifecycle(): + # 1. Create + create_payload = { + "name": "Integration Test Char", + "character_bio": "Testing with real app structure", + "character_image_doc_tg_id": "doc_123", + "avatar_image": "http://example.com/img.jpg" + } + + response = client.post("/api/characters/", json=create_payload) + assert response.status_code == 200, response.text + char_data = response.json() + assert char_data["name"] == create_payload["name"] + char_id = char_data["id"] + + # 2. Get + response = client.get(f"/api/characters/{char_id}") + assert response.status_code == 200 + assert response.json()["id"] == char_id + + # 3. Update + update_payload = {"name": "Updated Int Name"} + response = client.put(f"/api/characters/{char_id}", json=update_payload) + assert response.status_code == 200 + assert response.json()["name"] == "Updated Int Name" + + # 4. Delete + response = client.delete(f"/api/characters/{char_id}") + assert response.status_code == 204 + + # 5. Verify Delete + response = client.get(f"/api/characters/{char_id}") + assert response.status_code == 404