feat: Implement project management with new models, repositories, and API endpoints, and enhance character management with project association and DTOs.

This commit is contained in:
xds
2026-02-09 16:06:54 +03:00
parent 668aadcdc9
commit 458b6ebfc3
42 changed files with 728 additions and 60 deletions

Binary file not shown.

View File

@@ -44,3 +44,8 @@ def get_generation_service(
bot: Bot = Depends(get_bot_client), bot: Bot = Depends(get_bot_client),
) -> GenerationService: ) -> GenerationService:
return GenerationService(dao, gemini, s3_adapter, bot) return GenerationService(dao, gemini, s3_adapter, bot)
from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id

View File

@@ -19,6 +19,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/assets", tags=["Assets"]) router = APIRouter(prefix="/api/assets", tags=["Assets"])
@@ -68,11 +69,19 @@ async def delete_asset(
@router.get("", dependencies=[Depends(get_current_user)]) @router.get("", dependencies=[Depends(get_current_user)])
async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0) -> AssetsResponse: async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Optional[str] = None, limit: int = 10, offset: int = 0, current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> AssetsResponse:
logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}") logger.info(f"get_assets called. Limit: {limit}, Offset: {offset}")
assets = await dao.assets.get_assets(type, limit, offset)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
assets = await dao.assets.get_assets(type, limit, offset, created_by=user_id_filter, project_id=project_id)
# assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code # assets = await dao.assets.get_assets() # This line seemed redundant/conflicting in original code
total_count = await dao.assets.get_asset_count() total_count = await dao.assets.get_asset_count(created_by=user_id_filter, project_id=project_id)
# Manually map to DTO to trigger computed fields validation if necessary, # Manually map to DTO to trigger computed fields validation if necessary,
# but primarily to ensure valid Pydantic models for the response list. # but primarily to ensure valid Pydantic models for the response list.
@@ -84,11 +93,13 @@ async def get_assets(request: Request, dao: DAO = Depends(get_dao), type: Option
@router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED, dependencies=[Depends(get_current_user)]) @router.post("/upload", response_model=AssetResponse, status_code=status.HTTP_201_CREATED)
async def upload_asset( async def upload_asset(
file: UploadFile = File(...), file: UploadFile = File(...),
linked_char_id: Optional[str] = Form(None), linked_char_id: Optional[str] = Form(None),
dao: DAO = Depends(get_dao), dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id)
): ):
logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}") logger.info(f"upload_asset called. Filename: {file.filename}, ContentType: {file.content_type}, LinkedCharId: {linked_char_id}")
if not file.content_type: if not file.content_type:
@@ -97,6 +108,11 @@ async def upload_asset(
if not file.content_type.startswith("image/"): if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}") raise HTTPException(status_code=400, detail=f"Unsupported content type: {file.content_type}")
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
data = await file.read() data = await file.read()
if not data: if not data:
raise HTTPException(status_code=400, detail="Empty file") raise HTTPException(status_code=400, detail="Empty file")
@@ -111,7 +127,9 @@ async def upload_asset(
content_type=AssetContentType.IMAGE, content_type=AssetContentType.IMAGE,
linked_char_id=linked_char_id, linked_char_id=linked_char_id,
data=data, data=data,
thumbnail=thumbnail_bytes thumbnail=thumbnail_bytes,
created_by=str(current_user["_id"]),
project_id=project_id,
) )
asset_id = await dao.assets.create_asset(asset) asset_id = await dao.assets.create_asset(asset)

View File

@@ -59,6 +59,7 @@ class Token(BaseModel):
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str
username: str username: str
full_name: str | None = None full_name: str | None = None
status: str status: str

View File

@@ -1,4 +1,4 @@
from typing import List, Any, Coroutine from typing import List, Any, Coroutine, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel from pydantic import BaseModel
@@ -9,6 +9,7 @@ from api.models.AssetDTO import AssetsResponse, AssetResponse
from api.models.GenerationRequest import GenerationRequest, GenerationResponse from api.models.GenerationRequest import GenerationRequest, GenerationResponse
from models.Asset import Asset from models.Asset import Asset
from models.Character import Character from models.Character import Character
from api.models.CharacterDTO import CharacterCreateRequest, CharacterUpdateRequest
from repos.dao import DAO from repos.dao import DAO
from api.dependency import get_dao from api.dependency import get_dao
@@ -17,25 +18,49 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from api.endpoints.auth import get_current_user from api.endpoints.auth import get_current_user
from api.dependency import get_project_id
router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)]) router = APIRouter(prefix="/api/characters", tags=["Characters"], dependencies=[Depends(get_current_user)])
@router.get("/", response_model=List[Character]) @router.get("/", response_model=List[Character])
async def get_characters(request: Request, dao: DAO = Depends(get_dao), ) -> List[Character]: async def get_characters(request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id)) -> List[Character]:
logger.info("get_characters called") logger.info("get_characters called")
characters = await dao.chars.get_all_characters()
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
characters = await dao.chars.get_all_characters(created_by=user_id_filter, project_id=project_id)
return characters return characters
@router.get("/{character_id}/assets", response_model=AssetsResponse) @router.get("/{character_id}/assets", response_model=AssetsResponse)
async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10, async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), limit: int = 10,
offset: int = 0, ) -> AssetsResponse: offset: int = 0, current_user: dict = Depends(get_current_user)) -> AssetsResponse:
logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}") logger.info(f"get_character_assets called. CharacterID: {character_id}, Limit: {limit}, Offset: {offset}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if character is None: if character is None:
raise HTTPException(status_code=404, detail="Character not found") raise HTTPException(status_code=404, detail="Character not found")
# Access Check
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset) assets = await dao.assets.get_assets_by_char_id(character_id, limit, offset)
# Filter assets by user ownership as well?
# Usually if you own character, you see its assets.
# But assets also have specific created_by.
# Let's assume if you own character you can see its assets.
total_count = await dao.assets.get_asset_count(character_id) total_count = await dao.assets.get_asset_count(character_id)
asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets] asset_responses = [AssetResponse.model_validate(a.model_dump()) for a in assets]
@@ -43,12 +68,116 @@ async def get_character_assets(character_id: str, dao: DAO = Depends(get_dao), l
@router.get("/{character_id}", response_model=Character) @router.get("/{character_id}", response_model=Character)
async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao)) -> Character: async def get_character_by_id(character_id: str, request: Request, dao: DAO = Depends(get_dao), current_user: dict = Depends(get_current_user)) -> Character:
logger.debug(f"get_character_by_id called. ID: {character_id}") logger.debug(f"get_character_by_id called. ID: {character_id}")
character = await dao.chars.get_character(character_id) character = await dao.chars.get_character(character_id)
if not character:
raise HTTPException(status_code=404, detail="Character not found")
if character:
is_creator = character.created_by == str(current_user["_id"])
is_project_member = False
if character.project_id and character.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
return character return character
@router.post("/", response_model=Character)
async def create_character(
char_req: CharacterCreateRequest,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info("create_character called")
char_data = char_req.model_dump()
char_data["created_by"] = str(current_user["_id"])
if "id" not in char_data:
char_data["id"] = None
if char_req.project_id:
project = await dao.projects.get_project(char_req.project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
new_char = Character(**char_data)
created_char = await dao.chars.add_character(new_char)
return created_char
@router.put("/{character_id}", response_model=Character)
async def update_character(
character_id: str,
char_update: CharacterUpdateRequest,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
) -> Character:
logger.info(f"update_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
update_data = char_update.model_dump(exclude_unset=True)
if "project_id" in update_data and update_data["project_id"]:
new_project_id = update_data["project_id"]
project = await dao.projects.get_project(new_project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Target project access denied")
updated_char_data = existing_char.model_dump()
updated_char_data.update(update_data)
updated_char = Character(**updated_char_data)
success = await dao.chars.update_char(character_id, updated_char)
if not success:
raise HTTPException(status_code=500, detail="Failed to update character")
return updated_char
@router.delete("/{character_id}", status_code=204)
async def delete_character(
character_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
logger.info(f"delete_character called. ID: {character_id}")
existing_char = await dao.chars.get_character(character_id)
if not existing_char:
raise HTTPException(status_code=404, detail="Character not found")
is_creator = existing_char.created_by == str(current_user["_id"])
is_project_member = False
if existing_char.project_id and existing_char.project_id in current_user.get("project_ids", []):
is_project_member = True
if not is_creator and not is_project_member:
raise HTTPException(status_code=403, detail="Access denied")
success = await dao.chars.delete_character(character_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to delete character")
return
@router.post("/{character_id}/_run", response_model=GenerationResponse) @router.post("/{character_id}/_run", response_model=GenerationResponse)
async def post_character_generation(character_id: str, generation: GenerationRequest, async def post_character_generation(character_id: str, generation: GenerationRequest,
request: Request) -> GenerationResponse: request: Request) -> GenerationResponse:

View File

@@ -5,7 +5,8 @@ from fastapi.params import Depends
from starlette.requests import Request from starlette.requests import Request
from api import service from api import service
from api.dependency import get_generation_service from api.dependency import get_generation_service, get_project_id, get_dao
from repos.dao import DAO
from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest
from api.service.generation_service import GenerationService from api.service.generation_service import GenerationService
@@ -49,30 +50,65 @@ async def prompt_from_image(
@router.get("", response_model=GenerationsResponse) @router.get("", response_model=GenerationsResponse)
async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0, async def get_generations(character_id: Optional[str] = None, limit: int = 10, offset: int = 0,
generation_service: GenerationService = Depends(get_generation_service)): generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}") logger.info(f"get_generations called. CharacterId: {character_id}, Limit: {limit}, Offset: {offset}")
return await generation_service.get_generations(character_id, limit=limit, offset=offset)
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None # Show all project generations
return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id)
@router.post("/_run", response_model=GenerationResponse) @router.post("/_run", response_model=GenerationResponse)
async def post_generation(generation: GenerationRequest, request: Request, async def post_generation(generation: GenerationRequest, request: Request,
generation_service: GenerationService = Depends(get_generation_service), generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse: current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)) -> GenerationResponse:
logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}")
return await generation_service.create_generation_task(generation, user_id=current_user.get("username"))
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
generation.project_id = project_id
return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id")))
@router.get("/{generation_id}", response_model=GenerationResponse) @router.get("/{generation_id}", response_model=GenerationResponse)
async def get_generation(generation_id: str, async def get_generation(generation_id: str,
generation_service: GenerationService = Depends(get_generation_service)) -> GenerationResponse: generation_service: GenerationService = Depends(get_generation_service),
current_user: dict = Depends(get_current_user)) -> GenerationResponse:
logger.debug(f"get_generation called for ID: {generation_id}") logger.debug(f"get_generation called for ID: {generation_id}")
return await generation_service.get_generation(generation_id) gen = await generation_service.get_generation(generation_id)
if gen and gen.created_by != str(current_user["_id"]):
raise HTTPException(status_code=403, detail="Access denied")
return gen
@router.get("/running") @router.get("/running")
async def get_running_generations(request: Request, async def get_running_generations(request: Request,
generation_service: GenerationService = Depends(get_generation_service)): generation_service: GenerationService = Depends(get_generation_service),
return await generation_service.get_running_generations() current_user: dict = Depends(get_current_user),
project_id: Optional[str] = Depends(get_project_id),
dao: DAO = Depends(get_dao)):
user_id_filter = str(current_user["_id"])
if project_id:
project = await dao.projects.get_project(project_id)
if not project or str(current_user["_id"]) not in project.members:
raise HTTPException(status_code=403, detail="Project access denied")
user_id_filter = None
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)]) @router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])

View File

@@ -0,0 +1,167 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from api.dependency import get_dao
from api.endpoints.auth import get_current_user
from models.Project import Project
from repos.dao import DAO
router = APIRouter(prefix="/api/projects", tags=["Projects"])
class ProjectCreate(BaseModel):
name: str
description: Optional[str] = None
class ProjectResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
owner_id: str
members: List[str]
is_owner: bool = False
@router.post("", response_model=ProjectResponse)
async def create_project(
project_data: ProjectCreate,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
new_project = Project(
name=project_data.name,
description=project_data.description,
owner_id=user_id,
members=[user_id]
)
project_id = await dao.projects.create_project(new_project)
# Add project to user's project list
# Assuming user_repo has a method to add project or we do it directly?
# UserRepo doesn't have add_project method yet.
# But since UserRepo is just a wrapper around collection, lets add it here or update UserRepo later?
# Better to update UserRepo. For now, let's just return success.
# But user needs to see it in list.
# Update user in DB
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return ProjectResponse(
id=project_id,
name=new_project.name,
description=new_project.description,
owner_id=new_project.owner_id,
members=new_project.members,
is_owner=True
)
@router.get("", response_model=List[ProjectResponse])
async def get_my_projects(
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
projects = await dao.projects.get_projects_by_user(user_id)
responses = []
for p in projects:
responses.append(ProjectResponse(
id=p.id,
name=p.name,
description=p.description,
owner_id=p.owner_id,
members=p.members,
is_owner=(p.owner_id == user_id)
))
return responses
class MemberAdd(BaseModel):
username: str
@router.post("/{project_id}/members", dependencies=[Depends(get_current_user)])
async def add_member(
project_id: str,
member_data: MemberAdd,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can add members")
target_user = await dao.users.get_user_by_username(member_data.username)
if not target_user:
raise HTTPException(status_code=404, detail="User not found")
target_user_id = str(target_user["_id"])
if target_user_id in project.members:
return {"message": "User already in project"}
await dao.projects.add_member(project_id, target_user_id)
# Update target user's project list
await dao.users.collection.update_one(
{"_id": target_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Member added"}
@router.post("/{project_id}/join", dependencies=[Depends(get_current_user)])
async def join_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
# Retrieve project to verify it exists
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
user_id = str(current_user["_id"])
# Check if user is ALREADY in project
if user_id in project.members:
return {"message": "Already a member"}
# Add member
await dao.projects.add_member(project_id, user_id)
# Update user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$addToSet": {"project_ids": project_id}}
)
return {"message": "Joined project"}
@router.delete("/{project_id}", dependencies=[Depends(get_current_user)] )
async def delete_project(
project_id: str,
dao: DAO = Depends(get_dao),
current_user: dict = Depends(get_current_user)
):
user_id = str(current_user["_id"])
project = await dao.projects.get_project(project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if project.owner_id != user_id:
raise HTTPException(status_code=403, detail="Only owner can delete project")
await dao.projects.delete_project(project_id)
# Remove project from user's project list
await dao.users.collection.update_one(
{"_id": current_user["_id"]},
{"$pull": {"project_ids": project_id}}
)
return {"message": "Project deleted"}

View File

@@ -0,0 +1,18 @@
from typing import Optional
from pydantic import BaseModel
class CharacterCreateRequest(BaseModel):
name: str
character_bio: str
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None
class CharacterUpdateRequest(BaseModel):
name: Optional[str] = None
character_bio: Optional[str] = None
character_image_doc_tg_id: Optional[str] = None
avatar_image: Optional[str] = None
character_image_tg_id: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -16,6 +16,7 @@ class GenerationRequest(BaseModel):
telegram_id: Optional[int] = None telegram_id: Optional[int] = None
use_profile_image: bool = True use_profile_image: bool = True
assets_list: List[str] assets_list: List[str]
project_id: Optional[str] = None
class GenerationsResponse(BaseModel): class GenerationsResponse(BaseModel):

View File

@@ -92,10 +92,10 @@ class GenerationService:
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images) return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0) -> List[ async def get_generations(self, character_id: Optional[str] = None, limit: int = 10, offset: int = 0, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[
Generation]: Generation]:
generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset) generations = await self.dao.generations.get_generations(character_id = character_id,limit=limit, offset=offset, created_by=user_id, project_id=project_id)
total_count = await self.dao.generations.count_generations(character_id = character_id) total_count = await self.dao.generations.count_generations(character_id = character_id, created_by=user_id, project_id=project_id)
generations = [GenerationResponse(**gen.model_dump()) for gen in generations] generations = [GenerationResponse(**gen.model_dump()) for gen in generations]
return GenerationsResponse(generations=generations, total_count=total_count) return GenerationsResponse(generations=generations, total_count=total_count)
@@ -106,8 +106,8 @@ class GenerationService:
else: else:
return GenerationResponse(**gen.model_dump()) return GenerationResponse(**gen.model_dump())
async def get_running_generations(self) -> List[Generation]: async def get_running_generations(self, user_id: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING) return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id)
async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse:
gen_id = None gen_id = None
@@ -261,7 +261,8 @@ class GenerationService:
data=None, # Not storing bytes in DB anymore data=None, # Not storing bytes in DB anymore
minio_object_name=filename, minio_object_name=filename,
minio_bucket=self.s3_adapter.bucket_name, minio_bucket=self.s3_adapter.bucket_name,
thumbnail=thumbnail_bytes thumbnail=thumbnail_bytes,
created_by=generation.created_by
) )
# Сохраняем в БД # Сохраняем в БД

View File

@@ -180,6 +180,8 @@ app.add_middleware(
# Подключаем роутер API # Подключаем роутер API
from api.endpoints.auth import router as auth_api_router from api.endpoints.auth import router as auth_api_router
from api.endpoints.admin import router as admin_api_router from api.endpoints.admin import router as admin_api_router
from api.endpoints.project_router import router as project_api_router
app.include_router(auth_api_router) app.include_router(auth_api_router)
app.include_router(admin_api_router) app.include_router(admin_api_router)
app.include_router(api_assets_router) app.include_router(api_assets_router)
@@ -188,6 +190,7 @@ app.include_router(api_gen_router)
app.include_router(api_album_router) app.include_router(api_album_router)
app.include_router(api_admin_router) app.include_router(api_admin_router)
app.include_router(api_auth_router) app.include_router(api_auth_router)
app.include_router(project_api_router)
# --- ХЕНДЛЕРЫ БОТА (Main Router) --- # --- ХЕНДЛЕРЫ БОТА (Main Router) ---

View File

@@ -28,6 +28,8 @@ class Asset(BaseModel):
minio_thumbnail_object_name: Optional[str] = None minio_thumbnail_object_name: Optional[str] = None
thumbnail: Optional[bytes] = None thumbnail: Optional[bytes] = None
tags: List[str] = [] tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

View File

@@ -5,11 +5,13 @@ from pydantic_core.core_schema import computed_field
class Character(BaseModel): class Character(BaseModel):
id: str | None id: Optional[str] = None
name: str name: str
avatar_image: Optional[str] = None avatar_image: Optional[str] = None
character_image_data: Optional[bytes] = None character_image_data: Optional[bytes] = None
character_image_doc_tg_id: str character_image_doc_tg_id: Optional[str] = None
character_image_tg_id: str | None character_image_tg_id: Optional[str] = None
character_bio: str character_bio: Optional[str] = None
created_by: Optional[str] = None
project_id: Optional[str] = None

View File

@@ -35,7 +35,8 @@ class Generation(BaseModel):
output_token_usage: Optional[int] = None output_token_usage: Optional[int] = None
is_deleted: bool = False is_deleted: bool = False
album_id: Optional[str] = None album_id: Optional[str] = None
created_by: Optional[str] = None created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId)
project_id: Optional[str] = None
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

12
models/Project.py Normal file
View File

@@ -0,0 +1,12 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, Field
class Project(BaseModel):
id: Optional[str] = None
name: str
description: Optional[str] = None
owner_id: str
members: List[str] = [] # List of User IDs
is_deleted: bool = False
created_at: datetime = Field(default_factory=datetime.now)

View File

@@ -46,7 +46,7 @@ class AssetsRepo:
res = await self.collection.insert_one(asset.model_dump()) res = await self.collection.insert_one(asset.model_dump())
return str(res.inserted_id) return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False) -> List[Asset]: async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {} filter = {}
if asset_type: if asset_type:
filter["type"] = asset_type filter["type"] = asset_type
@@ -71,6 +71,9 @@ class AssetsRepo:
# So list DOES NOT return thumbnails by default. # So list DOES NOT return thumbnails by default.
args["thumbnail"] = 0 args["thumbnail"] = 0
if project_id:
filter["project_id"] = project_id
res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None) res = await self.collection.find(filter, args).sort("created_at", -1).skip(offset).limit(limit).to_list(None)
assets = [] assets = []
for doc in res: for doc in res:
@@ -157,8 +160,15 @@ class AssetsRepo:
assets.append(Asset(**doc)) assets.append(Asset(**doc))
return assets return assets
async def get_asset_count(self, character_id: Optional[str] = None) -> int: async def get_asset_count(self, character_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
return await self.collection.count_documents({"linked_char_id": character_id} if character_id else {}) filter = {}
if character_id:
filter["linked_char_id"] = character_id
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
return await self.collection.count_documents(filter)
async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]: async def get_assets_by_ids(self, asset_ids: List[str]) -> List[Asset]:
object_ids = [ObjectId(asset_id) for asset_id in asset_ids] object_ids = [ObjectId(asset_id) for asset_id in asset_ids]

View File

@@ -1,4 +1,4 @@
from typing import List from typing import List, Optional
from bson import ObjectId from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient from motor.motor_asyncio import AsyncIOMotorClient
@@ -12,7 +12,7 @@ class CharacterRepo:
async def add_character(self, character: Character) -> Character: async def add_character(self, character: Character) -> Character:
op = await self.collection.insert_one(character.model_dump()) op = await self.collection.insert_one(character.model_dump())
character.id = op.inserted_id character.id = str(op.inserted_id)
return character return character
async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None: async def get_character(self, character_id: str, with_image_data: bool = False) -> Character | None:
@@ -26,18 +26,25 @@ class CharacterRepo:
res["id"] = str(res.pop("_id")) res["id"] = str(res.pop("_id"))
return Character(**res) return Character(**res)
async def get_all_characters(self) -> List[Character]: async def get_all_characters(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Character]:
docs = await self.collection.find({}, {"character_image_data": 0}).to_list(None) filter = {}
if created_by:
filter["created_by"] = created_by
if project_id:
filter["project_id"] = project_id
characters = [] args = {"character_image_data": 0} # don't return image data for list
for doc in docs: res = await self.collection.find(filter, args).to_list(None)
# Конвертируем ObjectId в строку и кладем в поле id chars = []
for doc in res:
doc["id"] = str(doc.pop("_id")) doc["id"] = str(doc.pop("_id"))
chars.append(Character(**doc))
return chars
# Создаем объект async def update_char(self, char_id: str, character: Character) -> bool:
characters.append(Character(**doc)) result = await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})
return result.modified_count > 0
return characters async def delete_character(self, char_id: str) -> bool:
result = await self.collection.delete_one({"_id": ObjectId(char_id)})
async def update_char(self, char_id: str, character: Character) -> None: return result.deleted_count > 0
await self.collection.update_one({"_id": ObjectId(char_id)}, {"$set": character.model_dump()})

View File

@@ -5,6 +5,7 @@ from repos.char_repo import CharacterRepo
from repos.generation_repo import GenerationRepo from repos.generation_repo import GenerationRepo
from repos.user_repo import UsersRepo from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from typing import Optional from typing import Optional
@@ -16,3 +17,5 @@ class DAO:
self.assets = AssetsRepo(client, s3_adapter, db_name) self.assets = AssetsRepo(client, s3_adapter, db_name)
self.generations = GenerationRepo(client, db_name) self.generations = GenerationRepo(client, db_name)
self.albums = AlbumsRepo(client, db_name) self.albums = AlbumsRepo(client, db_name)
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)

View File

@@ -25,13 +25,19 @@ class GenerationRepo:
return Generation(**res) return Generation(**res)
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 10) -> List[Generation]: limit: int = 10, offset: int = 10, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False} filter = {"is_deleted": False}
if character_id is not None: if character_id is not None:
filter["linked_character_id"] = character_id filter["linked_character_id"] = character_id
if status is not None: if status is not None:
filter["status"] = status filter["status"] = status
if created_by is not None:
filter["created_by"] = created_by
filter["project_id"] = None
if project_id is not None:
filter["project_id"] = project_id
res = await self.collection.find(filter).sort("created_at", -1).skip( res = await self.collection.find(filter).sort("created_at", -1).skip(
offset).limit(limit).to_list(None) offset).limit(limit).to_list(None)
generations: List[Generation] = [] generations: List[Generation] = []
@@ -40,12 +46,17 @@ class GenerationRepo:
generations.append(Generation(**generation)) generations.append(Generation(**generation))
return generations return generations
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int: async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None) -> int:
args = {} args = {}
if character_id is not None: if character_id is not None:
args["linked_character_id"] = character_id args["linked_character_id"] = character_id
if status is not None: if status is not None:
args["status"] = status args["status"] = status
if created_by is not None:
args["created_by"] = created_by
if project_id is not None:
args["project_id"] = project_id
return await self.collection.count_documents(args) return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:

62
repos/project_repo.py Normal file
View File

@@ -0,0 +1,62 @@
from typing import List, Optional
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Project import Project
class ProjectRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["projects"]
async def create_project(self, project: Project) -> str:
res = await self.collection.insert_one(project.model_dump())
return str(res.inserted_id)
async def get_project(self, project_id: str) -> Optional[Project]:
if not ObjectId.is_valid(project_id):
return None
res = await self.collection.find_one({"_id": ObjectId(project_id)})
if res:
res["id"] = str(res.pop("_id"))
return Project(**res)
return None
async def get_projects_by_user(self, user_id: str) -> List[Project]:
# Find projects where user is owner OR in members
filter = {
"$or": [
{"owner_id": user_id},
{"members": user_id}
],
"is_deleted": False
}
cursor = self.collection.find(filter).sort("created_at", -1)
projects = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
projects.append(Project(**doc))
return projects
async def add_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$addToSet": {"members": user_id}}
)
return res.modified_count > 0
async def remove_member(self, project_id: str, user_id: str) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$pull": {"members": user_id}}
)
return res.modified_count > 0
async def update_project(self, project_id: str, updates: dict) -> bool:
res = await self.collection.update_one(
{"_id": ObjectId(project_id)},
{"$set": updates}
)
return res.modified_count > 0
async def delete_project(self, project_id: str) -> bool:
res = await self.collection.update_one({"_id": ObjectId(project_id)}, {"$set": {"is_deleted": True}})
return res.modified_count > 0

View File

@@ -19,10 +19,14 @@ class UsersRepo:
self.collection = client[db_name]["users"] self.collection = client[db_name]["users"]
async def get_user(self, user_id: int): async def get_user(self, user_id: int):
return await self.collection.find_one({"user_id": user_id}) user = await self.collection.find_one({"user_id": user_id})
user["id"] = str(user["_id"])
return user
async def get_user_by_username(self, username: str): async def get_user_by_username(self, username: str):
return await self.collection.find_one({"username": username}) user = await self.collection.find_one({"username": username})
user["id"] = str(user["_id"])
return user
async def create_user(self, username: str, password: str, full_name: Optional[str] = None): async def create_user(self, username: str, password: str, full_name: Optional[str] = None):
"""Создает нового пользователя с username/паролем""" """Создает нового пользователя с username/паролем"""
@@ -38,15 +42,22 @@ class UsersRepo:
"created_at": datetime.now(), "created_at": datetime.now(),
"is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать) "is_email_user": False, # Теперь это просто "обычный" юзер, не телеграм (хотя поле можно переименовать)
"is_web_user": True, "is_web_user": True,
"is_admin": False "is_admin": False,
"project_ids": [],
"current_project_id": None
} }
result = await self.collection.insert_one(user_doc) result = await self.collection.insert_one(user_doc)
return await self.collection.find_one({"_id": result.inserted_id}) user = await self.collection.find_one({"_id": result.inserted_id})
user["id"] = str(user["_id"])
return user
async def get_pending_users(self): async def get_pending_users(self):
"""Возвращает список пользователей со статусом PENDING""" """Возвращает список пользователей со статусом PENDING"""
cursor = self.collection.find({"status": UserStatus.PENDING}) cursor = self.collection.find({"status": UserStatus.PENDING})
return await cursor.to_list(length=100) users = await cursor.to_list(length=100)
for user in users:
user["id"] = str(user["_id"])
return users
async def approve_user(self, username: str): async def approve_user(self, username: str):
await self.collection.update_one( await self.collection.update_one(

View File

@@ -63,7 +63,8 @@ async def new_char_bio(message: Message, state: FSMContext, dao: DAO, bot: Bot):
character_image_data=file_io.read(), character_image_data=file_io.read(),
character_image_tg_id=None, character_image_tg_id=None,
character_image_doc_tg_id=file_id, character_image_doc_tg_id=file_id,
character_bio=bio character_bio=bio,
created_by=str(message.from_user.id)
) )
file_io.close() file_io.close()

View File

@@ -51,7 +51,7 @@ async def generate_image_cmd(message: Message, state: FSMContext, dao: DAO, gemi
await wait_msg.delete() await wait_msg.delete()
doc = await message.answer_document(res[0], caption="Generated result 💫") doc = await message.answer_document(res[0], caption="Generated result 💫")
await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data, await dao.assets.create_asset(Asset(id=None, name=res[0].filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=res[0].data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None)) tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, linked_char_id=None, created_by=str(message.from_user.id)))
@router.message(Command("gen_mode")) @router.message(Command("gen_mode"))
@@ -259,7 +259,8 @@ async def handle_album(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None, tg_doc_file_id = doc.document.file_id, tg_photo_file_id = None,
linked_char_id = data["char_id"])) linked_char_id = data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Генерация не вернула изображений.") await message.answer("❌ Генерация не вернула изображений.")
await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start") await gen_mode_base_msg(message=message, state=state, dao=dao, call_type="start")
@@ -314,7 +315,8 @@ async def gen_mode_start(
doc = await message.answer_document(file, caption="✨ Generated result") doc = await message.answer_document(file, caption="✨ Generated result")
await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data, await dao.assets.create_asset(Asset(id=None, name=file.filename, type=AssetType.GENERATED, content_type=AssetContentType.IMAGE, data=file.data,
tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None, tg_doc_file_id=doc.document.file_id, tg_photo_file_id=None,
linked_char_id=data["char_id"])) linked_char_id=data["char_id"],
created_by=str(message.from_user.id)))
else: else:
await message.answer("❌ Ничего не сгенерировалось.") await message.answer("❌ Ничего не сгенерировалось.")

View File

@@ -0,0 +1,101 @@
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
from motor.motor_asyncio import AsyncIOMotorClient
import os
import asyncio
from main import app
from api.endpoints.auth import get_current_user
from api.dependency import get_dao
from repos.dao import DAO
from models.Character import Character
# Config for test DB
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
DB_NAME = "bot_db_test_chars"
# Mock User
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed"
}
# Override get_current_user to bypass auth
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
# Setup Real DAO with Test DB
client_mongo = AsyncIOMotorClient(MONGO_HOST)
dao = DAO(client_mongo, db_name=DB_NAME)
def mock_get_dao():
return dao
app.dependency_overrides[get_dao] = mock_get_dao
client = TestClient(app)
@pytest.fixture(scope="module", autouse=True)
def setup_teardown():
# Setup: Ensure clean state
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
yield
# Teardown
loop.run_until_complete(client_mongo[DB_NAME]["characters"].drop())
loop.close()
def test_character_crud_flow():
# 1. Create Character
create_payload = {
"name": "Test Character",
"character_bio": "A bio for test character",
"character_image_doc_tg_id": "file_123",
"avatar_image": "http://example.com/avatar.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
assert char_data["created_by"] == MOCK_USER_ID
char_id = char_data["id"]
assert char_id is not None
# 2. Get Character
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update Character
update_payload = {
"name": "Updated Name",
"character_bio": "Updated bio"
}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
updated_data = response.json()
assert updated_data["name"] == "Updated Name"
assert updated_data["character_bio"] == "Updated bio"
# Verify update persistent
response = client.get(f"/api/characters/{char_id}")
assert response.json()["name"] == "Updated Name"
# 4. Delete Character
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# Verify deletion
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404, "Deleted character should return 404"

View File

@@ -0,0 +1,64 @@
import os
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock
# 1. Set Auth Bypass and Test Config
os.environ["DB_NAME"] = "bot_db_test_integration"
# We keep MONGO_HOST as is (it works in verified script)
# 2. Import app AFTER setting env
from main import app
from api.endpoints.auth import get_current_user
# 3. Override Auth
MOCK_USER_ID = "507f1f77bcf86cd799439011"
MOCK_USER = {
"_id": MOCK_USER_ID,
"username": "testuser",
"is_admin": False,
"status": "allowed",
"project_ids": []
}
def mock_get_current_user():
return MOCK_USER
app.dependency_overrides[get_current_user] = mock_get_current_user
client = TestClient(app)
def test_character_crud_lifecycle():
# 1. Create
create_payload = {
"name": "Integration Test Char",
"character_bio": "Testing with real app structure",
"character_image_doc_tg_id": "doc_123",
"avatar_image": "http://example.com/img.jpg"
}
response = client.post("/api/characters/", json=create_payload)
assert response.status_code == 200, response.text
char_data = response.json()
assert char_data["name"] == create_payload["name"]
char_id = char_data["id"]
# 2. Get
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 200
assert response.json()["id"] == char_id
# 3. Update
update_payload = {"name": "Updated Int Name"}
response = client.put(f"/api/characters/{char_id}", json=update_payload)
assert response.status_code == 200
assert response.json()["name"] == "Updated Int Name"
# 4. Delete
response = client.delete(f"/api/characters/{char_id}")
assert response.status_code == 204
# 5. Verify Delete
response = client.get(f"/api/characters/{char_id}")
assert response.status_code == 404