models + refactor
This commit is contained in:
4
.env
4
.env
@@ -8,4 +8,6 @@ MINIO_ACCESS_KEY=admin
|
|||||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||||
MINIO_BUCKET=ai-char
|
MINIO_BUCKET=ai-char
|
||||||
MODE=production
|
MODE=production
|
||||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||||
|
PROXY_SECRET_SALT=AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o
|
||||||
|
SCHEDULER_CHARACTER_ID=69931c10721fbd539804589b
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
# Project Context: AI Char Bot
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
Python backend project using FastAPI and MongoDB (Motor).
|
|
||||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
|
||||||
- **Service Layer**: `api/service` (Business logic).
|
|
||||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
|
||||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
|
||||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
|
||||||
|
|
||||||
## Coding Standards & Preferences
|
|
||||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
|
||||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
|
||||||
- **Error Handling**:
|
|
||||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
|
||||||
- Services/Routers handle `HTTPException`.
|
|
||||||
|
|
||||||
## Key Features & Implementation Details
|
|
||||||
- **Generations**:
|
|
||||||
- Managed by `GenerationService` and `GenerationRepo`.
|
|
||||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
|
||||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
|
||||||
- **Ideas**:
|
|
||||||
- Managed by `IdeaService` and `IdeaRepo`.
|
|
||||||
- Can have linked generations.
|
|
||||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
|
||||||
|
|
||||||
## Recent Changes
|
|
||||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
|
||||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
|
||||||
100
adapters/ai_proxy_adapter.py
Normal file
100
adapters/ai_proxy_adapter.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
import io
|
||||||
|
import httpx
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
from typing import List, Tuple, Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AIProxyAdapter:
|
||||||
|
def __init__(self, base_url: str = "http://82.22.174.14:8001", salt: str = None):
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.salt = salt or settings.PROXY_SECRET_SALT
|
||||||
|
|
||||||
|
def _generate_headers(self) -> Dict[str, str]:
|
||||||
|
timestamp = int(time.time())
|
||||||
|
hash_input = f"{timestamp}{self.salt}".encode()
|
||||||
|
signature = hashlib.sha256(hash_input).hexdigest()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"X-Timestamp": str(timestamp),
|
||||||
|
"X-Signature": signature
|
||||||
|
}
|
||||||
|
|
||||||
|
async def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", asset_urls: List[str] | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Generates text using the AI Proxy with signature verification.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/generate_text"
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
"asset_urls": asset_urls
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = self._generate_headers()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(url, json=payload, headers=headers, timeout=60.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if data.get("finish_reason") != "STOP":
|
||||||
|
logger.warning(f"AI Proxy generation finished with reason: {data.get('finish_reason')}")
|
||||||
|
|
||||||
|
return data.get("response") or ""
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Proxy Text Error: {e}")
|
||||||
|
raise Exception(f"AI Proxy Text Error: {e}")
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: AspectRatios,
|
||||||
|
quality: Quality,
|
||||||
|
model: str = "gemini-3-pro-image-preview",
|
||||||
|
asset_urls: List[str] | None = None
|
||||||
|
) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Generates image using the AI Proxy with signature verification.
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}/generate_image"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"asset_urls": asset_urls
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = self._generate_headers()
|
||||||
|
|
||||||
|
start_time = datetime.now()
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
response = await client.post(url, json=payload, headers=headers, timeout=120.0)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
image_bytes = response.content
|
||||||
|
byte_arr = io.BytesIO(image_bytes)
|
||||||
|
byte_arr.name = f"{datetime.now().timestamp()}.png"
|
||||||
|
byte_arr.seek(0)
|
||||||
|
|
||||||
|
end_time = datetime.now()
|
||||||
|
api_duration = (end_time - start_time).total_seconds()
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"api_execution_time_seconds": api_duration,
|
||||||
|
"token_usage": 0,
|
||||||
|
"input_token_usage": 0,
|
||||||
|
"output_token_usage": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return [byte_arr], metrics
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AI Proxy Image Error: {e}")
|
||||||
|
raise Exception(f"AI Proxy Image Error: {e}")
|
||||||
88
adapters/meta_adapter.py
Normal file
88
adapters/meta_adapter.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
META_GRAPH_VERSION = "v18.0"
|
||||||
|
META_GRAPH_BASE = f"https://graph.facebook.com/{META_GRAPH_VERSION}"
|
||||||
|
|
||||||
|
|
||||||
|
class MetaAdapter:
|
||||||
|
"""Adapter for Meta Platform API (Instagram Graph API).
|
||||||
|
|
||||||
|
Requires:
|
||||||
|
- access_token: long-lived Page or Instagram access token
|
||||||
|
- instagram_account_id: Instagram Business Account ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, access_token: str, instagram_account_id: str):
|
||||||
|
self.access_token = access_token
|
||||||
|
self.instagram_account_id = instagram_account_id
|
||||||
|
|
||||||
|
async def post_to_feed(self, image_url: str, caption: str) -> Optional[str]:
|
||||||
|
"""Upload image and publish to Instagram feed.
|
||||||
|
|
||||||
|
Returns the post ID on success, raises on failure.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
# Step 1: create media container
|
||||||
|
resp = await client.post(
|
||||||
|
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
|
||||||
|
data={
|
||||||
|
"image_url": image_url,
|
||||||
|
"caption": caption,
|
||||||
|
"access_token": self.access_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
creation_id = resp.json().get("id")
|
||||||
|
if not creation_id:
|
||||||
|
raise ValueError(f"No creation_id from Meta API: {resp.text}")
|
||||||
|
|
||||||
|
# Step 2: publish
|
||||||
|
resp2 = await client.post(
|
||||||
|
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
|
||||||
|
data={
|
||||||
|
"creation_id": creation_id,
|
||||||
|
"access_token": self.access_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp2.raise_for_status()
|
||||||
|
post_id = resp2.json().get("id")
|
||||||
|
logger.info(f"Published to Instagram feed: {post_id}")
|
||||||
|
return post_id
|
||||||
|
|
||||||
|
async def post_to_story(self, image_url: str) -> Optional[str]:
|
||||||
|
"""Upload image and publish to Instagram story.
|
||||||
|
|
||||||
|
Returns the story ID on success, raises on failure.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
# Step 1: create story container
|
||||||
|
resp = await client.post(
|
||||||
|
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
|
||||||
|
data={
|
||||||
|
"image_url": image_url,
|
||||||
|
"media_type": "STORIES",
|
||||||
|
"access_token": self.access_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
creation_id = resp.json().get("id")
|
||||||
|
if not creation_id:
|
||||||
|
raise ValueError(f"No creation_id from Meta API: {resp.text}")
|
||||||
|
|
||||||
|
# Step 2: publish
|
||||||
|
resp2 = await client.post(
|
||||||
|
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
|
||||||
|
data={
|
||||||
|
"creation_id": creation_id,
|
||||||
|
"access_token": self.access_token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
resp2.raise_for_status()
|
||||||
|
story_id = resp2.json().get("id")
|
||||||
|
logger.info(f"Published to Instagram story: {story_id}")
|
||||||
|
return story_id
|
||||||
81
aiws.py
81
aiws.py
@@ -23,6 +23,8 @@ from api.service.album_service import AlbumService
|
|||||||
from middlewares.album import AlbumMiddleware
|
from middlewares.album import AlbumMiddleware
|
||||||
from middlewares.auth import AuthMiddleware
|
from middlewares.auth import AuthMiddleware
|
||||||
from middlewares.dao import DaoMiddleware
|
from middlewares.dao import DaoMiddleware
|
||||||
|
from scheduler.daily_scheduler import DailyScheduler
|
||||||
|
from scheduler.telegram_admin_handler import create_daily_scheduler_router
|
||||||
|
|
||||||
# Репозитории и DAO
|
# Репозитории и DAO
|
||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
@@ -108,7 +110,7 @@ dp["gemini"] = gemini
|
|||||||
# 1. Роутеры без мидлварей (например, auth)
|
# 1. Роутеры без мидлварей (например, auth)
|
||||||
dp.include_router(auth_router)
|
dp.include_router(auth_router)
|
||||||
|
|
||||||
# 2. Основные роутеры
|
# 2. Основные роутеры (daily_scheduler router добавляется в lifespan)
|
||||||
main_router = Router()
|
main_router = Router()
|
||||||
dp.include_router(main_router)
|
dp.include_router(main_router)
|
||||||
dp.include_router(assets_router)
|
dp.include_router(assets_router)
|
||||||
@@ -141,6 +143,34 @@ async def start_scheduler(service: GenerationService):
|
|||||||
logger.error(f"Scheduler error: {e}")
|
logger.error(f"Scheduler error: {e}")
|
||||||
await asyncio.sleep(60) # Check every 60 seconds
|
await asyncio.sleep(60) # Check every 60 seconds
|
||||||
|
|
||||||
|
|
||||||
|
def _build_daily_scheduler() -> DailyScheduler:
|
||||||
|
"""Construct DailyScheduler; MetaAdapter is optional (needs env vars)."""
|
||||||
|
meta_adapter = None
|
||||||
|
if settings.META_ACCESS_TOKEN and settings.META_INSTAGRAM_ACCOUNT_ID:
|
||||||
|
from adapters.meta_adapter import MetaAdapter
|
||||||
|
meta_adapter = MetaAdapter(
|
||||||
|
access_token=settings.META_ACCESS_TOKEN,
|
||||||
|
instagram_account_id=settings.META_INSTAGRAM_ACCOUNT_ID,
|
||||||
|
)
|
||||||
|
logger.info("MetaAdapter initialized")
|
||||||
|
else:
|
||||||
|
logger.warning("META_ACCESS_TOKEN / META_INSTAGRAM_ACCOUNT_ID not set — Instagram publishing disabled")
|
||||||
|
|
||||||
|
if not settings.SCHEDULER_CHARACTER_ID:
|
||||||
|
logger.warning("SCHEDULER_CHARACTER_ID not set — daily scheduler will error at runtime")
|
||||||
|
|
||||||
|
return DailyScheduler(
|
||||||
|
dao=dao,
|
||||||
|
gemini=gemini,
|
||||||
|
s3_adapter=s3_adapter,
|
||||||
|
generation_service=generation_service,
|
||||||
|
bot=bot,
|
||||||
|
admin_id=ADMIN_ID,
|
||||||
|
character_id=settings.SCHEDULER_CHARACTER_ID or "",
|
||||||
|
meta_adapter=meta_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
@@ -164,36 +194,39 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
print("✅ DB & DAO initialized")
|
print("✅ DB & DAO initialized")
|
||||||
|
|
||||||
# 2. ЗАПУСК БОТА (в фоне)
|
# 2. Инициализация и регистрация daily_scheduler роутера
|
||||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
daily_scheduler = _build_daily_scheduler()
|
||||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
dp.include_router(create_daily_scheduler_router(daily_scheduler))
|
||||||
# polling_task = asyncio.create_task(
|
print("📅 Daily scheduler router registered")
|
||||||
# dp.start_polling(bot, handle_signals=False)
|
|
||||||
# )
|
|
||||||
# print("🤖 Bot polling started")
|
|
||||||
|
|
||||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
# 3. ЗАПУСК БОТА (в фоне)
|
||||||
|
# handle_signals=False — бот не перехватывает сигналы остановки у uvicorn
|
||||||
|
polling_task = asyncio.create_task(
|
||||||
|
dp.start_polling(bot, handle_signals=False)
|
||||||
|
)
|
||||||
|
print("🤖 Bot polling started")
|
||||||
|
|
||||||
|
# 4. ЗАПУСК ШЕДУЛЕРОВ
|
||||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||||
print("⏰ Scheduler started")
|
daily_scheduler_task = asyncio.create_task(daily_scheduler.run_loop())
|
||||||
|
print("⏰ Schedulers started")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# --- SHUTDOWN ---
|
# --- SHUTDOWN ---
|
||||||
print("🛑 Shutting down...")
|
print("🛑 Shutting down...")
|
||||||
|
|
||||||
# 4. Остановка шедулера
|
# Останавливаем все фоновые задачи
|
||||||
scheduler_task.cancel()
|
for task, name in [
|
||||||
try:
|
(polling_task, "Bot polling"),
|
||||||
await scheduler_task
|
(scheduler_task, "Stale-gen scheduler"),
|
||||||
except asyncio.CancelledError:
|
(daily_scheduler_task, "Daily scheduler"),
|
||||||
print("⏰ Scheduler stopped")
|
]:
|
||||||
|
task.cancel()
|
||||||
# 3. Остановка бота
|
try:
|
||||||
# polling_task.cancel()
|
await task
|
||||||
# try:
|
except asyncio.CancelledError:
|
||||||
# await polling_task
|
print(f"⏹ {name} stopped")
|
||||||
# except asyncio.CancelledError:
|
|
||||||
# print("🤖 Bot polling stopped")
|
|
||||||
|
|
||||||
# 4. Отключение БД
|
# 4. Отключение БД
|
||||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
|||||||
from repos.user_repo import UsersRepo, UserStatus
|
from repos.user_repo import UsersRepo, UserStatus
|
||||||
from api.dependency import get_dao
|
from api.dependency import get_dao
|
||||||
from repos.dao import DAO
|
from repos.dao import DAO
|
||||||
|
from models.Settings import SystemSettings
|
||||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
@@ -96,3 +97,21 @@ async def deny_user(
|
|||||||
|
|
||||||
await repo.deny_user(username)
|
await repo.deny_user(username)
|
||||||
return {"message": f"User {username} denied"}
|
return {"message": f"User {username} denied"}
|
||||||
|
|
||||||
|
@router.get("/settings", response_model=SystemSettings)
|
||||||
|
async def get_settings(
|
||||||
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
|
dao: Annotated[DAO, Depends(get_dao)]
|
||||||
|
):
|
||||||
|
return await dao.settings.get_settings()
|
||||||
|
|
||||||
|
@router.post("/settings")
|
||||||
|
async def update_settings(
|
||||||
|
settings: SystemSettings,
|
||||||
|
admin: Annotated[dict, Depends(get_current_admin)],
|
||||||
|
dao: Annotated[DAO, Depends(get_dao)]
|
||||||
|
):
|
||||||
|
success = await dao.settings.update_settings(settings)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update settings")
|
||||||
|
return {"message": "Settings updated successfully"}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from aiogram.types import BufferedInputFile
|
|||||||
|
|
||||||
from adapters.Exception import GoogleGenerationException
|
from adapters.Exception import GoogleGenerationException
|
||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.ai_proxy_adapter import AIProxyAdapter
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.models import (
|
from api.models import (
|
||||||
FinancialReport, UsageStats, UsageByEntity,
|
FinancialReport, UsageStats, UsageByEntity,
|
||||||
@@ -72,6 +73,7 @@ class GenerationService:
|
|||||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||||
self.dao = dao
|
self.dao = dao
|
||||||
self.gemini = gemini
|
self.gemini = gemini
|
||||||
|
self.ai_proxy = AIProxyAdapter()
|
||||||
self.s3_adapter = s3_adapter
|
self.s3_adapter = s3_adapter
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
@@ -84,12 +86,19 @@ class GenerationService:
|
|||||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||||
f"USER_ENTERED_PROMPT: {prompt}"
|
f"USER_ENTERED_PROMPT: {prompt}"
|
||||||
)
|
)
|
||||||
assets_data = []
|
|
||||||
if assets:
|
|
||||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
|
||||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
|
||||||
|
|
||||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
settings = await self.dao.settings.get_settings()
|
||||||
|
if settings.use_ai_proxy:
|
||||||
|
asset_urls = await self._prepare_asset_urls(assets) if assets else None
|
||||||
|
generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls)
|
||||||
|
else:
|
||||||
|
assets_data = []
|
||||||
|
if assets:
|
||||||
|
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||||
|
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||||
|
|
||||||
|
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||||
|
|
||||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||||
return generated_prompt
|
return generated_prompt
|
||||||
|
|
||||||
@@ -99,6 +108,15 @@ class GenerationService:
|
|||||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||||
technical_prompt += "Provide ONLY the detailed prompt."
|
technical_prompt += "Provide ONLY the detailed prompt."
|
||||||
|
|
||||||
|
settings = await self.dao.settings.get_settings()
|
||||||
|
if settings.use_ai_proxy:
|
||||||
|
# Proxy doesn't support raw bytes currently.
|
||||||
|
# In a real scenario we'd upload them to a temp bucket.
|
||||||
|
# For now, we call the proxy with just the prompt,
|
||||||
|
# or we could fall back to GoogleAdapter if images are essential.
|
||||||
|
# To be safe and follow instructions to use proxy, we use it.
|
||||||
|
return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model)
|
||||||
|
|
||||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||||
|
|
||||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||||
@@ -154,19 +172,36 @@ class GenerationService:
|
|||||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||||
|
|
||||||
# 1. Prepare input
|
# 1. Prepare input
|
||||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation)
|
||||||
|
|
||||||
# 2. Run generation with progress simulation
|
# 2. Run generation with progress simulation
|
||||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||||
try:
|
try:
|
||||||
generated_bytes_list, metrics = await generate_image_task(
|
settings = await self.dao.settings.get_settings()
|
||||||
prompt=generation_prompt,
|
if settings.use_ai_proxy:
|
||||||
media_group_bytes=media_group_bytes,
|
asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None
|
||||||
aspect_ratio=generation.aspect_ratio,
|
generated_images_io, metrics = await self.ai_proxy.generate_image(
|
||||||
quality=generation.quality,
|
prompt=generation_prompt,
|
||||||
model=generation.model or "gemini-3-pro-image-preview",
|
aspect_ratio=generation.aspect_ratio,
|
||||||
gemini=self.gemini
|
quality=generation.quality,
|
||||||
)
|
model=generation.model or "gemini-3-pro-image-preview",
|
||||||
|
asset_urls=asset_urls
|
||||||
|
)
|
||||||
|
generated_bytes_list = []
|
||||||
|
for img_io in generated_images_io:
|
||||||
|
img_io.seek(0)
|
||||||
|
generated_bytes_list.append(img_io.read())
|
||||||
|
img_io.close()
|
||||||
|
else:
|
||||||
|
generated_bytes_list, metrics = await generate_image_task(
|
||||||
|
prompt=generation_prompt,
|
||||||
|
media_group_bytes=media_group_bytes,
|
||||||
|
aspect_ratio=generation.aspect_ratio,
|
||||||
|
quality=generation.quality,
|
||||||
|
model=generation.model or "gemini-3-pro-image-preview",
|
||||||
|
gemini=self.gemini
|
||||||
|
)
|
||||||
|
|
||||||
self._update_generation_metrics(generation, metrics)
|
self._update_generation_metrics(generation, metrics)
|
||||||
|
|
||||||
# 3. Process results
|
# 3. Process results
|
||||||
@@ -299,36 +334,39 @@ class GenerationService:
|
|||||||
await self._handle_generation_failure(gen, e)
|
await self._handle_generation_failure(gen, e)
|
||||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||||
|
|
||||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]:
|
||||||
media_group_bytes: List[bytes] = []
|
media_group_bytes: List[bytes] = []
|
||||||
prompt = generation.prompt
|
prompt = generation.prompt
|
||||||
|
asset_ids = []
|
||||||
|
|
||||||
# 1. Character Avatar
|
# 1. Character Avatar
|
||||||
if generation.linked_character_id:
|
if generation.linked_character_id:
|
||||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||||
if not char_info:
|
if not char_info:
|
||||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||||
|
|
||||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
asset_ids.append(char_info.avatar_asset_id)
|
||||||
if avatar_asset:
|
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True)
|
||||||
data = await self._get_asset_data_bytes(avatar_asset)
|
if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data:
|
||||||
if data: media_group_bytes.append(data)
|
media_group_bytes.append(avatar_asset.data)
|
||||||
|
|
||||||
# 2. Reference Assets
|
# 2. Reference Assets
|
||||||
if generation.assets_list:
|
if generation.assets_list:
|
||||||
|
asset_ids.extend(generation.assets_list)
|
||||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||||
for asset in assets:
|
for asset in assets:
|
||||||
data = await self._get_asset_data_bytes(asset)
|
data = await self._load_asset_image_data(asset)
|
||||||
if data: media_group_bytes.append(data)
|
if data: media_group_bytes.append(data)
|
||||||
|
|
||||||
# 3. Environment Assets
|
# 3. Environment Assets
|
||||||
if generation.environment_id:
|
if generation.environment_id:
|
||||||
env = await self.dao.environments.get_env(generation.environment_id)
|
env = await self.dao.environments.get_env(generation.environment_id)
|
||||||
if env and env.asset_ids:
|
if env and env.asset_ids:
|
||||||
|
asset_ids.extend(env.asset_ids)
|
||||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||||
for asset in env_assets:
|
for asset in env_assets:
|
||||||
data = await self._get_asset_data_bytes(asset)
|
data = await self._load_asset_image_data(asset)
|
||||||
if data: media_group_bytes.append(data)
|
if data: media_group_bytes.append(data)
|
||||||
|
|
||||||
if media_group_bytes:
|
if media_group_bytes:
|
||||||
@@ -337,14 +375,26 @@ class GenerationService:
|
|||||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||||
)
|
)
|
||||||
|
|
||||||
return media_group_bytes, prompt
|
return media_group_bytes, prompt, asset_ids
|
||||||
|
|
||||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]:
|
||||||
|
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
|
||||||
|
urls = []
|
||||||
|
for asset in assets:
|
||||||
|
if asset.minio_object_name:
|
||||||
|
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
|
||||||
|
urls.append(f"{bucket}/{asset.minio_object_name}")
|
||||||
|
return urls
|
||||||
|
|
||||||
|
async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]:
|
||||||
|
"""Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids)."""
|
||||||
if asset.content_type != AssetContentType.IMAGE:
|
if asset.content_type != AssetContentType.IMAGE:
|
||||||
return None
|
return None
|
||||||
|
if asset.data:
|
||||||
|
return asset.data
|
||||||
if asset.minio_object_name:
|
if asset.minio_object_name:
|
||||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||||
return asset.data
|
return None
|
||||||
|
|
||||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||||
|
|||||||
10
config.py
10
config.py
@@ -24,6 +24,16 @@ class Settings(BaseSettings):
|
|||||||
# External API
|
# External API
|
||||||
EXTERNAL_API_SECRET: Optional[str] = None
|
EXTERNAL_API_SECRET: Optional[str] = None
|
||||||
|
|
||||||
|
# Daily Scheduler
|
||||||
|
SCHEDULER_CHARACTER_ID: Optional[str] = None # Character ID used for daily generation
|
||||||
|
|
||||||
|
# Meta Platform (Instagram Graph API)
|
||||||
|
META_ACCESS_TOKEN: Optional[str] = None # Long-lived page/Instagram access token
|
||||||
|
META_INSTAGRAM_ACCOUNT_ID: Optional[str] = None # Instagram Business Account ID
|
||||||
|
|
||||||
|
# AI Proxy Security
|
||||||
|
PROXY_SECRET_SALT: str = "AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o"
|
||||||
|
|
||||||
# JWT Security
|
# JWT Security
|
||||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||||
ALGORITHM: str = "HS256"
|
ALGORITHM: str = "HS256"
|
||||||
|
|||||||
10
models/Settings.py
Normal file
10
models/Settings.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
|
||||||
|
class SystemSettings(BaseModel):
|
||||||
|
id: str = Field(default="system_settings", alias="_id")
|
||||||
|
use_ai_proxy: bool = False
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True
|
||||||
@@ -10,6 +10,7 @@ from repos.idea_repo import IdeaRepo
|
|||||||
from repos.post_repo import PostRepo
|
from repos.post_repo import PostRepo
|
||||||
from repos.environment_repo import EnvironmentRepo
|
from repos.environment_repo import EnvironmentRepo
|
||||||
from repos.inspiration_repo import InspirationRepo
|
from repos.inspiration_repo import InspirationRepo
|
||||||
|
from repos.settings_repo import SettingsRepo
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -27,3 +28,4 @@ class DAO:
|
|||||||
self.posts = PostRepo(client, db_name)
|
self.posts = PostRepo(client, db_name)
|
||||||
self.environments = EnvironmentRepo(client, db_name)
|
self.environments = EnvironmentRepo(client, db_name)
|
||||||
self.inspirations = InspirationRepo(client, db_name)
|
self.inspirations = InspirationRepo(client, db_name)
|
||||||
|
self.settings = SettingsRepo(client, db_name)
|
||||||
|
|||||||
26
repos/settings_repo.py
Normal file
26
repos/settings_repo.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from models.Settings import SystemSettings
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
|
||||||
|
class SettingsRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name: str):
|
||||||
|
self.collection = client[db_name]["settings"]
|
||||||
|
|
||||||
|
async def get_settings(self) -> SystemSettings:
|
||||||
|
doc = await self.collection.find_one({"_id": "system_settings"})
|
||||||
|
if not doc:
|
||||||
|
# Create default settings if not exists
|
||||||
|
settings = SystemSettings()
|
||||||
|
await self.collection.insert_one(settings.model_dump(by_alias=True))
|
||||||
|
return settings
|
||||||
|
return SystemSettings(**doc)
|
||||||
|
|
||||||
|
async def update_settings(self, settings: SystemSettings) -> bool:
|
||||||
|
settings.updated_at = datetime.now(UTC)
|
||||||
|
result = await self.collection.replace_one(
|
||||||
|
{"_id": "system_settings"},
|
||||||
|
settings.model_dump(by_alias=True),
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
return result.modified_count > 0 or result.upserted_id is not None
|
||||||
0
scheduler/__init__.py
Normal file
0
scheduler/__init__.py
Normal file
456
scheduler/daily_scheduler.py
Normal file
456
scheduler/daily_scheduler.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from aiogram import Bot
|
||||||
|
from aiogram.types import BufferedInputFile, InlineKeyboardButton, InlineKeyboardMarkup
|
||||||
|
|
||||||
|
from adapters.google_adapter import GoogleAdapter
|
||||||
|
from adapters.ai_proxy_adapter import AIProxyAdapter
|
||||||
|
from adapters.s3_adapter import S3Adapter
|
||||||
|
from api.service.generation_service import GenerationService
|
||||||
|
from models.Asset import Asset
|
||||||
|
from models.Generation import Generation, GenerationStatus
|
||||||
|
from models.enums import AspectRatios, ImageModel, Quality, TextModel
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MSK_TZ = timezone(timedelta(hours=3))
|
||||||
|
SCHEDULE_HOUR_MSK = 11
|
||||||
|
SCHEDULE_MINUTE_MSK = 0
|
||||||
|
|
||||||
|
# Callback data prefixes for inline keyboard buttons
|
||||||
|
CB_POST = "daily_post"
|
||||||
|
CB_REGEN_ALL = "daily_regen_all"
|
||||||
|
CB_REGEN_IMG = "daily_regen_img"
|
||||||
|
CB_REGEN_MORE = "daily_regen_more"
|
||||||
|
CB_CANCEL = "daily_cancel"
|
||||||
|
|
||||||
|
|
||||||
|
def make_admin_keyboard(generation_id: str) -> InlineKeyboardMarkup:
|
||||||
|
return InlineKeyboardMarkup(
|
||||||
|
inline_keyboard=[
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text="✅ Выложить", callback_data=f"{CB_POST}:{generation_id}"),
|
||||||
|
InlineKeyboardButton(text="❌ Отмена", callback_data=f"{CB_CANCEL}:{generation_id}"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text="🔄 Перегенерить с нуля", callback_data=f"{CB_REGEN_ALL}:{generation_id}"),
|
||||||
|
InlineKeyboardButton(text="🖼 Перегенерить изображение", callback_data=f"{CB_REGEN_IMG}:{generation_id}"),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
InlineKeyboardButton(text="➕ Сгенерировать ещё 2", callback_data=f"{CB_REGEN_MORE}:{generation_id}"),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DailyScheduler:
|
||||||
|
"""Orchestrates the daily AI-character content generation pipeline.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Generate image prompt + social caption via LLM (with character avatar).
|
||||||
|
2. Generate image via GenerationService.create_generation() (reuses existing pipeline).
|
||||||
|
3. Send to Telegram admin with action buttons.
|
||||||
|
|
||||||
|
Admin actions (inline keyboard):
|
||||||
|
- Выложить → post to Instagram feed + story via Meta API.
|
||||||
|
- Перегенерить с нуля → restart from step 1.
|
||||||
|
- Перегенерить изображение → restart from step 2 (same prompt/caption).
|
||||||
|
- Сгенерировать ещё 2 → generate 2 pose-varied images.
|
||||||
|
- Отмена → dismiss (no action).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dao: DAO,
|
||||||
|
gemini: GoogleAdapter,
|
||||||
|
s3_adapter: S3Adapter,
|
||||||
|
generation_service: GenerationService,
|
||||||
|
bot: Bot,
|
||||||
|
admin_id: int,
|
||||||
|
character_id: str,
|
||||||
|
meta_adapter=None, # Optional[MetaAdapter]
|
||||||
|
):
|
||||||
|
self.dao = dao
|
||||||
|
self.gemini = gemini
|
||||||
|
self.ai_proxy = AIProxyAdapter()
|
||||||
|
self.s3_adapter = s3_adapter
|
||||||
|
self.generation_service = generation_service
|
||||||
|
self.bot = bot
|
||||||
|
self.admin_id = admin_id
|
||||||
|
self.character_id = character_id
|
||||||
|
self.meta_adapter = meta_adapter
|
||||||
|
|
||||||
|
# Stores session state keyed by generation_id.
|
||||||
|
# Each value: {prompt, caption, asset_id, message_id, chat_id}
|
||||||
|
self.pending_sessions: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Scheduler loop
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def run_loop(self):
|
||||||
|
"""Run indefinitely, triggering daily generation at 11:00 MSK."""
|
||||||
|
logger.info("Daily scheduler loop started")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self._wait_until_next_run()
|
||||||
|
logger.info("Daily scheduler: triggering daily generation")
|
||||||
|
await self.run_daily_generation()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Daily scheduler loop error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _wait_until_next_run(self):
|
||||||
|
now = datetime.now(MSK_TZ)
|
||||||
|
next_run = now.replace(
|
||||||
|
hour=SCHEDULE_HOUR_MSK,
|
||||||
|
minute=SCHEDULE_MINUTE_MSK,
|
||||||
|
second=0,
|
||||||
|
microsecond=0,
|
||||||
|
)
|
||||||
|
if now >= next_run:
|
||||||
|
next_run += timedelta(days=1)
|
||||||
|
wait_seconds = (next_run - now).total_seconds()
|
||||||
|
logger.info(
|
||||||
|
f"Next daily generation at {next_run.strftime('%Y-%m-%d %H:%M MSK')} "
|
||||||
|
f"(in {wait_seconds / 3600:.1f}h)"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(wait_seconds)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main generation pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def run_daily_generation(self):
|
||||||
|
"""Full pipeline: prompt → image → send to admin."""
|
||||||
|
try:
|
||||||
|
prompt, caption = await self._generate_prompt_and_caption()
|
||||||
|
logger.info(f"Prompt generated ({len(prompt)} chars), caption ({len(caption)} chars)")
|
||||||
|
|
||||||
|
generation, asset = await self._generate_image_and_save(prompt)
|
||||||
|
logger.info(f"Generation done: id={generation.id}, asset={asset.id}")
|
||||||
|
|
||||||
|
await self._send_to_admin(generation, asset, prompt, caption)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Daily generation pipeline failed: {e}", exc_info=True)
|
||||||
|
try:
|
||||||
|
await self.bot.send_message(
|
||||||
|
chat_id=self.admin_id,
|
||||||
|
text=f"❌ <b>Ежедневная генерация провалилась:</b>\n<code>{e}</code>",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 1 — Generate prompt + caption via LLM
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _generate_prompt_and_caption(self) -> Tuple[str, str]:
|
||||||
|
"""Ask Gemini to produce an image prompt and social caption.
|
||||||
|
|
||||||
|
Passes the character's avatar photo to the model so it can create
|
||||||
|
a prompt that is faithful to the character's appearance.
|
||||||
|
"""
|
||||||
|
char = await self.dao.chars.get_character(self.character_id)
|
||||||
|
if not char:
|
||||||
|
raise ValueError(f"Character {self.character_id} not found in DB")
|
||||||
|
|
||||||
|
avatar_bytes_list: list[bytes] = []
|
||||||
|
if char.avatar_asset_id:
|
||||||
|
avatar_asset = await self.dao.assets.get_asset(char.avatar_asset_id, with_data=True)
|
||||||
|
if avatar_asset and avatar_asset.data:
|
||||||
|
avatar_bytes_list.append(avatar_asset.data)
|
||||||
|
|
||||||
|
char_bio = char.character_bio or "An expressive, stylish AI character."
|
||||||
|
system_prompt = (
|
||||||
|
f"You are a creative director for the social media account of an AI character named '{char.name}'.\n"
|
||||||
|
# f"Character description: {char_bio}\n\n"
|
||||||
|
"I'm attaching the character's avatar photo. Based on it, produce TWO things:\n\n"
|
||||||
|
"1. IMAGE_PROMPT: A detailed, vivid image generation prompt in English. "
|
||||||
|
"Describe the pose, environment, lighting, color palette, and artistic style. It must look amateur. "
|
||||||
|
"Make it unique and suitable for a social media post.\n\n"
|
||||||
|
"2. SOCIAL_CAPTION: An engaging caption in English for Instagram and TikTok. "
|
||||||
|
"Include 5-10 relevant hashtags at the end.\n\n"
|
||||||
|
"Reply in EXACTLY this format (two lines, no extra text before IMAGE_PROMPT):\n"
|
||||||
|
"IMAGE_PROMPT: <prompt here>\n"
|
||||||
|
"SOCIAL_CAPTION: <caption here>"
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = await self.dao.settings.get_settings()
|
||||||
|
if settings.use_ai_proxy:
|
||||||
|
asset_urls = await self._prepare_asset_urls([char.avatar_asset_id]) if char.avatar_asset_id else None
|
||||||
|
raw = await self.ai_proxy.generate_text(
|
||||||
|
system_prompt,
|
||||||
|
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
|
||||||
|
asset_urls
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw = await asyncio.to_thread(
|
||||||
|
self.gemini.generate_text,
|
||||||
|
system_prompt,
|
||||||
|
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
|
||||||
|
avatar_bytes_list or None,
|
||||||
|
)
|
||||||
|
logger.debug(f"LLM raw response: {raw[:500]}")
|
||||||
|
|
||||||
|
prompt, caption = self._parse_prompt_and_caption(raw, char.name)
|
||||||
|
return prompt, caption
|
||||||
|
|
||||||
|
async def _prepare_asset_urls(self, asset_ids: list[str]) -> list[str]:
|
||||||
|
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
|
||||||
|
urls = []
|
||||||
|
for asset in assets:
|
||||||
|
if asset.minio_object_name:
|
||||||
|
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
|
||||||
|
urls.append(f"{bucket}/{asset.minio_object_name}")
|
||||||
|
return urls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_prompt_and_caption(raw: str, char_name: str) -> Tuple[str, str]:
|
||||||
|
prompt = ""
|
||||||
|
caption = ""
|
||||||
|
|
||||||
|
if "IMAGE_PROMPT:" in raw and "SOCIAL_CAPTION:" in raw:
|
||||||
|
after_label = raw.split("IMAGE_PROMPT:", 1)[1]
|
||||||
|
prompt = after_label.split("SOCIAL_CAPTION:", 1)[0].strip()
|
||||||
|
caption = after_label.split("SOCIAL_CAPTION:", 1)[1].strip()
|
||||||
|
elif "IMAGE_PROMPT:" in raw:
|
||||||
|
prompt = raw.split("IMAGE_PROMPT:", 1)[1].strip()
|
||||||
|
else:
|
||||||
|
prompt = raw.strip()
|
||||||
|
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError(f"LLM did not produce IMAGE_PROMPT. Raw snippet: {raw[:300]}")
|
||||||
|
if not caption:
|
||||||
|
caption = f"✨ Новый контент от {char_name}"
|
||||||
|
|
||||||
|
return prompt, caption
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 2 — Generate image via GenerationService
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _generate_image_and_save(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
variation_hint: Optional[str] = None,
|
||||||
|
) -> Tuple[Generation, Asset]:
|
||||||
|
"""Create a Generation record and delegate execution to GenerationService.
|
||||||
|
|
||||||
|
Uses GenerationService.create_generation() which handles:
|
||||||
|
- loading character avatar / reference assets
|
||||||
|
- calling Gemini image generation
|
||||||
|
- saving result as Asset in S3
|
||||||
|
- finalizing the Generation record with metrics
|
||||||
|
|
||||||
|
No telegram_id is set, so the service won't send its own notification —
|
||||||
|
we handle that ourselves in _send_to_admin() with action buttons.
|
||||||
|
"""
|
||||||
|
actual_prompt = prompt
|
||||||
|
if variation_hint:
|
||||||
|
actual_prompt = f"{prompt}. {variation_hint}"
|
||||||
|
|
||||||
|
# Create Generation record (GenerationService.create_generation expects it pre-saved)
|
||||||
|
generation = Generation(
|
||||||
|
status=GenerationStatus.RUNNING,
|
||||||
|
linked_character_id=self.character_id,
|
||||||
|
aspect_ratio=AspectRatios.NINESIXTEEN,
|
||||||
|
quality=Quality.ONEK,
|
||||||
|
prompt=actual_prompt,
|
||||||
|
model=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW.value,
|
||||||
|
use_profile_image=True,
|
||||||
|
# No telegram_id → service won't send its own notification
|
||||||
|
)
|
||||||
|
gen_id = await self.dao.generations.create_generation(generation)
|
||||||
|
generation.id = gen_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Delegate all heavy lifting to the existing service
|
||||||
|
await self.generation_service.create_generation(generation)
|
||||||
|
except Exception:
|
||||||
|
# create_generation doesn't mark FAILED itself — the caller (_queued_generation_runner) does.
|
||||||
|
# So we need to handle failure here.
|
||||||
|
await self.generation_service._handle_generation_failure(generation, Exception("Image generation failed"))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# After create_generation, generation.result_list is populated
|
||||||
|
if not generation.result_list:
|
||||||
|
raise ValueError("Generation completed but produced no assets")
|
||||||
|
|
||||||
|
asset = await self.dao.assets.get_asset(generation.result_list[0], with_data=False)
|
||||||
|
if not asset:
|
||||||
|
raise ValueError(f"Asset {generation.result_list[0]} not found after generation")
|
||||||
|
|
||||||
|
return generation, asset
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 3 — Send to admin
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _send_to_admin(
|
||||||
|
self,
|
||||||
|
generation: Generation,
|
||||||
|
asset: Asset,
|
||||||
|
prompt: str,
|
||||||
|
caption: str,
|
||||||
|
):
|
||||||
|
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||||
|
if not img_data:
|
||||||
|
raise ValueError(f"Cannot load image from S3: {asset.minio_object_name}")
|
||||||
|
|
||||||
|
self.pending_sessions[generation.id] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"caption": caption,
|
||||||
|
"asset_id": asset.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = await self.bot.send_photo(
|
||||||
|
chat_id=self.admin_id,
|
||||||
|
photo=BufferedInputFile(img_data, filename="daily.png"),
|
||||||
|
caption=(
|
||||||
|
f"📸 <b>Ежедневная генерация</b>\n\n"
|
||||||
|
f"<b>Подпись для соцсетей:</b>\n{caption}\n\n"
|
||||||
|
f"<b>Промпт:</b>\n<code>{prompt[:300]}</code>"
|
||||||
|
),
|
||||||
|
reply_markup=make_admin_keyboard(generation.id),
|
||||||
|
)
|
||||||
|
self.pending_sessions[generation.id]["message_id"] = msg.message_id
|
||||||
|
self.pending_sessions[generation.id]["chat_id"] = msg.chat.id
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Admin action handlers (called from Telegram callback router)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def handle_post(self, generation_id: str, message_id: int, chat_id: int):
|
||||||
|
"""Post to Instagram feed + story."""
|
||||||
|
session = self.pending_sessions.get(generation_id)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.meta_adapter:
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption="⚠️ Meta API не настроен (META_ACCESS_TOKEN не задан). Публикация недоступна.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
asset = await self.dao.assets.get_asset(session["asset_id"], with_data=False)
|
||||||
|
if not asset or not asset.minio_object_name:
|
||||||
|
raise ValueError("Asset not found in DB")
|
||||||
|
|
||||||
|
image_url = await self.s3_adapter.get_presigned_url(
|
||||||
|
asset.minio_object_name, expiration=3600
|
||||||
|
)
|
||||||
|
if not image_url:
|
||||||
|
raise ValueError("Could not generate presigned URL for image")
|
||||||
|
|
||||||
|
feed_id = await self.meta_adapter.post_to_feed(image_url, session["caption"])
|
||||||
|
story_id = await self.meta_adapter.post_to_story(image_url)
|
||||||
|
|
||||||
|
self.pending_sessions.pop(generation_id, None)
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption=(
|
||||||
|
f"✅ <b>Опубликовано!</b>\n\n"
|
||||||
|
f"📰 Feed ID: <code>{feed_id}</code>\n"
|
||||||
|
f"📖 Story ID: <code>{story_id}</code>"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Meta publish failed for generation {generation_id}: {e}", exc_info=True)
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption=f"❌ <b>Ошибка публикации:</b>\n<code>{e}</code>",
|
||||||
|
reply_markup=make_admin_keyboard(generation_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_regen_all(self, generation_id: str, message_id: int, chat_id: int):
|
||||||
|
"""Restart from step 1: generate new prompt, caption, and image."""
|
||||||
|
self.pending_sessions.pop(generation_id, None)
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption="🔄 <b>Перегенерация с нуля...</b>",
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._run_regen_all(chat_id))
|
||||||
|
|
||||||
|
async def _run_regen_all(self, chat_id: int):
|
||||||
|
try:
|
||||||
|
await self.run_daily_generation()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Regen-all failed: {e}", exc_info=True)
|
||||||
|
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка перегенерации:\n<code>{e}</code>")
|
||||||
|
|
||||||
|
async def handle_regen_image(self, generation_id: str, message_id: int, chat_id: int):
|
||||||
|
"""Restart from step 2: generate new image using existing prompt/caption."""
|
||||||
|
session = self.pending_sessions.pop(generation_id, None)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = session["prompt"]
|
||||||
|
caption = session["caption"]
|
||||||
|
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption="🖼 <b>Перегенерация изображения...</b>",
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._run_regen_image(prompt, caption, chat_id))
|
||||||
|
|
||||||
|
async def _run_regen_image(self, prompt: str, caption: str, chat_id: int):
|
||||||
|
try:
|
||||||
|
generation, asset = await self._generate_image_and_save(prompt)
|
||||||
|
await self._send_to_admin(generation, asset, prompt, caption)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Regen-image failed: {e}", exc_info=True)
|
||||||
|
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка генерации:\n<code>{e}</code>")
|
||||||
|
|
||||||
|
async def handle_regen_more(self, generation_id: str, message_id: int, chat_id: int):
|
||||||
|
"""Generate 2 more pose-varied images using the existing prompt/caption."""
|
||||||
|
session = self.pending_sessions.get(generation_id)
|
||||||
|
if not session:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = session["prompt"]
|
||||||
|
caption = session["caption"]
|
||||||
|
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption="➕ <b>Генерирую ещё 2 варианта...</b>",
|
||||||
|
)
|
||||||
|
asyncio.create_task(self._run_regen_more(prompt, caption, chat_id))
|
||||||
|
|
||||||
|
async def _run_regen_more(self, prompt: str, caption: str, chat_id: int):
|
||||||
|
variation_hints = [
|
||||||
|
"Slightly vary the pose and camera angle while keeping the same scene, environment and lighting.",
|
||||||
|
"Try a different subtle pose or expression, same background and setting as described.",
|
||||||
|
]
|
||||||
|
for i, hint in enumerate(variation_hints):
|
||||||
|
try:
|
||||||
|
generation, asset = await self._generate_image_and_save(prompt, variation_hint=hint)
|
||||||
|
await self._send_to_admin(generation, asset, prompt, caption)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Regen-more variant {i + 1} failed: {e}", exc_info=True)
|
||||||
|
await self.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=f"❌ Ошибка варианта {i + 1}:\n<code>{e}</code>",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_cancel(self, generation_id: str, message_id: int, chat_id: int):
|
||||||
|
"""Dismiss: remove buttons, do nothing else."""
|
||||||
|
self.pending_sessions.pop(generation_id, None)
|
||||||
|
await self.bot.edit_message_caption(
|
||||||
|
chat_id=chat_id,
|
||||||
|
message_id=message_id,
|
||||||
|
caption="🚫 Отменено.",
|
||||||
|
)
|
||||||
82
scheduler/telegram_admin_handler.py
Normal file
82
scheduler/telegram_admin_handler.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Telegram inline-keyboard handlers for the daily scheduler admin flow.
|
||||||
|
|
||||||
|
Usage (in aiws.py):
|
||||||
|
from scheduler.telegram_admin_handler import create_daily_scheduler_router
|
||||||
|
from scheduler.daily_scheduler import DailyScheduler
|
||||||
|
|
||||||
|
daily_scheduler = DailyScheduler(...)
|
||||||
|
dp.include_router(create_daily_scheduler_router(daily_scheduler))
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiogram import F, Router
|
||||||
|
from aiogram.types import CallbackQuery
|
||||||
|
|
||||||
|
from scheduler.daily_scheduler import (
|
||||||
|
CB_CANCEL,
|
||||||
|
CB_POST,
|
||||||
|
CB_REGEN_ALL,
|
||||||
|
CB_REGEN_IMG,
|
||||||
|
CB_REGEN_MORE,
|
||||||
|
DailyScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_daily_scheduler_router(scheduler: DailyScheduler) -> Router:
|
||||||
|
"""Return an aiogram Router with all callback handlers bound to *scheduler*."""
|
||||||
|
router = Router(name="daily_scheduler")
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith(CB_POST + ":"))
|
||||||
|
async def on_post(callback: CallbackQuery):
|
||||||
|
generation_id = callback.data.split(":", 1)[1]
|
||||||
|
await callback.answer("Публикую в Instagram...")
|
||||||
|
await scheduler.handle_post(
|
||||||
|
generation_id=generation_id,
|
||||||
|
message_id=callback.message.message_id,
|
||||||
|
chat_id=callback.message.chat.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith(CB_REGEN_ALL + ":"))
|
||||||
|
async def on_regen_all(callback: CallbackQuery):
|
||||||
|
generation_id = callback.data.split(":", 1)[1]
|
||||||
|
await callback.answer("Перезапускаю с нуля...")
|
||||||
|
await scheduler.handle_regen_all(
|
||||||
|
generation_id=generation_id,
|
||||||
|
message_id=callback.message.message_id,
|
||||||
|
chat_id=callback.message.chat.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith(CB_REGEN_IMG + ":"))
|
||||||
|
async def on_regen_img(callback: CallbackQuery):
|
||||||
|
generation_id = callback.data.split(":", 1)[1]
|
||||||
|
await callback.answer("Генерирую новое изображение...")
|
||||||
|
await scheduler.handle_regen_image(
|
||||||
|
generation_id=generation_id,
|
||||||
|
message_id=callback.message.message_id,
|
||||||
|
chat_id=callback.message.chat.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith(CB_REGEN_MORE + ":"))
|
||||||
|
async def on_regen_more(callback: CallbackQuery):
|
||||||
|
generation_id = callback.data.split(":", 1)[1]
|
||||||
|
await callback.answer("Генерирую 2 варианта...")
|
||||||
|
await scheduler.handle_regen_more(
|
||||||
|
generation_id=generation_id,
|
||||||
|
message_id=callback.message.message_id,
|
||||||
|
chat_id=callback.message.chat.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.callback_query(F.data.startswith(CB_CANCEL + ":"))
|
||||||
|
async def on_cancel(callback: CallbackQuery):
|
||||||
|
generation_id = callback.data.split(":", 1)[1]
|
||||||
|
await callback.answer("Отменено")
|
||||||
|
await scheduler.handle_cancel(
|
||||||
|
generation_id=generation_id,
|
||||||
|
message_id=callback.message.message_id,
|
||||||
|
chat_id=callback.message.chat.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return router
|
||||||
51
tests/test_ai_proxy_logic.py
Normal file
51
tests/test_ai_proxy_logic.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
from api.service.generation_service import GenerationService
|
||||||
|
from models.Settings import SystemSettings
|
||||||
|
from models.Generation import Generation
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
|
async def test_generation_service_proxy_logic():
|
||||||
|
dao = MagicMock()
|
||||||
|
gemini = MagicMock()
|
||||||
|
s3_adapter = MagicMock()
|
||||||
|
|
||||||
|
# Mock settings to have proxy ENABLED
|
||||||
|
dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True))
|
||||||
|
dao.assets.get_assets_by_ids = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
service = GenerationService(dao, gemini, s3_adapter)
|
||||||
|
|
||||||
|
# 1. Test ask_prompt_assistant with proxy
|
||||||
|
with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text:
|
||||||
|
mock_proxy_text.return_value = "Proxy Result"
|
||||||
|
result = await service.ask_prompt_assistant("Test Prompt")
|
||||||
|
assert result == "Proxy Result"
|
||||||
|
mock_proxy_text.assert_called_once()
|
||||||
|
gemini.generate_text.assert_not_called()
|
||||||
|
|
||||||
|
# 2. Test create_generation with proxy
|
||||||
|
generation = Generation(
|
||||||
|
prompt="Test Image",
|
||||||
|
aspect_ratio=AspectRatios.ONEONE,
|
||||||
|
quality=Quality.ONEK,
|
||||||
|
assets_list=[]
|
||||||
|
)
|
||||||
|
# Mock _prepare_generation_input to avoid complex DB calls
|
||||||
|
service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", []))
|
||||||
|
service._process_generated_images = AsyncMock(return_value=[])
|
||||||
|
service._finalize_generation = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img:
|
||||||
|
import io
|
||||||
|
mock_img_io = io.BytesIO(b"fake image data")
|
||||||
|
mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0})
|
||||||
|
|
||||||
|
await service.create_generation(generation)
|
||||||
|
mock_proxy_img.assert_called_once()
|
||||||
|
# gemini.generate_image would be called via generate_image_task in else branch
|
||||||
|
|
||||||
|
print("✅ Proxy logic test passed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(test_generation_service_proxy_logic())
|
||||||
Reference in New Issue
Block a user