nsfw mark api
This commit is contained in:
@@ -18,7 +18,8 @@ from api.models import (
|
||||
PromptRequest,
|
||||
GenerationGroupResponse,
|
||||
FinancialReport,
|
||||
ExternalGenerationRequest
|
||||
ExternalGenerationRequest,
|
||||
NsfwRequest
|
||||
)
|
||||
from api.service.generation_service import GenerationService
|
||||
from repos.dao import DAO
|
||||
@@ -192,6 +193,33 @@ async def toggle_like(
|
||||
return {"is_liked": is_liked}
|
||||
|
||||
|
||||
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def mark_generation_nsfw(
|
||||
generation_id: str,
|
||||
request: NsfwRequest,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||
if not gen:
|
||||
raise HTTPException(status_code=404, detail="Generation not found")
|
||||
|
||||
if gen.created_by != str(current_user["_id"]):
|
||||
is_member = False
|
||||
if gen.project_id:
|
||||
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||
if project and str(current_user["_id"]) in project.members:
|
||||
is_member = True
|
||||
|
||||
if not is_member:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if not await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw):
|
||||
raise HTTPException(status_code=404, detail="Generation not found or already in the requested state")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
|
||||
@@ -23,6 +23,10 @@ class GenerationRequest(BaseModel):
|
||||
count: int = Field(default=1, ge=1, le=10)
|
||||
|
||||
|
||||
class NsfwRequest(BaseModel):
|
||||
is_nsfw: bool
|
||||
|
||||
|
||||
class GenerationsResponse(BaseModel):
|
||||
generations: List["GenerationResponse"]
|
||||
total_count: int
|
||||
@@ -69,4 +73,4 @@ class PromptRequest(BaseModel):
|
||||
|
||||
|
||||
class PromptResponse(BaseModel):
|
||||
prompt: str
|
||||
prompt: str
|
||||
|
||||
Reference in New Issue
Block a user