feat: Implement external generation import API secured by HMAC-SHA256 signature verification.
This commit is contained in:
3
.env
3
.env
@@ -7,4 +7,5 @@ MINIO_ENDPOINT=http://31.59.58.220:9000
|
||||
MINIO_ACCESS_KEY=admin
|
||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||
MINIO_BUCKET=ai-char
|
||||
MODE=production
|
||||
MODE=production
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -10,7 +10,9 @@
|
||||
"main:app",
|
||||
"--reload",
|
||||
"--port",
|
||||
"8090"
|
||||
"8090",
|
||||
"--host",
|
||||
"0.0.0.0"
|
||||
],
|
||||
"jinja": true,
|
||||
"justMyCode": true
|
||||
|
||||
Binary file not shown.
@@ -1,6 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, UploadFile, File, Form
|
||||
from fastapi import APIRouter, UploadFile, File, Form, Header, HTTPException
|
||||
from fastapi.params import Depends
|
||||
from starlette.requests import Request
|
||||
|
||||
@@ -20,13 +20,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
from api.endpoints.auth import get_current_user
|
||||
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"], dependencies=[Depends(get_current_user)])
|
||||
router = APIRouter(prefix='/api/generations', tags=["Generation"])
|
||||
|
||||
|
||||
@router.post("/prompt-assistant", response_model=PromptResponse)
|
||||
async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
generation_service: GenerationService = Depends(
|
||||
get_generation_service)) -> PromptResponse:
|
||||
get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)) -> PromptResponse:
|
||||
logger.info(f"ask_prompt_assistant called with prompt length: {len(prompt_request.prompt)}. Linked assets: {len(prompt_request.linked_assets) if prompt_request.linked_assets else 0}")
|
||||
generated_prompt = await generation_service.ask_prompt_assistant(prompt_request.prompt, prompt_request.linked_assets)
|
||||
return PromptResponse(prompt=generated_prompt)
|
||||
@@ -36,7 +37,8 @@ async def ask_prompt_assistant(prompt_request: PromptRequest, request: Request,
|
||||
async def prompt_from_image(
|
||||
prompt: Optional[str] = Form(None),
|
||||
images: List[UploadFile] = File(...),
|
||||
generation_service: GenerationService = Depends(get_generation_service)
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> PromptResponse:
|
||||
logger.info(f"prompt_from_image called. Images count: {len(images)}. Prompt provided: {bool(prompt)}")
|
||||
images_bytes = []
|
||||
@@ -111,8 +113,59 @@ async def get_running_generations(request: Request,
|
||||
return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id)
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_current_user)])
|
||||
async def delete_generation(generation_id: str, generation_service: GenerationService = Depends(get_generation_service)):
|
||||
|
||||
|
||||
@router.post("/import", response_model=GenerationResponse)
|
||||
async def import_external_generation(
|
||||
request: Request,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
x_signature: str = Header(..., alias="X-Signature")
|
||||
) -> GenerationResponse:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
Requires server-to-server authentication via HMAC signature.
|
||||
"""
|
||||
import os
|
||||
from utils.external_auth import verify_signature
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
logger.info("import_external_generation called")
|
||||
|
||||
# Get raw request body for signature verification
|
||||
body = await request.body()
|
||||
|
||||
# Verify signature
|
||||
secret = os.getenv("EXTERNAL_API_SECRET")
|
||||
if not secret:
|
||||
logger.error("EXTERNAL_API_SECRET not configured")
|
||||
raise HTTPException(status_code=500, detail="Server configuration error")
|
||||
|
||||
if not verify_signature(body, x_signature, secret):
|
||||
logger.warning("Invalid signature for external generation import")
|
||||
raise HTTPException(status_code=401, detail="Invalid signature")
|
||||
|
||||
# Parse request body
|
||||
import json
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
external_gen = ExternalGenerationRequest(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse request body: {e}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid request body: {str(e)}")
|
||||
|
||||
# Import generation
|
||||
try:
|
||||
generation = await generation_service.import_external_generation(external_gen)
|
||||
return GenerationResponse(**generation.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import external generation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Import failed: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{generation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_generation(generation_id: str,
|
||||
generation_service: GenerationService = Depends(get_generation_service),
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
logger.info(f"delete_generation called for ID: {generation_id}")
|
||||
deleted = await generation_service.delete_generation(generation_id)
|
||||
if not deleted:
|
||||
|
||||
37
api/models/ExternalGenerationDTO.py
Normal file
37
api/models/ExternalGenerationDTO.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
|
||||
class ExternalGenerationRequest(BaseModel):
|
||||
"""Request model for importing external generations."""
|
||||
|
||||
prompt: str
|
||||
tech_prompt: Optional[str] = None
|
||||
|
||||
# Image can be provided as base64 string OR URL (one must be provided)
|
||||
image_data: Optional[str] = Field(None, description="Base64-encoded image data")
|
||||
image_url: Optional[str] = Field(None, description="URL to download image from")
|
||||
|
||||
# Generation metadata
|
||||
aspect_ratio: AspectRatios = AspectRatios.NINESIXTEEN
|
||||
quality: Quality = Quality.ONEK
|
||||
|
||||
# Optional linking
|
||||
linked_character_id: Optional[str] = None
|
||||
created_by: str = Field(..., description="User ID from external system")
|
||||
project_id: Optional[str] = None
|
||||
|
||||
# Performance metrics
|
||||
execution_time_seconds: Optional[float] = None
|
||||
api_execution_time_seconds: Optional[float] = None
|
||||
token_usage: Optional[int] = None
|
||||
input_token_usage: Optional[int] = None
|
||||
output_token_usage: Optional[int] = None
|
||||
|
||||
def validate_image_source(self):
|
||||
"""Ensure at least one image source is provided."""
|
||||
if not self.image_data and not self.image_url:
|
||||
raise ValueError("Either image_data or image_url must be provided")
|
||||
if self.image_data and self.image_url:
|
||||
raise ValueError("Only one of image_data or image_url should be provided")
|
||||
Binary file not shown.
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import base64
|
||||
from datetime import datetime, UTC
|
||||
from typing import List, Optional, Tuple, Any, Dict
|
||||
from io import BytesIO
|
||||
import httpx
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile
|
||||
@@ -158,16 +160,17 @@ class GenerationService:
|
||||
# 2. Получаем ассеты-референсы (если они есть)
|
||||
reference_assets: List[Asset] = []
|
||||
media_group_bytes: List[bytes] = []
|
||||
generation_prompt = f"""
|
||||
generation_prompt = generation.prompt
|
||||
# generation_prompt = f"""
|
||||
|
||||
Create detailed image of character in scene.
|
||||
# Create detailed image of character in scene.
|
||||
|
||||
SCENE DESCRIPTION: {generation.prompt}
|
||||
# SCENE DESCRIPTION: {generation.prompt}
|
||||
|
||||
Rules:
|
||||
- Integrate the character's appearance naturally into the scene description.
|
||||
- Focus on lighting, texture, and composition.
|
||||
"""
|
||||
# Rules:
|
||||
# - Integrate the character's appearance naturally into the scene description.
|
||||
# - Focus on lighting, texture, and composition.
|
||||
# """
|
||||
if generation.linked_character_id is not None:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if char_info is None:
|
||||
@@ -331,6 +334,99 @@ class GenerationService:
|
||||
logger.error(f"Error in progress simulation: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
async def import_external_generation(self, external_gen) -> Generation:
|
||||
"""
|
||||
Import a generation from an external source.
|
||||
|
||||
Args:
|
||||
external_gen: ExternalGenerationRequest with generation data and image
|
||||
|
||||
Returns:
|
||||
Created Generation object
|
||||
"""
|
||||
from api.models.ExternalGenerationDTO import ExternalGenerationRequest
|
||||
|
||||
# Validate image source
|
||||
external_gen.validate_image_source()
|
||||
|
||||
logger.info(f"Importing external generation for user: {external_gen.created_by}")
|
||||
|
||||
# 1. Process image (download or decode)
|
||||
image_bytes = None
|
||||
|
||||
if external_gen.image_url:
|
||||
# Download image from URL
|
||||
logger.info(f"Downloading image from URL: {external_gen.image_url}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(external_gen.image_url, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif external_gen.image_data:
|
||||
# Decode base64 image
|
||||
logger.info("Decoding base64 image data")
|
||||
image_bytes = base64.b64decode(external_gen.image_data)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Failed to process image data")
|
||||
|
||||
# 2. Generate thumbnail
|
||||
from utils.image_utils import create_thumbnail
|
||||
thumbnail_bytes = await asyncio.to_thread(create_thumbnail, image_bytes)
|
||||
|
||||
# 3. Save to S3
|
||||
filename = f"external/{external_gen.created_by}/{datetime.now().strftime('%Y%m%d_%H%M%S')}_{random.randint(1000, 9999)}.png"
|
||||
await self.s3_adapter.upload_file(filename, image_bytes, content_type="image/png")
|
||||
|
||||
# 4. Create Asset
|
||||
new_asset = Asset(
|
||||
name=f"External_Generated_{external_gen.linked_character_id or 'no_char'}",
|
||||
type=AssetType.GENERATED,
|
||||
content_type=AssetContentType.IMAGE,
|
||||
linked_char_id=external_gen.linked_character_id,
|
||||
data=None, # Not storing bytes in DB
|
||||
minio_object_name=filename,
|
||||
minio_bucket=self.s3_adapter.bucket_name,
|
||||
thumbnail=thumbnail_bytes,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id
|
||||
)
|
||||
|
||||
asset_id = await self.dao.assets.create_asset(new_asset)
|
||||
new_asset.id = str(asset_id)
|
||||
|
||||
logger.info(f"Created asset {asset_id} for external generation")
|
||||
|
||||
# 5. Create Generation record
|
||||
generation = Generation(
|
||||
status=GenerationStatus.DONE,
|
||||
linked_character_id=external_gen.linked_character_id,
|
||||
aspect_ratio=external_gen.aspect_ratio,
|
||||
quality=external_gen.quality,
|
||||
prompt=external_gen.prompt,
|
||||
tech_prompt=external_gen.tech_prompt,
|
||||
result_list=[new_asset.id],
|
||||
result=new_asset.id,
|
||||
progress=100,
|
||||
execution_time_seconds=external_gen.execution_time_seconds,
|
||||
api_execution_time_seconds=external_gen.api_execution_time_seconds,
|
||||
token_usage=external_gen.token_usage,
|
||||
input_token_usage=external_gen.input_token_usage,
|
||||
output_token_usage=external_gen.output_token_usage,
|
||||
created_by=external_gen.created_by,
|
||||
project_id=external_gen.project_id,
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC)
|
||||
)
|
||||
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
logger.info(f"Created generation {gen_id} from external source")
|
||||
|
||||
return generation
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""
|
||||
Soft delete generation by marking it as deleted.
|
||||
|
||||
63
tests/test_external_import.py
Executable file
63
tests/test_external_import.py
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for external generation import API.
|
||||
This script demonstrates how to call the import endpoint with proper HMAC signature.
|
||||
"""
|
||||
|
||||
import hmac
|
||||
import hashlib
|
||||
import json
|
||||
import requests
|
||||
import base64
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Configuration
|
||||
API_URL = "http://localhost:8090/api/generations/import"
|
||||
SECRET = os.getenv("EXTERNAL_API_SECRET", "your_super_secret_key_change_this_in_production")
|
||||
|
||||
# Sample generation data
|
||||
generation_data = {
|
||||
"prompt": "A beautiful sunset over mountains",
|
||||
"tech_prompt": "High quality landscape photography",
|
||||
"image_url": "https://picsum.photos/512/512", # Sample image URL
|
||||
# OR use base64:
|
||||
# "image_data": "base64_encoded_image_string_here",
|
||||
"aspect_ratio": "9:16",
|
||||
"quality": "1k",
|
||||
"created_by": "external_user_123",
|
||||
"execution_time_seconds": 5.2,
|
||||
"token_usage": 1000,
|
||||
"input_token_usage": 200,
|
||||
"output_token_usage": 800
|
||||
}
|
||||
|
||||
# Convert to JSON
|
||||
body = json.dumps(generation_data).encode('utf-8')
|
||||
|
||||
# Compute HMAC signature
|
||||
signature = hmac.new(
|
||||
SECRET.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
# Make request
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Signature": signature
|
||||
}
|
||||
|
||||
print(f"Sending request to {API_URL}")
|
||||
print(f"Signature: {signature}")
|
||||
|
||||
try:
|
||||
response = requests.post(API_URL, data=body, headers=headers)
|
||||
print(f"\nStatus Code: {response.status_code}")
|
||||
print(f"Response: {json.dumps(response.json(), indent=2)}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
if hasattr(e, 'response'):
|
||||
print(f"Response text: {e.response.text}")
|
||||
46
utils/external_auth.py
Normal file
46
utils/external_auth.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import hmac
|
||||
import hashlib
|
||||
import os
|
||||
from fastapi import Header, HTTPException
|
||||
from typing import Optional
|
||||
|
||||
def verify_signature(body: bytes, signature: str, secret: str) -> bool:
|
||||
"""
|
||||
Verify HMAC-SHA256 signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from X-Signature header
|
||||
secret: Shared secret key
|
||||
|
||||
Returns:
|
||||
True if signature is valid, False otherwise
|
||||
"""
|
||||
expected_signature = hmac.new(
|
||||
secret.encode('utf-8'),
|
||||
body,
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return hmac.compare_digest(signature, expected_signature)
|
||||
|
||||
|
||||
async def verify_external_signature(
|
||||
x_signature: Optional[str] = Header(None, alias="X-Signature")
|
||||
):
|
||||
"""
|
||||
FastAPI dependency to verify external API signature.
|
||||
|
||||
Raises:
|
||||
HTTPException: If signature is missing or invalid
|
||||
"""
|
||||
if not x_signature:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing X-Signature header"
|
||||
)
|
||||
|
||||
# Note: We'll need to access the raw request body in the endpoint
|
||||
# This dependency just validates the header exists
|
||||
# Actual signature verification happens in the endpoint
|
||||
return x_signature
|
||||
Reference in New Issue
Block a user