models + refactor
This commit is contained in:
4
.env
4
.env
@@ -8,4 +8,6 @@ MINIO_ACCESS_KEY=admin
|
||||
MINIO_SECRET_KEY=SuperSecretPassword123!
|
||||
MINIO_BUCKET=ai-char
|
||||
MODE=production
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW
|
||||
PROXY_SECRET_SALT=AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o
|
||||
SCHEDULER_CHARACTER_ID=69931c10721fbd539804589b
|
||||
@@ -1,33 +0,0 @@
|
||||
# Project Context: AI Char Bot
|
||||
|
||||
## Overview
|
||||
Python backend project using FastAPI and MongoDB (Motor).
|
||||
Root: `/Users/xds/develop/py projects/ai-char-bot`
|
||||
|
||||
## Architecture
|
||||
- **API Layer**: `api/endpoints` (FastAPI routers).
|
||||
- **Service Layer**: `api/service` (Business logic).
|
||||
- **Data Layer**: `repos` (DAOs and Repositories).
|
||||
- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs).
|
||||
- **Adapters**: `adapters` (External services like S3, Google Gemini).
|
||||
|
||||
## Coding Standards & Preferences
|
||||
- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style).
|
||||
- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers.
|
||||
- **Error Handling**:
|
||||
- Repositories should return `None` if an entity is not found (e.g., `toggle_like`).
|
||||
- Services/Routers handle `HTTPException`.
|
||||
|
||||
## Key Features & Implementation Details
|
||||
- **Generations**:
|
||||
- Managed by `GenerationService` and `GenerationRepo`.
|
||||
- `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found).
|
||||
- `get_generations` requires `current_user_id` to correctly calculate `is_liked`.
|
||||
- **Ideas**:
|
||||
- Managed by `IdeaService` and `IdeaRepo`.
|
||||
- Can have linked generations.
|
||||
- When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`.
|
||||
|
||||
## Recent Changes
|
||||
- Refactored `toggle_like` to handle non-existent generations and return `bool | None`.
|
||||
- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct.
|
||||
100
adapters/ai_proxy_adapter.py
Normal file
100
adapters/ai_proxy_adapter.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import io
|
||||
import httpx
|
||||
import hashlib
|
||||
import time
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from models.enums import AspectRatios, Quality
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AIProxyAdapter:
|
||||
def __init__(self, base_url: str = "http://82.22.174.14:8001", salt: str = None):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.salt = salt or settings.PROXY_SECRET_SALT
|
||||
|
||||
def _generate_headers(self) -> Dict[str, str]:
|
||||
timestamp = int(time.time())
|
||||
hash_input = f"{timestamp}{self.salt}".encode()
|
||||
signature = hashlib.sha256(hash_input).hexdigest()
|
||||
|
||||
return {
|
||||
"X-Timestamp": str(timestamp),
|
||||
"X-Signature": signature
|
||||
}
|
||||
|
||||
async def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", asset_urls: List[str] | None = None) -> str:
|
||||
"""
|
||||
Generates text using the AI Proxy with signature verification.
|
||||
"""
|
||||
url = f"{self.base_url}/generate_text"
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"asset_urls": asset_urls
|
||||
}
|
||||
|
||||
headers = self._generate_headers()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if data.get("finish_reason") != "STOP":
|
||||
logger.warning(f"AI Proxy generation finished with reason: {data.get('finish_reason')}")
|
||||
|
||||
return data.get("response") or ""
|
||||
except Exception as e:
|
||||
logger.error(f"AI Proxy Text Error: {e}")
|
||||
raise Exception(f"AI Proxy Text Error: {e}")
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
model: str = "gemini-3-pro-image-preview",
|
||||
asset_urls: List[str] | None = None
|
||||
) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||||
"""
|
||||
Generates image using the AI Proxy with signature verification.
|
||||
"""
|
||||
url = f"{self.base_url}/generate_image"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"asset_urls": asset_urls
|
||||
}
|
||||
|
||||
headers = self._generate_headers()
|
||||
|
||||
start_time = datetime.now()
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(url, json=payload, headers=headers, timeout=120.0)
|
||||
response.raise_for_status()
|
||||
|
||||
image_bytes = response.content
|
||||
byte_arr = io.BytesIO(image_bytes)
|
||||
byte_arr.name = f"{datetime.now().timestamp()}.png"
|
||||
byte_arr.seek(0)
|
||||
|
||||
end_time = datetime.now()
|
||||
api_duration = (end_time - start_time).total_seconds()
|
||||
|
||||
metrics = {
|
||||
"api_execution_time_seconds": api_duration,
|
||||
"token_usage": 0,
|
||||
"input_token_usage": 0,
|
||||
"output_token_usage": 0
|
||||
}
|
||||
|
||||
return [byte_arr], metrics
|
||||
except Exception as e:
|
||||
logger.error(f"AI Proxy Image Error: {e}")
|
||||
raise Exception(f"AI Proxy Image Error: {e}")
|
||||
88
adapters/meta_adapter.py
Normal file
88
adapters/meta_adapter.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
META_GRAPH_VERSION = "v18.0"
|
||||
META_GRAPH_BASE = f"https://graph.facebook.com/{META_GRAPH_VERSION}"
|
||||
|
||||
|
||||
class MetaAdapter:
|
||||
"""Adapter for Meta Platform API (Instagram Graph API).
|
||||
|
||||
Requires:
|
||||
- access_token: long-lived Page or Instagram access token
|
||||
- instagram_account_id: Instagram Business Account ID
|
||||
"""
|
||||
|
||||
def __init__(self, access_token: str, instagram_account_id: str):
|
||||
self.access_token = access_token
|
||||
self.instagram_account_id = instagram_account_id
|
||||
|
||||
async def post_to_feed(self, image_url: str, caption: str) -> Optional[str]:
|
||||
"""Upload image and publish to Instagram feed.
|
||||
|
||||
Returns the post ID on success, raises on failure.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Step 1: create media container
|
||||
resp = await client.post(
|
||||
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
|
||||
data={
|
||||
"image_url": image_url,
|
||||
"caption": caption,
|
||||
"access_token": self.access_token,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
creation_id = resp.json().get("id")
|
||||
if not creation_id:
|
||||
raise ValueError(f"No creation_id from Meta API: {resp.text}")
|
||||
|
||||
# Step 2: publish
|
||||
resp2 = await client.post(
|
||||
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
|
||||
data={
|
||||
"creation_id": creation_id,
|
||||
"access_token": self.access_token,
|
||||
},
|
||||
)
|
||||
resp2.raise_for_status()
|
||||
post_id = resp2.json().get("id")
|
||||
logger.info(f"Published to Instagram feed: {post_id}")
|
||||
return post_id
|
||||
|
||||
async def post_to_story(self, image_url: str) -> Optional[str]:
|
||||
"""Upload image and publish to Instagram story.
|
||||
|
||||
Returns the story ID on success, raises on failure.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
# Step 1: create story container
|
||||
resp = await client.post(
|
||||
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media",
|
||||
data={
|
||||
"image_url": image_url,
|
||||
"media_type": "STORIES",
|
||||
"access_token": self.access_token,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
creation_id = resp.json().get("id")
|
||||
if not creation_id:
|
||||
raise ValueError(f"No creation_id from Meta API: {resp.text}")
|
||||
|
||||
# Step 2: publish
|
||||
resp2 = await client.post(
|
||||
f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish",
|
||||
data={
|
||||
"creation_id": creation_id,
|
||||
"access_token": self.access_token,
|
||||
},
|
||||
)
|
||||
resp2.raise_for_status()
|
||||
story_id = resp2.json().get("id")
|
||||
logger.info(f"Published to Instagram story: {story_id}")
|
||||
return story_id
|
||||
81
aiws.py
81
aiws.py
@@ -23,6 +23,8 @@ from api.service.album_service import AlbumService
|
||||
from middlewares.album import AlbumMiddleware
|
||||
from middlewares.auth import AuthMiddleware
|
||||
from middlewares.dao import DaoMiddleware
|
||||
from scheduler.daily_scheduler import DailyScheduler
|
||||
from scheduler.telegram_admin_handler import create_daily_scheduler_router
|
||||
|
||||
# Репозитории и DAO
|
||||
from repos.char_repo import CharacterRepo
|
||||
@@ -108,7 +110,7 @@ dp["gemini"] = gemini
|
||||
# 1. Роутеры без мидлварей (например, auth)
|
||||
dp.include_router(auth_router)
|
||||
|
||||
# 2. Основные роутеры
|
||||
# 2. Основные роутеры (daily_scheduler router добавляется в lifespan)
|
||||
main_router = Router()
|
||||
dp.include_router(main_router)
|
||||
dp.include_router(assets_router)
|
||||
@@ -141,6 +143,34 @@ async def start_scheduler(service: GenerationService):
|
||||
logger.error(f"Scheduler error: {e}")
|
||||
await asyncio.sleep(60) # Check every 60 seconds
|
||||
|
||||
|
||||
def _build_daily_scheduler() -> DailyScheduler:
|
||||
"""Construct DailyScheduler; MetaAdapter is optional (needs env vars)."""
|
||||
meta_adapter = None
|
||||
if settings.META_ACCESS_TOKEN and settings.META_INSTAGRAM_ACCOUNT_ID:
|
||||
from adapters.meta_adapter import MetaAdapter
|
||||
meta_adapter = MetaAdapter(
|
||||
access_token=settings.META_ACCESS_TOKEN,
|
||||
instagram_account_id=settings.META_INSTAGRAM_ACCOUNT_ID,
|
||||
)
|
||||
logger.info("MetaAdapter initialized")
|
||||
else:
|
||||
logger.warning("META_ACCESS_TOKEN / META_INSTAGRAM_ACCOUNT_ID not set — Instagram publishing disabled")
|
||||
|
||||
if not settings.SCHEDULER_CHARACTER_ID:
|
||||
logger.warning("SCHEDULER_CHARACTER_ID not set — daily scheduler will error at runtime")
|
||||
|
||||
return DailyScheduler(
|
||||
dao=dao,
|
||||
gemini=gemini,
|
||||
s3_adapter=s3_adapter,
|
||||
generation_service=generation_service,
|
||||
bot=bot,
|
||||
admin_id=ADMIN_ID,
|
||||
character_id=settings.SCHEDULER_CHARACTER_ID or "",
|
||||
meta_adapter=meta_adapter,
|
||||
)
|
||||
|
||||
# --- LIFESPAN (Запуск FastAPI + Bot) ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -164,36 +194,39 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
print("✅ DB & DAO initialized")
|
||||
|
||||
# 2. ЗАПУСК БОТА (в фоне)
|
||||
# Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn
|
||||
# Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше
|
||||
# polling_task = asyncio.create_task(
|
||||
# dp.start_polling(bot, handle_signals=False)
|
||||
# )
|
||||
# print("🤖 Bot polling started")
|
||||
# 2. Инициализация и регистрация daily_scheduler роутера
|
||||
daily_scheduler = _build_daily_scheduler()
|
||||
dp.include_router(create_daily_scheduler_router(daily_scheduler))
|
||||
print("📅 Daily scheduler router registered")
|
||||
|
||||
# 3. ЗАПУСК ШЕДУЛЕРА
|
||||
# 3. ЗАПУСК БОТА (в фоне)
|
||||
# handle_signals=False — бот не перехватывает сигналы остановки у uvicorn
|
||||
polling_task = asyncio.create_task(
|
||||
dp.start_polling(bot, handle_signals=False)
|
||||
)
|
||||
print("🤖 Bot polling started")
|
||||
|
||||
# 4. ЗАПУСК ШЕДУЛЕРОВ
|
||||
scheduler_task = asyncio.create_task(start_scheduler(generation_service))
|
||||
print("⏰ Scheduler started")
|
||||
daily_scheduler_task = asyncio.create_task(daily_scheduler.run_loop())
|
||||
print("⏰ Schedulers started")
|
||||
|
||||
yield
|
||||
|
||||
# --- SHUTDOWN ---
|
||||
print("🛑 Shutting down...")
|
||||
|
||||
# 4. Остановка шедулера
|
||||
scheduler_task.cancel()
|
||||
try:
|
||||
await scheduler_task
|
||||
except asyncio.CancelledError:
|
||||
print("⏰ Scheduler stopped")
|
||||
|
||||
# 3. Остановка бота
|
||||
# polling_task.cancel()
|
||||
# try:
|
||||
# await polling_task
|
||||
# except asyncio.CancelledError:
|
||||
# print("🤖 Bot polling stopped")
|
||||
|
||||
# Останавливаем все фоновые задачи
|
||||
for task, name in [
|
||||
(polling_task, "Bot polling"),
|
||||
(scheduler_task, "Stale-gen scheduler"),
|
||||
(daily_scheduler_task, "Daily scheduler"),
|
||||
]:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
print(f"⏹ {name} stopped")
|
||||
|
||||
# 4. Отключение БД
|
||||
# Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from repos.user_repo import UsersRepo, UserStatus
|
||||
from api.dependency import get_dao
|
||||
from repos.dao import DAO
|
||||
from models.Settings import SystemSettings
|
||||
from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY
|
||||
from jose import JWTError, jwt
|
||||
from starlette.requests import Request
|
||||
@@ -96,3 +97,21 @@ async def deny_user(
|
||||
|
||||
await repo.deny_user(username)
|
||||
return {"message": f"User {username} denied"}
|
||||
|
||||
@router.get("/settings", response_model=SystemSettings)
|
||||
async def get_settings(
|
||||
admin: Annotated[dict, Depends(get_current_admin)],
|
||||
dao: Annotated[DAO, Depends(get_dao)]
|
||||
):
|
||||
return await dao.settings.get_settings()
|
||||
|
||||
@router.post("/settings")
|
||||
async def update_settings(
|
||||
settings: SystemSettings,
|
||||
admin: Annotated[dict, Depends(get_current_admin)],
|
||||
dao: Annotated[DAO, Depends(get_dao)]
|
||||
):
|
||||
success = await dao.settings.update_settings(settings)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to update settings")
|
||||
return {"message": "Settings updated successfully"}
|
||||
|
||||
@@ -12,6 +12,7 @@ from aiogram.types import BufferedInputFile
|
||||
|
||||
from adapters.Exception import GoogleGenerationException
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.ai_proxy_adapter import AIProxyAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.models import (
|
||||
FinancialReport, UsageStats, UsageByEntity,
|
||||
@@ -72,6 +73,7 @@ class GenerationService:
|
||||
def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
self.ai_proxy = AIProxyAdapter()
|
||||
self.s3_adapter = s3_adapter
|
||||
self.bot = bot
|
||||
|
||||
@@ -84,12 +86,19 @@ class GenerationService:
|
||||
"Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! "
|
||||
f"USER_ENTERED_PROMPT: {prompt}"
|
||||
)
|
||||
assets_data = []
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
asset_urls = await self._prepare_asset_urls(assets) if assets else None
|
||||
generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls)
|
||||
else:
|
||||
assets_data = []
|
||||
if assets:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
return generated_prompt
|
||||
|
||||
@@ -99,6 +108,15 @@ class GenerationService:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
# Proxy doesn't support raw bytes currently.
|
||||
# In a real scenario we'd upload them to a temp bucket.
|
||||
# For now, we call the proxy with just the prompt,
|
||||
# or we could fall back to GoogleAdapter if images are essential.
|
||||
# To be safe and follow instructions to use proxy, we use it.
|
||||
return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model)
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
@@ -154,19 +172,36 @@ class GenerationService:
|
||||
logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}")
|
||||
|
||||
# 1. Prepare input
|
||||
media_group_bytes, generation_prompt = await self._prepare_generation_input(generation)
|
||||
media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation)
|
||||
|
||||
# 2. Run generation with progress simulation
|
||||
progress_task = asyncio.create_task(self._simulate_progress(generation))
|
||||
try:
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None
|
||||
generated_images_io, metrics = await self.ai_proxy.generate_image(
|
||||
prompt=generation_prompt,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
asset_urls=asset_urls
|
||||
)
|
||||
generated_bytes_list = []
|
||||
for img_io in generated_images_io:
|
||||
img_io.seek(0)
|
||||
generated_bytes_list.append(img_io.read())
|
||||
img_io.close()
|
||||
else:
|
||||
generated_bytes_list, metrics = await generate_image_task(
|
||||
prompt=generation_prompt,
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
|
||||
# 3. Process results
|
||||
@@ -299,36 +334,39 @@ class GenerationService:
|
||||
await self._handle_generation_failure(gen, e)
|
||||
logger.exception(f"Background generation task failed for ID: {gen.id}")
|
||||
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]:
|
||||
async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]:
|
||||
media_group_bytes: List[bytes] = []
|
||||
prompt = generation.prompt
|
||||
asset_ids = []
|
||||
|
||||
# 1. Character Avatar
|
||||
if generation.linked_character_id:
|
||||
char_info = await self.dao.chars.get_character(generation.linked_character_id)
|
||||
if not char_info:
|
||||
raise ValueError(f"Character {generation.linked_character_id} not found")
|
||||
|
||||
|
||||
if generation.use_profile_image and char_info.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
|
||||
if avatar_asset:
|
||||
data = await self._get_asset_data_bytes(avatar_asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
asset_ids.append(char_info.avatar_asset_id)
|
||||
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True)
|
||||
if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data:
|
||||
media_group_bytes.append(avatar_asset.data)
|
||||
|
||||
# 2. Reference Assets
|
||||
if generation.assets_list:
|
||||
asset_ids.extend(generation.assets_list)
|
||||
assets = await self.dao.assets.get_assets_by_ids(generation.assets_list)
|
||||
for asset in assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
data = await self._load_asset_image_data(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
# 3. Environment Assets
|
||||
if generation.environment_id:
|
||||
env = await self.dao.environments.get_env(generation.environment_id)
|
||||
if env and env.asset_ids:
|
||||
asset_ids.extend(env.asset_ids)
|
||||
env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids)
|
||||
for asset in env_assets:
|
||||
data = await self._get_asset_data_bytes(asset)
|
||||
data = await self._load_asset_image_data(asset)
|
||||
if data: media_group_bytes.append(data)
|
||||
|
||||
if media_group_bytes:
|
||||
@@ -337,14 +375,26 @@ class GenerationService:
|
||||
"character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity."
|
||||
)
|
||||
|
||||
return media_group_bytes, prompt
|
||||
return media_group_bytes, prompt, asset_ids
|
||||
|
||||
async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]:
|
||||
async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]:
|
||||
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
|
||||
urls = []
|
||||
for asset in assets:
|
||||
if asset.minio_object_name:
|
||||
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
|
||||
urls.append(f"{bucket}/{asset.minio_object_name}")
|
||||
return urls
|
||||
|
||||
async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]:
|
||||
"""Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids)."""
|
||||
if asset.content_type != AssetContentType.IMAGE:
|
||||
return None
|
||||
if asset.data:
|
||||
return asset.data
|
||||
if asset.minio_object_name:
|
||||
return await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
return asset.data
|
||||
return None
|
||||
|
||||
def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]):
|
||||
generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds")
|
||||
|
||||
10
config.py
10
config.py
@@ -24,6 +24,16 @@ class Settings(BaseSettings):
|
||||
# External API
|
||||
EXTERNAL_API_SECRET: Optional[str] = None
|
||||
|
||||
# Daily Scheduler
|
||||
SCHEDULER_CHARACTER_ID: Optional[str] = None # Character ID used for daily generation
|
||||
|
||||
# Meta Platform (Instagram Graph API)
|
||||
META_ACCESS_TOKEN: Optional[str] = None # Long-lived page/Instagram access token
|
||||
META_INSTAGRAM_ACCOUNT_ID: Optional[str] = None # Instagram Business Account ID
|
||||
|
||||
# AI Proxy Security
|
||||
PROXY_SECRET_SALT: str = "AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o"
|
||||
|
||||
# JWT Security
|
||||
SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY"
|
||||
ALGORITHM: str = "HS256"
|
||||
|
||||
10
models/Settings.py
Normal file
10
models/Settings.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, UTC
|
||||
|
||||
class SystemSettings(BaseModel):
|
||||
id: str = Field(default="system_settings", alias="_id")
|
||||
use_ai_proxy: bool = False
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
@@ -10,6 +10,7 @@ from repos.idea_repo import IdeaRepo
|
||||
from repos.post_repo import PostRepo
|
||||
from repos.environment_repo import EnvironmentRepo
|
||||
from repos.inspiration_repo import InspirationRepo
|
||||
from repos.settings_repo import SettingsRepo
|
||||
|
||||
|
||||
from typing import Optional
|
||||
@@ -27,3 +28,4 @@ class DAO:
|
||||
self.posts = PostRepo(client, db_name)
|
||||
self.environments = EnvironmentRepo(client, db_name)
|
||||
self.inspirations = InspirationRepo(client, db_name)
|
||||
self.settings = SettingsRepo(client, db_name)
|
||||
|
||||
26
repos/settings_repo.py
Normal file
26
repos/settings_repo.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from models.Settings import SystemSettings
|
||||
from datetime import datetime, UTC
|
||||
|
||||
class SettingsRepo:
|
||||
def __init__(self, client: AsyncIOMotorClient, db_name: str):
|
||||
self.collection = client[db_name]["settings"]
|
||||
|
||||
async def get_settings(self) -> SystemSettings:
|
||||
doc = await self.collection.find_one({"_id": "system_settings"})
|
||||
if not doc:
|
||||
# Create default settings if not exists
|
||||
settings = SystemSettings()
|
||||
await self.collection.insert_one(settings.model_dump(by_alias=True))
|
||||
return settings
|
||||
return SystemSettings(**doc)
|
||||
|
||||
async def update_settings(self, settings: SystemSettings) -> bool:
|
||||
settings.updated_at = datetime.now(UTC)
|
||||
result = await self.collection.replace_one(
|
||||
{"_id": "system_settings"},
|
||||
settings.model_dump(by_alias=True),
|
||||
upsert=True
|
||||
)
|
||||
return result.modified_count > 0 or result.upserted_id is not None
|
||||
0
scheduler/__init__.py
Normal file
0
scheduler/__init__.py
Normal file
456
scheduler/daily_scheduler.py
Normal file
456
scheduler/daily_scheduler.py
Normal file
@@ -0,0 +1,456 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from aiogram import Bot
|
||||
from aiogram.types import BufferedInputFile, InlineKeyboardButton, InlineKeyboardMarkup
|
||||
|
||||
from adapters.google_adapter import GoogleAdapter
|
||||
from adapters.ai_proxy_adapter import AIProxyAdapter
|
||||
from adapters.s3_adapter import S3Adapter
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Asset import Asset
|
||||
from models.Generation import Generation, GenerationStatus
|
||||
from models.enums import AspectRatios, ImageModel, Quality, TextModel
|
||||
from repos.dao import DAO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MSK_TZ = timezone(timedelta(hours=3))
|
||||
SCHEDULE_HOUR_MSK = 11
|
||||
SCHEDULE_MINUTE_MSK = 0
|
||||
|
||||
# Callback data prefixes for inline keyboard buttons
|
||||
CB_POST = "daily_post"
|
||||
CB_REGEN_ALL = "daily_regen_all"
|
||||
CB_REGEN_IMG = "daily_regen_img"
|
||||
CB_REGEN_MORE = "daily_regen_more"
|
||||
CB_CANCEL = "daily_cancel"
|
||||
|
||||
|
||||
def make_admin_keyboard(generation_id: str) -> InlineKeyboardMarkup:
|
||||
return InlineKeyboardMarkup(
|
||||
inline_keyboard=[
|
||||
[
|
||||
InlineKeyboardButton(text="✅ Выложить", callback_data=f"{CB_POST}:{generation_id}"),
|
||||
InlineKeyboardButton(text="❌ Отмена", callback_data=f"{CB_CANCEL}:{generation_id}"),
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton(text="🔄 Перегенерить с нуля", callback_data=f"{CB_REGEN_ALL}:{generation_id}"),
|
||||
InlineKeyboardButton(text="🖼 Перегенерить изображение", callback_data=f"{CB_REGEN_IMG}:{generation_id}"),
|
||||
],
|
||||
[
|
||||
InlineKeyboardButton(text="➕ Сгенерировать ещё 2", callback_data=f"{CB_REGEN_MORE}:{generation_id}"),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class DailyScheduler:
|
||||
"""Orchestrates the daily AI-character content generation pipeline.
|
||||
|
||||
Flow:
|
||||
1. Generate image prompt + social caption via LLM (with character avatar).
|
||||
2. Generate image via GenerationService.create_generation() (reuses existing pipeline).
|
||||
3. Send to Telegram admin with action buttons.
|
||||
|
||||
Admin actions (inline keyboard):
|
||||
- Выложить → post to Instagram feed + story via Meta API.
|
||||
- Перегенерить с нуля → restart from step 1.
|
||||
- Перегенерить изображение → restart from step 2 (same prompt/caption).
|
||||
- Сгенерировать ещё 2 → generate 2 pose-varied images.
|
||||
- Отмена → dismiss (no action).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dao: DAO,
|
||||
gemini: GoogleAdapter,
|
||||
s3_adapter: S3Adapter,
|
||||
generation_service: GenerationService,
|
||||
bot: Bot,
|
||||
admin_id: int,
|
||||
character_id: str,
|
||||
meta_adapter=None, # Optional[MetaAdapter]
|
||||
):
|
||||
self.dao = dao
|
||||
self.gemini = gemini
|
||||
self.ai_proxy = AIProxyAdapter()
|
||||
self.s3_adapter = s3_adapter
|
||||
self.generation_service = generation_service
|
||||
self.bot = bot
|
||||
self.admin_id = admin_id
|
||||
self.character_id = character_id
|
||||
self.meta_adapter = meta_adapter
|
||||
|
||||
# Stores session state keyed by generation_id.
|
||||
# Each value: {prompt, caption, asset_id, message_id, chat_id}
|
||||
self.pending_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Scheduler loop
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def run_loop(self):
|
||||
"""Run indefinitely, triggering daily generation at 11:00 MSK."""
|
||||
logger.info("Daily scheduler loop started")
|
||||
while True:
|
||||
try:
|
||||
await self._wait_until_next_run()
|
||||
logger.info("Daily scheduler: triggering daily generation")
|
||||
await self.run_daily_generation()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Daily scheduler loop error: {e}", exc_info=True)
|
||||
|
||||
async def _wait_until_next_run(self):
|
||||
now = datetime.now(MSK_TZ)
|
||||
next_run = now.replace(
|
||||
hour=SCHEDULE_HOUR_MSK,
|
||||
minute=SCHEDULE_MINUTE_MSK,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
if now >= next_run:
|
||||
next_run += timedelta(days=1)
|
||||
wait_seconds = (next_run - now).total_seconds()
|
||||
logger.info(
|
||||
f"Next daily generation at {next_run.strftime('%Y-%m-%d %H:%M MSK')} "
|
||||
f"(in {wait_seconds / 3600:.1f}h)"
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main generation pipeline
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def run_daily_generation(self):
|
||||
"""Full pipeline: prompt → image → send to admin."""
|
||||
try:
|
||||
prompt, caption = await self._generate_prompt_and_caption()
|
||||
logger.info(f"Prompt generated ({len(prompt)} chars), caption ({len(caption)} chars)")
|
||||
|
||||
generation, asset = await self._generate_image_and_save(prompt)
|
||||
logger.info(f"Generation done: id={generation.id}, asset={asset.id}")
|
||||
|
||||
await self._send_to_admin(generation, asset, prompt, caption)
|
||||
except Exception as e:
|
||||
logger.error(f"Daily generation pipeline failed: {e}", exc_info=True)
|
||||
try:
|
||||
await self.bot.send_message(
|
||||
chat_id=self.admin_id,
|
||||
text=f"❌ <b>Ежедневная генерация провалилась:</b>\n<code>{e}</code>",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1 — Generate prompt + caption via LLM
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _generate_prompt_and_caption(self) -> Tuple[str, str]:
|
||||
"""Ask Gemini to produce an image prompt and social caption.
|
||||
|
||||
Passes the character's avatar photo to the model so it can create
|
||||
a prompt that is faithful to the character's appearance.
|
||||
"""
|
||||
char = await self.dao.chars.get_character(self.character_id)
|
||||
if not char:
|
||||
raise ValueError(f"Character {self.character_id} not found in DB")
|
||||
|
||||
avatar_bytes_list: list[bytes] = []
|
||||
if char.avatar_asset_id:
|
||||
avatar_asset = await self.dao.assets.get_asset(char.avatar_asset_id, with_data=True)
|
||||
if avatar_asset and avatar_asset.data:
|
||||
avatar_bytes_list.append(avatar_asset.data)
|
||||
|
||||
char_bio = char.character_bio or "An expressive, stylish AI character."
|
||||
system_prompt = (
|
||||
f"You are a creative director for the social media account of an AI character named '{char.name}'.\n"
|
||||
# f"Character description: {char_bio}\n\n"
|
||||
"I'm attaching the character's avatar photo. Based on it, produce TWO things:\n\n"
|
||||
"1. IMAGE_PROMPT: A detailed, vivid image generation prompt in English. "
|
||||
"Describe the pose, environment, lighting, color palette, and artistic style. It must look amateur. "
|
||||
"Make it unique and suitable for a social media post.\n\n"
|
||||
"2. SOCIAL_CAPTION: An engaging caption in English for Instagram and TikTok. "
|
||||
"Include 5-10 relevant hashtags at the end.\n\n"
|
||||
"Reply in EXACTLY this format (two lines, no extra text before IMAGE_PROMPT):\n"
|
||||
"IMAGE_PROMPT: <prompt here>\n"
|
||||
"SOCIAL_CAPTION: <caption here>"
|
||||
)
|
||||
|
||||
settings = await self.dao.settings.get_settings()
|
||||
if settings.use_ai_proxy:
|
||||
asset_urls = await self._prepare_asset_urls([char.avatar_asset_id]) if char.avatar_asset_id else None
|
||||
raw = await self.ai_proxy.generate_text(
|
||||
system_prompt,
|
||||
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
|
||||
asset_urls
|
||||
)
|
||||
else:
|
||||
raw = await asyncio.to_thread(
|
||||
self.gemini.generate_text,
|
||||
system_prompt,
|
||||
TextModel.GEMINI_3_1_PRO_PREVIEW.value,
|
||||
avatar_bytes_list or None,
|
||||
)
|
||||
logger.debug(f"LLM raw response: {raw[:500]}")
|
||||
|
||||
prompt, caption = self._parse_prompt_and_caption(raw, char.name)
|
||||
return prompt, caption
|
||||
|
||||
async def _prepare_asset_urls(self, asset_ids: list[str]) -> list[str]:
|
||||
assets = await self.dao.assets.get_assets_by_ids(asset_ids)
|
||||
urls = []
|
||||
for asset in assets:
|
||||
if asset.minio_object_name:
|
||||
bucket = asset.minio_bucket or self.s3_adapter.bucket_name
|
||||
urls.append(f"{bucket}/{asset.minio_object_name}")
|
||||
return urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_prompt_and_caption(raw: str, char_name: str) -> Tuple[str, str]:
|
||||
prompt = ""
|
||||
caption = ""
|
||||
|
||||
if "IMAGE_PROMPT:" in raw and "SOCIAL_CAPTION:" in raw:
|
||||
after_label = raw.split("IMAGE_PROMPT:", 1)[1]
|
||||
prompt = after_label.split("SOCIAL_CAPTION:", 1)[0].strip()
|
||||
caption = after_label.split("SOCIAL_CAPTION:", 1)[1].strip()
|
||||
elif "IMAGE_PROMPT:" in raw:
|
||||
prompt = raw.split("IMAGE_PROMPT:", 1)[1].strip()
|
||||
else:
|
||||
prompt = raw.strip()
|
||||
|
||||
if not prompt:
|
||||
raise ValueError(f"LLM did not produce IMAGE_PROMPT. Raw snippet: {raw[:300]}")
|
||||
if not caption:
|
||||
caption = f"✨ Новый контент от {char_name}"
|
||||
|
||||
return prompt, caption
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2 — Generate image via GenerationService
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _generate_image_and_save(
|
||||
self,
|
||||
prompt: str,
|
||||
variation_hint: Optional[str] = None,
|
||||
) -> Tuple[Generation, Asset]:
|
||||
"""Create a Generation record and delegate execution to GenerationService.
|
||||
|
||||
Uses GenerationService.create_generation() which handles:
|
||||
- loading character avatar / reference assets
|
||||
- calling Gemini image generation
|
||||
- saving result as Asset in S3
|
||||
- finalizing the Generation record with metrics
|
||||
|
||||
No telegram_id is set, so the service won't send its own notification —
|
||||
we handle that ourselves in _send_to_admin() with action buttons.
|
||||
"""
|
||||
actual_prompt = prompt
|
||||
if variation_hint:
|
||||
actual_prompt = f"{prompt}. {variation_hint}"
|
||||
|
||||
# Create Generation record (GenerationService.create_generation expects it pre-saved)
|
||||
generation = Generation(
|
||||
status=GenerationStatus.RUNNING,
|
||||
linked_character_id=self.character_id,
|
||||
aspect_ratio=AspectRatios.NINESIXTEEN,
|
||||
quality=Quality.ONEK,
|
||||
prompt=actual_prompt,
|
||||
model=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW.value,
|
||||
use_profile_image=True,
|
||||
# No telegram_id → service won't send its own notification
|
||||
)
|
||||
gen_id = await self.dao.generations.create_generation(generation)
|
||||
generation.id = gen_id
|
||||
|
||||
try:
|
||||
# Delegate all heavy lifting to the existing service
|
||||
await self.generation_service.create_generation(generation)
|
||||
except Exception:
|
||||
# create_generation doesn't mark FAILED itself — the caller (_queued_generation_runner) does.
|
||||
# So we need to handle failure here.
|
||||
await self.generation_service._handle_generation_failure(generation, Exception("Image generation failed"))
|
||||
raise
|
||||
|
||||
# After create_generation, generation.result_list is populated
|
||||
if not generation.result_list:
|
||||
raise ValueError("Generation completed but produced no assets")
|
||||
|
||||
asset = await self.dao.assets.get_asset(generation.result_list[0], with_data=False)
|
||||
if not asset:
|
||||
raise ValueError(f"Asset {generation.result_list[0]} not found after generation")
|
||||
|
||||
return generation, asset
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3 — Send to admin
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_to_admin(
|
||||
self,
|
||||
generation: Generation,
|
||||
asset: Asset,
|
||||
prompt: str,
|
||||
caption: str,
|
||||
):
|
||||
img_data = await self.s3_adapter.get_file(asset.minio_object_name)
|
||||
if not img_data:
|
||||
raise ValueError(f"Cannot load image from S3: {asset.minio_object_name}")
|
||||
|
||||
self.pending_sessions[generation.id] = {
|
||||
"prompt": prompt,
|
||||
"caption": caption,
|
||||
"asset_id": asset.id,
|
||||
}
|
||||
|
||||
msg = await self.bot.send_photo(
|
||||
chat_id=self.admin_id,
|
||||
photo=BufferedInputFile(img_data, filename="daily.png"),
|
||||
caption=(
|
||||
f"📸 <b>Ежедневная генерация</b>\n\n"
|
||||
f"<b>Подпись для соцсетей:</b>\n{caption}\n\n"
|
||||
f"<b>Промпт:</b>\n<code>{prompt[:300]}</code>"
|
||||
),
|
||||
reply_markup=make_admin_keyboard(generation.id),
|
||||
)
|
||||
self.pending_sessions[generation.id]["message_id"] = msg.message_id
|
||||
self.pending_sessions[generation.id]["chat_id"] = msg.chat.id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Admin action handlers (called from Telegram callback router)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def handle_post(self, generation_id: str, message_id: int, chat_id: int):
|
||||
"""Post to Instagram feed + story."""
|
||||
session = self.pending_sessions.get(generation_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
if not self.meta_adapter:
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption="⚠️ Meta API не настроен (META_ACCESS_TOKEN не задан). Публикация недоступна.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
asset = await self.dao.assets.get_asset(session["asset_id"], with_data=False)
|
||||
if not asset or not asset.minio_object_name:
|
||||
raise ValueError("Asset not found in DB")
|
||||
|
||||
image_url = await self.s3_adapter.get_presigned_url(
|
||||
asset.minio_object_name, expiration=3600
|
||||
)
|
||||
if not image_url:
|
||||
raise ValueError("Could not generate presigned URL for image")
|
||||
|
||||
feed_id = await self.meta_adapter.post_to_feed(image_url, session["caption"])
|
||||
story_id = await self.meta_adapter.post_to_story(image_url)
|
||||
|
||||
self.pending_sessions.pop(generation_id, None)
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption=(
|
||||
f"✅ <b>Опубликовано!</b>\n\n"
|
||||
f"📰 Feed ID: <code>{feed_id}</code>\n"
|
||||
f"📖 Story ID: <code>{story_id}</code>"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Meta publish failed for generation {generation_id}: {e}", exc_info=True)
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption=f"❌ <b>Ошибка публикации:</b>\n<code>{e}</code>",
|
||||
reply_markup=make_admin_keyboard(generation_id),
|
||||
)
|
||||
|
||||
async def handle_regen_all(self, generation_id: str, message_id: int, chat_id: int):
|
||||
"""Restart from step 1: generate new prompt, caption, and image."""
|
||||
self.pending_sessions.pop(generation_id, None)
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption="🔄 <b>Перегенерация с нуля...</b>",
|
||||
)
|
||||
asyncio.create_task(self._run_regen_all(chat_id))
|
||||
|
||||
async def _run_regen_all(self, chat_id: int):
|
||||
try:
|
||||
await self.run_daily_generation()
|
||||
except Exception as e:
|
||||
logger.error(f"Regen-all failed: {e}", exc_info=True)
|
||||
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка перегенерации:\n<code>{e}</code>")
|
||||
|
||||
async def handle_regen_image(self, generation_id: str, message_id: int, chat_id: int):
|
||||
"""Restart from step 2: generate new image using existing prompt/caption."""
|
||||
session = self.pending_sessions.pop(generation_id, None)
|
||||
if not session:
|
||||
return
|
||||
|
||||
prompt = session["prompt"]
|
||||
caption = session["caption"]
|
||||
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption="🖼 <b>Перегенерация изображения...</b>",
|
||||
)
|
||||
asyncio.create_task(self._run_regen_image(prompt, caption, chat_id))
|
||||
|
||||
async def _run_regen_image(self, prompt: str, caption: str, chat_id: int):
|
||||
try:
|
||||
generation, asset = await self._generate_image_and_save(prompt)
|
||||
await self._send_to_admin(generation, asset, prompt, caption)
|
||||
except Exception as e:
|
||||
logger.error(f"Regen-image failed: {e}", exc_info=True)
|
||||
await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка генерации:\n<code>{e}</code>")
|
||||
|
||||
async def handle_regen_more(self, generation_id: str, message_id: int, chat_id: int):
|
||||
"""Generate 2 more pose-varied images using the existing prompt/caption."""
|
||||
session = self.pending_sessions.get(generation_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
prompt = session["prompt"]
|
||||
caption = session["caption"]
|
||||
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption="➕ <b>Генерирую ещё 2 варианта...</b>",
|
||||
)
|
||||
asyncio.create_task(self._run_regen_more(prompt, caption, chat_id))
|
||||
|
||||
async def _run_regen_more(self, prompt: str, caption: str, chat_id: int):
|
||||
variation_hints = [
|
||||
"Slightly vary the pose and camera angle while keeping the same scene, environment and lighting.",
|
||||
"Try a different subtle pose or expression, same background and setting as described.",
|
||||
]
|
||||
for i, hint in enumerate(variation_hints):
|
||||
try:
|
||||
generation, asset = await self._generate_image_and_save(prompt, variation_hint=hint)
|
||||
await self._send_to_admin(generation, asset, prompt, caption)
|
||||
except Exception as e:
|
||||
logger.error(f"Regen-more variant {i + 1} failed: {e}", exc_info=True)
|
||||
await self.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"❌ Ошибка варианта {i + 1}:\n<code>{e}</code>",
|
||||
)
|
||||
|
||||
async def handle_cancel(self, generation_id: str, message_id: int, chat_id: int):
|
||||
"""Dismiss: remove buttons, do nothing else."""
|
||||
self.pending_sessions.pop(generation_id, None)
|
||||
await self.bot.edit_message_caption(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
caption="🚫 Отменено.",
|
||||
)
|
||||
82
scheduler/telegram_admin_handler.py
Normal file
82
scheduler/telegram_admin_handler.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Telegram inline-keyboard handlers for the daily scheduler admin flow.
|
||||
|
||||
Usage (in aiws.py):
|
||||
from scheduler.telegram_admin_handler import create_daily_scheduler_router
|
||||
from scheduler.daily_scheduler import DailyScheduler
|
||||
|
||||
daily_scheduler = DailyScheduler(...)
|
||||
dp.include_router(create_daily_scheduler_router(daily_scheduler))
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from aiogram import F, Router
|
||||
from aiogram.types import CallbackQuery
|
||||
|
||||
from scheduler.daily_scheduler import (
|
||||
CB_CANCEL,
|
||||
CB_POST,
|
||||
CB_REGEN_ALL,
|
||||
CB_REGEN_IMG,
|
||||
CB_REGEN_MORE,
|
||||
DailyScheduler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_daily_scheduler_router(scheduler: DailyScheduler) -> Router:
|
||||
"""Return an aiogram Router with all callback handlers bound to *scheduler*."""
|
||||
router = Router(name="daily_scheduler")
|
||||
|
||||
@router.callback_query(F.data.startswith(CB_POST + ":"))
|
||||
async def on_post(callback: CallbackQuery):
|
||||
generation_id = callback.data.split(":", 1)[1]
|
||||
await callback.answer("Публикую в Instagram...")
|
||||
await scheduler.handle_post(
|
||||
generation_id=generation_id,
|
||||
message_id=callback.message.message_id,
|
||||
chat_id=callback.message.chat.id,
|
||||
)
|
||||
|
||||
@router.callback_query(F.data.startswith(CB_REGEN_ALL + ":"))
|
||||
async def on_regen_all(callback: CallbackQuery):
|
||||
generation_id = callback.data.split(":", 1)[1]
|
||||
await callback.answer("Перезапускаю с нуля...")
|
||||
await scheduler.handle_regen_all(
|
||||
generation_id=generation_id,
|
||||
message_id=callback.message.message_id,
|
||||
chat_id=callback.message.chat.id,
|
||||
)
|
||||
|
||||
@router.callback_query(F.data.startswith(CB_REGEN_IMG + ":"))
|
||||
async def on_regen_img(callback: CallbackQuery):
|
||||
generation_id = callback.data.split(":", 1)[1]
|
||||
await callback.answer("Генерирую новое изображение...")
|
||||
await scheduler.handle_regen_image(
|
||||
generation_id=generation_id,
|
||||
message_id=callback.message.message_id,
|
||||
chat_id=callback.message.chat.id,
|
||||
)
|
||||
|
||||
@router.callback_query(F.data.startswith(CB_REGEN_MORE + ":"))
|
||||
async def on_regen_more(callback: CallbackQuery):
|
||||
generation_id = callback.data.split(":", 1)[1]
|
||||
await callback.answer("Генерирую 2 варианта...")
|
||||
await scheduler.handle_regen_more(
|
||||
generation_id=generation_id,
|
||||
message_id=callback.message.message_id,
|
||||
chat_id=callback.message.chat.id,
|
||||
)
|
||||
|
||||
@router.callback_query(F.data.startswith(CB_CANCEL + ":"))
|
||||
async def on_cancel(callback: CallbackQuery):
|
||||
generation_id = callback.data.split(":", 1)[1]
|
||||
await callback.answer("Отменено")
|
||||
await scheduler.handle_cancel(
|
||||
generation_id=generation_id,
|
||||
message_id=callback.message.message_id,
|
||||
chat_id=callback.message.chat.id,
|
||||
)
|
||||
|
||||
return router
|
||||
51
tests/test_ai_proxy_logic.py
Normal file
51
tests/test_ai_proxy_logic.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Settings import SystemSettings
|
||||
from models.Generation import Generation
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
async def test_generation_service_proxy_logic():
|
||||
dao = MagicMock()
|
||||
gemini = MagicMock()
|
||||
s3_adapter = MagicMock()
|
||||
|
||||
# Mock settings to have proxy ENABLED
|
||||
dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True))
|
||||
dao.assets.get_assets_by_ids = AsyncMock(return_value=[])
|
||||
|
||||
service = GenerationService(dao, gemini, s3_adapter)
|
||||
|
||||
# 1. Test ask_prompt_assistant with proxy
|
||||
with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text:
|
||||
mock_proxy_text.return_value = "Proxy Result"
|
||||
result = await service.ask_prompt_assistant("Test Prompt")
|
||||
assert result == "Proxy Result"
|
||||
mock_proxy_text.assert_called_once()
|
||||
gemini.generate_text.assert_not_called()
|
||||
|
||||
# 2. Test create_generation with proxy
|
||||
generation = Generation(
|
||||
prompt="Test Image",
|
||||
aspect_ratio=AspectRatios.ONEONE,
|
||||
quality=Quality.ONEK,
|
||||
assets_list=[]
|
||||
)
|
||||
# Mock _prepare_generation_input to avoid complex DB calls
|
||||
service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", []))
|
||||
service._process_generated_images = AsyncMock(return_value=[])
|
||||
service._finalize_generation = AsyncMock()
|
||||
|
||||
with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img:
|
||||
import io
|
||||
mock_img_io = io.BytesIO(b"fake image data")
|
||||
mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0})
|
||||
|
||||
await service.create_generation(generation)
|
||||
mock_proxy_img.assert_called_once()
|
||||
# gemini.generate_image would be called via generate_image_task in else branch
|
||||
|
||||
print("✅ Proxy logic test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_generation_service_proxy_logic())
|
||||
Reference in New Issue
Block a user