nsfw mark api
This commit is contained in:
@@ -20,7 +20,7 @@ class GoogleAdapter:
|
|||||||
self.client = genai.Client(api_key=api_key)
|
self.client = genai.Client(api_key=api_key)
|
||||||
|
|
||||||
# Константы моделей
|
# Константы моделей
|
||||||
self.TEXT_MODEL = "gemini-3-pro-preview"
|
self.TEXT_MODEL = "gemini-3.1-pro-preview"
|
||||||
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
|
||||||
|
|
||||||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ from api.models import (
|
|||||||
PromptRequest,
|
PromptRequest,
|
||||||
GenerationGroupResponse,
|
GenerationGroupResponse,
|
||||||
FinancialReport,
|
FinancialReport,
|
||||||
ExternalGenerationRequest
|
ExternalGenerationRequest,
|
||||||
|
NsfwRequest
|
||||||
)
|
)
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
@@ -192,6 +193,33 @@ async def toggle_like(
|
|||||||
return {"is_liked": is_liked}
|
return {"is_liked": is_liked}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def mark_generation_nsfw(
|
||||||
|
generation_id: str,
|
||||||
|
request: NsfwRequest,
|
||||||
|
generation_service: GenerationService = Depends(get_generation_service),
|
||||||
|
current_user: dict = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"]))
|
||||||
|
if not gen:
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found")
|
||||||
|
|
||||||
|
if gen.created_by != str(current_user["_id"]):
|
||||||
|
is_member = False
|
||||||
|
if gen.project_id:
|
||||||
|
project = await generation_service.dao.projects.get_project(gen.project_id)
|
||||||
|
if project and str(current_user["_id"]) in project.members:
|
||||||
|
is_member = True
|
||||||
|
|
||||||
|
if not is_member:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
if not await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw):
|
||||||
|
raise HTTPException(status_code=404, detail="Generation not found or already in the requested state")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import", response_model=GenerationResponse)
|
@router.post("/import", response_model=GenerationResponse)
|
||||||
async def import_external_generation(
|
async def import_external_generation(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ class GenerationRequest(BaseModel):
|
|||||||
count: int = Field(default=1, ge=1, le=10)
|
count: int = Field(default=1, ge=1, le=10)
|
||||||
|
|
||||||
|
|
||||||
|
class NsfwRequest(BaseModel):
|
||||||
|
is_nsfw: bool
|
||||||
|
|
||||||
|
|
||||||
class GenerationsResponse(BaseModel):
|
class GenerationsResponse(BaseModel):
|
||||||
generations: List["GenerationResponse"]
|
generations: List["GenerationResponse"]
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|||||||
@@ -131,6 +131,16 @@ class GenerationRepo:
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool:
|
||||||
|
if not ObjectId.is_valid(generation_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(generation_id)},
|
||||||
|
{"$set": {"nsfw": is_nsfw}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict:
|
||||||
"""
|
"""
|
||||||
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.
|
||||||
|
|||||||
Reference in New Issue
Block a user