diff --git a/adapters/google_adapter.py b/adapters/google_adapter.py index adcfe0a..a4cd452 100644 --- a/adapters/google_adapter.py +++ b/adapters/google_adapter.py @@ -20,7 +20,7 @@ class GoogleAdapter: self.client = genai.Client(api_key=api_key) # Константы моделей - self.TEXT_MODEL = "gemini-3-pro-preview" + self.TEXT_MODEL = "gemini-3.1-pro-preview" self.IMAGE_MODEL = "gemini-3-pro-image-preview" def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple: diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py index 3634a17..3817647 100644 --- a/api/endpoints/generation_router.py +++ b/api/endpoints/generation_router.py @@ -18,7 +18,8 @@ from api.models import ( PromptRequest, GenerationGroupResponse, FinancialReport, - ExternalGenerationRequest + ExternalGenerationRequest, + NsfwRequest ) from api.service.generation_service import GenerationService from repos.dao import DAO @@ -192,6 +193,33 @@ async def toggle_like( return {"is_liked": is_liked} +@router.post("/{generation_id}/nsfw", status_code=status.HTTP_204_NO_CONTENT) +async def mark_generation_nsfw( + generation_id: str, + request: NsfwRequest, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user) +): + gen = await generation_service.get_generation(generation_id, current_user_id=str(current_user["_id"])) + if not gen: + raise HTTPException(status_code=404, detail="Generation not found") + + if gen.created_by != str(current_user["_id"]): + is_member = False + if gen.project_id: + project = await generation_service.dao.projects.get_project(gen.project_id) + if project and str(current_user["_id"]) in project.members: + is_member = True + + if not is_member: + raise HTTPException(status_code=403, detail="Access denied") + + if not await generation_service.dao.generations.mark_nsfw(generation_id, request.is_nsfw): + raise HTTPException(status_code=404, detail="Generation not found or already in the requested state") + + return None + + @router.post("/import", response_model=GenerationResponse) async def import_external_generation( request: Request, diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index 2019720..e24d9c0 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -23,6 +23,10 @@ class GenerationRequest(BaseModel): count: int = Field(default=1, ge=1, le=10) +class NsfwRequest(BaseModel): + is_nsfw: bool + + class GenerationsResponse(BaseModel): generations: List["GenerationResponse"] total_count: int @@ -69,4 +73,4 @@ class PromptRequest(BaseModel): class PromptResponse(BaseModel): - prompt: str \ No newline at end of file + prompt: str diff --git a/repos/generation_repo.py b/repos/generation_repo.py index 753f99c..b81df3b 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -131,6 +131,16 @@ class GenerationRepo: ) return True + async def mark_nsfw(self, generation_id: str, is_nsfw: bool) -> bool: + if not ObjectId.is_valid(generation_id): + return False + + res = await self.collection.update_one( + {"_id": ObjectId(generation_id)}, + {"$set": {"nsfw": is_nsfw}} + ) + return res.modified_count > 0 + async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict: """ Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation.