From 14f9e7b7e9b40e59eaaecd2372e0b9b2d855e88a Mon Sep 17 00:00:00 2001 From: xds Date: Tue, 17 Mar 2026 16:46:32 +0300 Subject: [PATCH] models + refactor --- .env | 4 +- .gemini/AGENTS.md | 33 -- adapters/ai_proxy_adapter.py | 100 ++++++ adapters/meta_adapter.py | 88 ++++++ aiws.py | 81 +++-- api/endpoints/admin.py | 19 ++ api/service/generation_service.py | 100 ++++-- config.py | 10 + models/Settings.py | 10 + repos/dao.py | 2 + repos/settings_repo.py | 26 ++ scheduler/__init__.py | 0 scheduler/daily_scheduler.py | 456 ++++++++++++++++++++++++++++ scheduler/telegram_admin_handler.py | 82 +++++ tests/test_ai_proxy_logic.py | 51 ++++ 15 files changed, 979 insertions(+), 83 deletions(-) delete mode 100644 .gemini/AGENTS.md create mode 100644 adapters/ai_proxy_adapter.py create mode 100644 adapters/meta_adapter.py create mode 100644 models/Settings.py create mode 100644 repos/settings_repo.py create mode 100644 scheduler/__init__.py create mode 100644 scheduler/daily_scheduler.py create mode 100644 scheduler/telegram_admin_handler.py create mode 100644 tests/test_ai_proxy_logic.py diff --git a/.env b/.env index 11d172a..6a2082b 100644 --- a/.env +++ b/.env @@ -8,4 +8,6 @@ MINIO_ACCESS_KEY=admin MINIO_SECRET_KEY=SuperSecretPassword123! MINIO_BUCKET=ai-char MODE=production -EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW \ No newline at end of file +EXTERNAL_API_SECRET=Gt9TyQ8OAYhcELh2YCbKjdHLflZGufKHJZcG338MQDW +PROXY_SECRET_SALT=AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o +SCHEDULER_CHARACTER_ID=69931c10721fbd539804589b \ No newline at end of file diff --git a/.gemini/AGENTS.md b/.gemini/AGENTS.md deleted file mode 100644 index d33bda8..0000000 --- a/.gemini/AGENTS.md +++ /dev/null @@ -1,33 +0,0 @@ -# Project Context: AI Char Bot - -## Overview -Python backend project using FastAPI and MongoDB (Motor). -Root: `/Users/xds/develop/py projects/ai-char-bot` - -## Architecture -- **API Layer**: `api/endpoints` (FastAPI routers). -- **Service Layer**: `api/service` (Business logic). -- **Data Layer**: `repos` (DAOs and Repositories). -- **Models**: `models` (Domain models) and `api/models` (Request/Response DTOs). -- **Adapters**: `adapters` (External services like S3, Google Gemini). - -## Coding Standards & Preferences -- **Type Hinting**: Use `Type | None` instead of `Optional[Type]` (Python 3.10+ style). -- **Async/Await**: Extensive use of `asyncio` and asynchronous DB drivers. -- **Error Handling**: - - Repositories should return `None` if an entity is not found (e.g., `toggle_like`). - - Services/Routers handle `HTTPException`. - -## Key Features & Implementation Details -- **Generations**: - - Managed by `GenerationService` and `GenerationRepo`. - - `toggle_like` returns `bool | None` (True=Liked, False=Unliked, None=Not Found). - - `get_generations` requires `current_user_id` to correctly calculate `is_liked`. -- **Ideas**: - - Managed by `IdeaService` and `IdeaRepo`. - - Can have linked generations. - - When fetching generations for an idea, ensure `current_user_id` is passed to `GenerationService`. - -## Recent Changes -- Refactored `toggle_like` to handle non-existent generations and return `bool | None`. -- Updated `IdeaRouter` to pass `current_user_id` when fetching generations to ensure `is_liked` flag is correct. diff --git a/adapters/ai_proxy_adapter.py b/adapters/ai_proxy_adapter.py new file mode 100644 index 0000000..b886241 --- /dev/null +++ b/adapters/ai_proxy_adapter.py @@ -0,0 +1,100 @@ +import logging +import io +import httpx +import hashlib +import time +from typing import List, Tuple, Dict, Any, Optional +from datetime import datetime +from models.enums import AspectRatios, Quality +from config import settings + +logger = logging.getLogger(__name__) + +class AIProxyAdapter: + def __init__(self, base_url: str = "http://82.22.174.14:8001", salt: str = None): + self.base_url = base_url.rstrip("/") + self.salt = salt or settings.PROXY_SECRET_SALT + + def _generate_headers(self) -> Dict[str, str]: + timestamp = int(time.time()) + hash_input = f"{timestamp}{self.salt}".encode() + signature = hashlib.sha256(hash_input).hexdigest() + + return { + "X-Timestamp": str(timestamp), + "X-Signature": signature + } + + async def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", asset_urls: List[str] | None = None) -> str: + """ + Generates text using the AI Proxy with signature verification. + """ + url = f"{self.base_url}/generate_text" + + messages = [{"role": "user", "content": prompt}] + payload = { + "messages": messages, + "asset_urls": asset_urls + } + + headers = self._generate_headers() + + async with httpx.AsyncClient() as client: + try: + response = await client.post(url, json=payload, headers=headers, timeout=60.0) + response.raise_for_status() + data = response.json() + + if data.get("finish_reason") != "STOP": + logger.warning(f"AI Proxy generation finished with reason: {data.get('finish_reason')}") + + return data.get("response") or "" + except Exception as e: + logger.error(f"AI Proxy Text Error: {e}") + raise Exception(f"AI Proxy Text Error: {e}") + + async def generate_image( + self, + prompt: str, + aspect_ratio: AspectRatios, + quality: Quality, + model: str = "gemini-3-pro-image-preview", + asset_urls: List[str] | None = None + ) -> Tuple[List[io.BytesIO], Dict[str, Any]]: + """ + Generates image using the AI Proxy with signature verification. + """ + url = f"{self.base_url}/generate_image" + + payload = { + "prompt": prompt, + "asset_urls": asset_urls + } + + headers = self._generate_headers() + + start_time = datetime.now() + async with httpx.AsyncClient() as client: + try: + response = await client.post(url, json=payload, headers=headers, timeout=120.0) + response.raise_for_status() + + image_bytes = response.content + byte_arr = io.BytesIO(image_bytes) + byte_arr.name = f"{datetime.now().timestamp()}.png" + byte_arr.seek(0) + + end_time = datetime.now() + api_duration = (end_time - start_time).total_seconds() + + metrics = { + "api_execution_time_seconds": api_duration, + "token_usage": 0, + "input_token_usage": 0, + "output_token_usage": 0 + } + + return [byte_arr], metrics + except Exception as e: + logger.error(f"AI Proxy Image Error: {e}") + raise Exception(f"AI Proxy Image Error: {e}") diff --git a/adapters/meta_adapter.py b/adapters/meta_adapter.py new file mode 100644 index 0000000..33a336c --- /dev/null +++ b/adapters/meta_adapter.py @@ -0,0 +1,88 @@ +import logging +from typing import Optional + +import httpx + +logger = logging.getLogger(__name__) + +META_GRAPH_VERSION = "v18.0" +META_GRAPH_BASE = f"https://graph.facebook.com/{META_GRAPH_VERSION}" + + +class MetaAdapter: + """Adapter for Meta Platform API (Instagram Graph API). + + Requires: + - access_token: long-lived Page or Instagram access token + - instagram_account_id: Instagram Business Account ID + """ + + def __init__(self, access_token: str, instagram_account_id: str): + self.access_token = access_token + self.instagram_account_id = instagram_account_id + + async def post_to_feed(self, image_url: str, caption: str) -> Optional[str]: + """Upload image and publish to Instagram feed. + + Returns the post ID on success, raises on failure. + """ + async with httpx.AsyncClient(timeout=30.0) as client: + # Step 1: create media container + resp = await client.post( + f"{META_GRAPH_BASE}/{self.instagram_account_id}/media", + data={ + "image_url": image_url, + "caption": caption, + "access_token": self.access_token, + }, + ) + resp.raise_for_status() + creation_id = resp.json().get("id") + if not creation_id: + raise ValueError(f"No creation_id from Meta API: {resp.text}") + + # Step 2: publish + resp2 = await client.post( + f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish", + data={ + "creation_id": creation_id, + "access_token": self.access_token, + }, + ) + resp2.raise_for_status() + post_id = resp2.json().get("id") + logger.info(f"Published to Instagram feed: {post_id}") + return post_id + + async def post_to_story(self, image_url: str) -> Optional[str]: + """Upload image and publish to Instagram story. + + Returns the story ID on success, raises on failure. + """ + async with httpx.AsyncClient(timeout=30.0) as client: + # Step 1: create story container + resp = await client.post( + f"{META_GRAPH_BASE}/{self.instagram_account_id}/media", + data={ + "image_url": image_url, + "media_type": "STORIES", + "access_token": self.access_token, + }, + ) + resp.raise_for_status() + creation_id = resp.json().get("id") + if not creation_id: + raise ValueError(f"No creation_id from Meta API: {resp.text}") + + # Step 2: publish + resp2 = await client.post( + f"{META_GRAPH_BASE}/{self.instagram_account_id}/media_publish", + data={ + "creation_id": creation_id, + "access_token": self.access_token, + }, + ) + resp2.raise_for_status() + story_id = resp2.json().get("id") + logger.info(f"Published to Instagram story: {story_id}") + return story_id diff --git a/aiws.py b/aiws.py index 5bb94a2..025b453 100644 --- a/aiws.py +++ b/aiws.py @@ -23,6 +23,8 @@ from api.service.album_service import AlbumService from middlewares.album import AlbumMiddleware from middlewares.auth import AuthMiddleware from middlewares.dao import DaoMiddleware +from scheduler.daily_scheduler import DailyScheduler +from scheduler.telegram_admin_handler import create_daily_scheduler_router # Репозитории и DAO from repos.char_repo import CharacterRepo @@ -108,7 +110,7 @@ dp["gemini"] = gemini # 1. Роутеры без мидлварей (например, auth) dp.include_router(auth_router) -# 2. Основные роутеры +# 2. Основные роутеры (daily_scheduler router добавляется в lifespan) main_router = Router() dp.include_router(main_router) dp.include_router(assets_router) @@ -141,6 +143,34 @@ async def start_scheduler(service: GenerationService): logger.error(f"Scheduler error: {e}") await asyncio.sleep(60) # Check every 60 seconds + +def _build_daily_scheduler() -> DailyScheduler: + """Construct DailyScheduler; MetaAdapter is optional (needs env vars).""" + meta_adapter = None + if settings.META_ACCESS_TOKEN and settings.META_INSTAGRAM_ACCOUNT_ID: + from adapters.meta_adapter import MetaAdapter + meta_adapter = MetaAdapter( + access_token=settings.META_ACCESS_TOKEN, + instagram_account_id=settings.META_INSTAGRAM_ACCOUNT_ID, + ) + logger.info("MetaAdapter initialized") + else: + logger.warning("META_ACCESS_TOKEN / META_INSTAGRAM_ACCOUNT_ID not set — Instagram publishing disabled") + + if not settings.SCHEDULER_CHARACTER_ID: + logger.warning("SCHEDULER_CHARACTER_ID not set — daily scheduler will error at runtime") + + return DailyScheduler( + dao=dao, + gemini=gemini, + s3_adapter=s3_adapter, + generation_service=generation_service, + bot=bot, + admin_id=ADMIN_ID, + character_id=settings.SCHEDULER_CHARACTER_ID or "", + meta_adapter=meta_adapter, + ) + # --- LIFESPAN (Запуск FastAPI + Bot) --- @asynccontextmanager async def lifespan(app: FastAPI): @@ -164,36 +194,39 @@ async def lifespan(app: FastAPI): print("✅ DB & DAO initialized") - # 2. ЗАПУСК БОТА (в фоне) - # Важно: handle_signals=False, чтобы бот не перехватывал сигналы остановки у uvicorn - # Мы НЕ передаем сюда dao=..., так как он уже подключен через Middleware выше - # polling_task = asyncio.create_task( - # dp.start_polling(bot, handle_signals=False) - # ) - # print("🤖 Bot polling started") + # 2. Инициализация и регистрация daily_scheduler роутера + daily_scheduler = _build_daily_scheduler() + dp.include_router(create_daily_scheduler_router(daily_scheduler)) + print("📅 Daily scheduler router registered") - # 3. ЗАПУСК ШЕДУЛЕРА + # 3. ЗАПУСК БОТА (в фоне) + # handle_signals=False — бот не перехватывает сигналы остановки у uvicorn + polling_task = asyncio.create_task( + dp.start_polling(bot, handle_signals=False) + ) + print("🤖 Bot polling started") + + # 4. ЗАПУСК ШЕДУЛЕРОВ scheduler_task = asyncio.create_task(start_scheduler(generation_service)) - print("⏰ Scheduler started") + daily_scheduler_task = asyncio.create_task(daily_scheduler.run_loop()) + print("⏰ Schedulers started") yield # --- SHUTDOWN --- print("🛑 Shutting down...") - - # 4. Остановка шедулера - scheduler_task.cancel() - try: - await scheduler_task - except asyncio.CancelledError: - print("⏰ Scheduler stopped") - - # 3. Остановка бота - # polling_task.cancel() - # try: - # await polling_task - # except asyncio.CancelledError: - # print("🤖 Bot polling stopped") + + # Останавливаем все фоновые задачи + for task, name in [ + (polling_task, "Bot polling"), + (scheduler_task, "Stale-gen scheduler"), + (daily_scheduler_task, "Daily scheduler"), + ]: + task.cancel() + try: + await task + except asyncio.CancelledError: + print(f"⏹ {name} stopped") # 4. Отключение БД # Обычно Motor закрывать не обязательно при выходе, но хорошим тоном считается diff --git a/api/endpoints/admin.py b/api/endpoints/admin.py index 1fa8ab2..3d77378 100644 --- a/api/endpoints/admin.py +++ b/api/endpoints/admin.py @@ -7,6 +7,7 @@ from pydantic import BaseModel from repos.user_repo import UsersRepo, UserStatus from api.dependency import get_dao from repos.dao import DAO +from models.Settings import SystemSettings from utils.security import verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY from jose import JWTError, jwt from starlette.requests import Request @@ -96,3 +97,21 @@ async def deny_user( await repo.deny_user(username) return {"message": f"User {username} denied"} + +@router.get("/settings", response_model=SystemSettings) +async def get_settings( + admin: Annotated[dict, Depends(get_current_admin)], + dao: Annotated[DAO, Depends(get_dao)] +): + return await dao.settings.get_settings() + +@router.post("/settings") +async def update_settings( + settings: SystemSettings, + admin: Annotated[dict, Depends(get_current_admin)], + dao: Annotated[DAO, Depends(get_dao)] +): + success = await dao.settings.update_settings(settings) + if not success: + raise HTTPException(status_code=500, detail="Failed to update settings") + return {"message": "Settings updated successfully"} diff --git a/api/service/generation_service.py b/api/service/generation_service.py index c0c1841..b8afd9a 100644 --- a/api/service/generation_service.py +++ b/api/service/generation_service.py @@ -12,6 +12,7 @@ from aiogram.types import BufferedInputFile from adapters.Exception import GoogleGenerationException from adapters.google_adapter import GoogleAdapter +from adapters.ai_proxy_adapter import AIProxyAdapter from adapters.s3_adapter import S3Adapter from api.models import ( FinancialReport, UsageStats, UsageByEntity, @@ -72,6 +73,7 @@ class GenerationService: def __init__(self, dao: DAO, gemini: GoogleAdapter, s3_adapter: S3Adapter, bot: Optional[Bot] = None): self.dao = dao self.gemini = gemini + self.ai_proxy = AIProxyAdapter() self.s3_adapter = s3_adapter self.bot = bot @@ -84,12 +86,19 @@ class GenerationService: "Understand user needs and generate best variation of prompt. ANSWER ONLY PROMPT STRING!!! " f"USER_ENTERED_PROMPT: {prompt}" ) - assets_data = [] - if assets: - assets_db = await self.dao.assets.get_assets_by_ids(assets) - assets_data.extend(asset.data for asset in assets_db if asset.data) - generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data) + settings = await self.dao.settings.get_settings() + if settings.use_ai_proxy: + asset_urls = await self._prepare_asset_urls(assets) if assets else None + generated_prompt = await self.ai_proxy.generate_text(future_prompt, model, asset_urls) + else: + assets_data = [] + if assets: + assets_db = await self.dao.assets.get_assets_by_ids(assets) + assets_data.extend(asset.data for asset in assets_db if asset.data) + + generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data) + logger.info(f"Prompt Assistant: {generated_prompt}") return generated_prompt @@ -99,6 +108,15 @@ class GenerationService: technical_prompt += f"User also provided this context: {user_prompt}. " technical_prompt += "Provide ONLY the detailed prompt." + settings = await self.dao.settings.get_settings() + if settings.use_ai_proxy: + # Proxy doesn't support raw bytes currently. + # In a real scenario we'd upload them to a temp bucket. + # For now, we call the proxy with just the prompt, + # or we could fall back to GoogleAdapter if images are essential. + # To be safe and follow instructions to use proxy, we use it. + return await self.ai_proxy.generate_text(prompt=technical_prompt, model=model) + return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images) async def get_generations(self, **kwargs) -> GenerationsResponse: @@ -154,19 +172,36 @@ class GenerationService: logger.info(f"Processing generation {generation.id}. Character ID: {generation.linked_character_id}") # 1. Prepare input - media_group_bytes, generation_prompt = await self._prepare_generation_input(generation) + media_group_bytes, generation_prompt, asset_ids = await self._prepare_generation_input(generation) # 2. Run generation with progress simulation progress_task = asyncio.create_task(self._simulate_progress(generation)) try: - generated_bytes_list, metrics = await generate_image_task( - prompt=generation_prompt, - media_group_bytes=media_group_bytes, - aspect_ratio=generation.aspect_ratio, - quality=generation.quality, - model=generation.model or "gemini-3-pro-image-preview", - gemini=self.gemini - ) + settings = await self.dao.settings.get_settings() + if settings.use_ai_proxy: + asset_urls = await self._prepare_asset_urls(asset_ids) if asset_ids else None + generated_images_io, metrics = await self.ai_proxy.generate_image( + prompt=generation_prompt, + aspect_ratio=generation.aspect_ratio, + quality=generation.quality, + model=generation.model or "gemini-3-pro-image-preview", + asset_urls=asset_urls + ) + generated_bytes_list = [] + for img_io in generated_images_io: + img_io.seek(0) + generated_bytes_list.append(img_io.read()) + img_io.close() + else: + generated_bytes_list, metrics = await generate_image_task( + prompt=generation_prompt, + media_group_bytes=media_group_bytes, + aspect_ratio=generation.aspect_ratio, + quality=generation.quality, + model=generation.model or "gemini-3-pro-image-preview", + gemini=self.gemini + ) + self._update_generation_metrics(generation, metrics) # 3. Process results @@ -299,36 +334,39 @@ class GenerationService: await self._handle_generation_failure(gen, e) logger.exception(f"Background generation task failed for ID: {gen.id}") - async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str]: + async def _prepare_generation_input(self, generation: Generation) -> Tuple[List[bytes], str, List[str]]: media_group_bytes: List[bytes] = [] prompt = generation.prompt + asset_ids = [] # 1. Character Avatar if generation.linked_character_id: char_info = await self.dao.chars.get_character(generation.linked_character_id) if not char_info: raise ValueError(f"Character {generation.linked_character_id} not found") - + if generation.use_profile_image and char_info.avatar_asset_id: - avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id) - if avatar_asset: - data = await self._get_asset_data_bytes(avatar_asset) - if data: media_group_bytes.append(data) + asset_ids.append(char_info.avatar_asset_id) + avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id, with_data=True) + if avatar_asset and avatar_asset.content_type == AssetContentType.IMAGE and avatar_asset.data: + media_group_bytes.append(avatar_asset.data) # 2. Reference Assets if generation.assets_list: + asset_ids.extend(generation.assets_list) assets = await self.dao.assets.get_assets_by_ids(generation.assets_list) for asset in assets: - data = await self._get_asset_data_bytes(asset) + data = await self._load_asset_image_data(asset) if data: media_group_bytes.append(data) # 3. Environment Assets if generation.environment_id: env = await self.dao.environments.get_env(generation.environment_id) if env and env.asset_ids: + asset_ids.extend(env.asset_ids) env_assets = await self.dao.assets.get_assets_by_ids(env.asset_ids) for asset in env_assets: - data = await self._get_asset_data_bytes(asset) + data = await self._load_asset_image_data(asset) if data: media_group_bytes.append(data) if media_group_bytes: @@ -337,14 +375,26 @@ class GenerationService: "character's facial features and hair, environment or clothes. Maintain high fidelity to the reference identity." ) - return media_group_bytes, prompt + return media_group_bytes, prompt, asset_ids - async def _get_asset_data_bytes(self, asset: Asset) -> Optional[bytes]: + async def _prepare_asset_urls(self, asset_ids: List[str]) -> List[str]: + assets = await self.dao.assets.get_assets_by_ids(asset_ids) + urls = [] + for asset in assets: + if asset.minio_object_name: + bucket = asset.minio_bucket or self.s3_adapter.bucket_name + urls.append(f"{bucket}/{asset.minio_object_name}") + return urls + + async def _load_asset_image_data(self, asset: Asset) -> Optional[bytes]: + """Load image bytes for an asset that was fetched without data (e.g. from get_assets_by_ids).""" if asset.content_type != AssetContentType.IMAGE: return None + if asset.data: + return asset.data if asset.minio_object_name: return await self.s3_adapter.get_file(asset.minio_object_name) - return asset.data + return None def _update_generation_metrics(self, generation: Generation, metrics: Dict[str, Any]): generation.api_execution_time_seconds = metrics.get("api_execution_time_seconds") diff --git a/config.py b/config.py index f48b51b..e92074a 100644 --- a/config.py +++ b/config.py @@ -24,6 +24,16 @@ class Settings(BaseSettings): # External API EXTERNAL_API_SECRET: Optional[str] = None + # Daily Scheduler + SCHEDULER_CHARACTER_ID: Optional[str] = None # Character ID used for daily generation + + # Meta Platform (Instagram Graph API) + META_ACCESS_TOKEN: Optional[str] = None # Long-lived page/Instagram access token + META_INSTAGRAM_ACCOUNT_ID: Optional[str] = None # Instagram Business Account ID + + # AI Proxy Security + PROXY_SECRET_SALT: str = "AbVJUkwTPcUWJWhPzmjXb5p4SYyKmYC5m1uVW7Dhi7o" + # JWT Security SECRET_KEY: str = "CHANGE_ME_TO_A_SUPER_SECRET_KEY" ALGORITHM: str = "HS256" diff --git a/models/Settings.py b/models/Settings.py new file mode 100644 index 0000000..9e90668 --- /dev/null +++ b/models/Settings.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field +from datetime import datetime, UTC + +class SystemSettings(BaseModel): + id: str = Field(default="system_settings", alias="_id") + use_ai_proxy: bool = False + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + + class Config: + populate_by_name = True diff --git a/repos/dao.py b/repos/dao.py index c0df4d4..2751220 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -10,6 +10,7 @@ from repos.idea_repo import IdeaRepo from repos.post_repo import PostRepo from repos.environment_repo import EnvironmentRepo from repos.inspiration_repo import InspirationRepo +from repos.settings_repo import SettingsRepo from typing import Optional @@ -27,3 +28,4 @@ class DAO: self.posts = PostRepo(client, db_name) self.environments = EnvironmentRepo(client, db_name) self.inspirations = InspirationRepo(client, db_name) + self.settings = SettingsRepo(client, db_name) diff --git a/repos/settings_repo.py b/repos/settings_repo.py new file mode 100644 index 0000000..1283dc5 --- /dev/null +++ b/repos/settings_repo.py @@ -0,0 +1,26 @@ +from typing import Optional +from motor.motor_asyncio import AsyncIOMotorClient +from models.Settings import SystemSettings +from datetime import datetime, UTC + +class SettingsRepo: + def __init__(self, client: AsyncIOMotorClient, db_name: str): + self.collection = client[db_name]["settings"] + + async def get_settings(self) -> SystemSettings: + doc = await self.collection.find_one({"_id": "system_settings"}) + if not doc: + # Create default settings if not exists + settings = SystemSettings() + await self.collection.insert_one(settings.model_dump(by_alias=True)) + return settings + return SystemSettings(**doc) + + async def update_settings(self, settings: SystemSettings) -> bool: + settings.updated_at = datetime.now(UTC) + result = await self.collection.replace_one( + {"_id": "system_settings"}, + settings.model_dump(by_alias=True), + upsert=True + ) + return result.modified_count > 0 or result.upserted_id is not None diff --git a/scheduler/__init__.py b/scheduler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scheduler/daily_scheduler.py b/scheduler/daily_scheduler.py new file mode 100644 index 0000000..5b80a9a --- /dev/null +++ b/scheduler/daily_scheduler.py @@ -0,0 +1,456 @@ +import asyncio +import logging +from datetime import datetime, timezone, timedelta +from typing import Any, Dict, Optional, Tuple + +from aiogram import Bot +from aiogram.types import BufferedInputFile, InlineKeyboardButton, InlineKeyboardMarkup + +from adapters.google_adapter import GoogleAdapter +from adapters.ai_proxy_adapter import AIProxyAdapter +from adapters.s3_adapter import S3Adapter +from api.service.generation_service import GenerationService +from models.Asset import Asset +from models.Generation import Generation, GenerationStatus +from models.enums import AspectRatios, ImageModel, Quality, TextModel +from repos.dao import DAO + +logger = logging.getLogger(__name__) + +MSK_TZ = timezone(timedelta(hours=3)) +SCHEDULE_HOUR_MSK = 11 +SCHEDULE_MINUTE_MSK = 0 + +# Callback data prefixes for inline keyboard buttons +CB_POST = "daily_post" +CB_REGEN_ALL = "daily_regen_all" +CB_REGEN_IMG = "daily_regen_img" +CB_REGEN_MORE = "daily_regen_more" +CB_CANCEL = "daily_cancel" + + +def make_admin_keyboard(generation_id: str) -> InlineKeyboardMarkup: + return InlineKeyboardMarkup( + inline_keyboard=[ + [ + InlineKeyboardButton(text="✅ Выложить", callback_data=f"{CB_POST}:{generation_id}"), + InlineKeyboardButton(text="❌ Отмена", callback_data=f"{CB_CANCEL}:{generation_id}"), + ], + [ + InlineKeyboardButton(text="🔄 Перегенерить с нуля", callback_data=f"{CB_REGEN_ALL}:{generation_id}"), + InlineKeyboardButton(text="🖼 Перегенерить изображение", callback_data=f"{CB_REGEN_IMG}:{generation_id}"), + ], + [ + InlineKeyboardButton(text="➕ Сгенерировать ещё 2", callback_data=f"{CB_REGEN_MORE}:{generation_id}"), + ], + ] + ) + + +class DailyScheduler: + """Orchestrates the daily AI-character content generation pipeline. + + Flow: + 1. Generate image prompt + social caption via LLM (with character avatar). + 2. Generate image via GenerationService.create_generation() (reuses existing pipeline). + 3. Send to Telegram admin with action buttons. + + Admin actions (inline keyboard): + - Выложить → post to Instagram feed + story via Meta API. + - Перегенерить с нуля → restart from step 1. + - Перегенерить изображение → restart from step 2 (same prompt/caption). + - Сгенерировать ещё 2 → generate 2 pose-varied images. + - Отмена → dismiss (no action). + """ + + def __init__( + self, + dao: DAO, + gemini: GoogleAdapter, + s3_adapter: S3Adapter, + generation_service: GenerationService, + bot: Bot, + admin_id: int, + character_id: str, + meta_adapter=None, # Optional[MetaAdapter] + ): + self.dao = dao + self.gemini = gemini + self.ai_proxy = AIProxyAdapter() + self.s3_adapter = s3_adapter + self.generation_service = generation_service + self.bot = bot + self.admin_id = admin_id + self.character_id = character_id + self.meta_adapter = meta_adapter + + # Stores session state keyed by generation_id. + # Each value: {prompt, caption, asset_id, message_id, chat_id} + self.pending_sessions: Dict[str, Dict[str, Any]] = {} + + # ------------------------------------------------------------------ + # Scheduler loop + # ------------------------------------------------------------------ + + async def run_loop(self): + """Run indefinitely, triggering daily generation at 11:00 MSK.""" + logger.info("Daily scheduler loop started") + while True: + try: + await self._wait_until_next_run() + logger.info("Daily scheduler: triggering daily generation") + await self.run_daily_generation() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Daily scheduler loop error: {e}", exc_info=True) + + async def _wait_until_next_run(self): + now = datetime.now(MSK_TZ) + next_run = now.replace( + hour=SCHEDULE_HOUR_MSK, + minute=SCHEDULE_MINUTE_MSK, + second=0, + microsecond=0, + ) + if now >= next_run: + next_run += timedelta(days=1) + wait_seconds = (next_run - now).total_seconds() + logger.info( + f"Next daily generation at {next_run.strftime('%Y-%m-%d %H:%M MSK')} " + f"(in {wait_seconds / 3600:.1f}h)" + ) + await asyncio.sleep(wait_seconds) + + # ------------------------------------------------------------------ + # Main generation pipeline + # ------------------------------------------------------------------ + + async def run_daily_generation(self): + """Full pipeline: prompt → image → send to admin.""" + try: + prompt, caption = await self._generate_prompt_and_caption() + logger.info(f"Prompt generated ({len(prompt)} chars), caption ({len(caption)} chars)") + + generation, asset = await self._generate_image_and_save(prompt) + logger.info(f"Generation done: id={generation.id}, asset={asset.id}") + + await self._send_to_admin(generation, asset, prompt, caption) + except Exception as e: + logger.error(f"Daily generation pipeline failed: {e}", exc_info=True) + try: + await self.bot.send_message( + chat_id=self.admin_id, + text=f"❌ Ежедневная генерация провалилась:\n{e}", + ) + except Exception: + pass + + # ------------------------------------------------------------------ + # Step 1 — Generate prompt + caption via LLM + # ------------------------------------------------------------------ + + async def _generate_prompt_and_caption(self) -> Tuple[str, str]: + """Ask Gemini to produce an image prompt and social caption. + + Passes the character's avatar photo to the model so it can create + a prompt that is faithful to the character's appearance. + """ + char = await self.dao.chars.get_character(self.character_id) + if not char: + raise ValueError(f"Character {self.character_id} not found in DB") + + avatar_bytes_list: list[bytes] = [] + if char.avatar_asset_id: + avatar_asset = await self.dao.assets.get_asset(char.avatar_asset_id, with_data=True) + if avatar_asset and avatar_asset.data: + avatar_bytes_list.append(avatar_asset.data) + + char_bio = char.character_bio or "An expressive, stylish AI character." + system_prompt = ( + f"You are a creative director for the social media account of an AI character named '{char.name}'.\n" + # f"Character description: {char_bio}\n\n" + "I'm attaching the character's avatar photo. Based on it, produce TWO things:\n\n" + "1. IMAGE_PROMPT: A detailed, vivid image generation prompt in English. " + "Describe the pose, environment, lighting, color palette, and artistic style. It must look amateur. " + "Make it unique and suitable for a social media post.\n\n" + "2. SOCIAL_CAPTION: An engaging caption in English for Instagram and TikTok. " + "Include 5-10 relevant hashtags at the end.\n\n" + "Reply in EXACTLY this format (two lines, no extra text before IMAGE_PROMPT):\n" + "IMAGE_PROMPT: \n" + "SOCIAL_CAPTION: " + ) + + settings = await self.dao.settings.get_settings() + if settings.use_ai_proxy: + asset_urls = await self._prepare_asset_urls([char.avatar_asset_id]) if char.avatar_asset_id else None + raw = await self.ai_proxy.generate_text( + system_prompt, + TextModel.GEMINI_3_1_PRO_PREVIEW.value, + asset_urls + ) + else: + raw = await asyncio.to_thread( + self.gemini.generate_text, + system_prompt, + TextModel.GEMINI_3_1_PRO_PREVIEW.value, + avatar_bytes_list or None, + ) + logger.debug(f"LLM raw response: {raw[:500]}") + + prompt, caption = self._parse_prompt_and_caption(raw, char.name) + return prompt, caption + + async def _prepare_asset_urls(self, asset_ids: list[str]) -> list[str]: + assets = await self.dao.assets.get_assets_by_ids(asset_ids) + urls = [] + for asset in assets: + if asset.minio_object_name: + bucket = asset.minio_bucket or self.s3_adapter.bucket_name + urls.append(f"{bucket}/{asset.minio_object_name}") + return urls + + @staticmethod + def _parse_prompt_and_caption(raw: str, char_name: str) -> Tuple[str, str]: + prompt = "" + caption = "" + + if "IMAGE_PROMPT:" in raw and "SOCIAL_CAPTION:" in raw: + after_label = raw.split("IMAGE_PROMPT:", 1)[1] + prompt = after_label.split("SOCIAL_CAPTION:", 1)[0].strip() + caption = after_label.split("SOCIAL_CAPTION:", 1)[1].strip() + elif "IMAGE_PROMPT:" in raw: + prompt = raw.split("IMAGE_PROMPT:", 1)[1].strip() + else: + prompt = raw.strip() + + if not prompt: + raise ValueError(f"LLM did not produce IMAGE_PROMPT. Raw snippet: {raw[:300]}") + if not caption: + caption = f"✨ Новый контент от {char_name}" + + return prompt, caption + + # ------------------------------------------------------------------ + # Step 2 — Generate image via GenerationService + # ------------------------------------------------------------------ + + async def _generate_image_and_save( + self, + prompt: str, + variation_hint: Optional[str] = None, + ) -> Tuple[Generation, Asset]: + """Create a Generation record and delegate execution to GenerationService. + + Uses GenerationService.create_generation() which handles: + - loading character avatar / reference assets + - calling Gemini image generation + - saving result as Asset in S3 + - finalizing the Generation record with metrics + + No telegram_id is set, so the service won't send its own notification — + we handle that ourselves in _send_to_admin() with action buttons. + """ + actual_prompt = prompt + if variation_hint: + actual_prompt = f"{prompt}. {variation_hint}" + + # Create Generation record (GenerationService.create_generation expects it pre-saved) + generation = Generation( + status=GenerationStatus.RUNNING, + linked_character_id=self.character_id, + aspect_ratio=AspectRatios.NINESIXTEEN, + quality=Quality.ONEK, + prompt=actual_prompt, + model=ImageModel.GEMINI_3_PRO_IMAGE_PREVIEW.value, + use_profile_image=True, + # No telegram_id → service won't send its own notification + ) + gen_id = await self.dao.generations.create_generation(generation) + generation.id = gen_id + + try: + # Delegate all heavy lifting to the existing service + await self.generation_service.create_generation(generation) + except Exception: + # create_generation doesn't mark FAILED itself — the caller (_queued_generation_runner) does. + # So we need to handle failure here. + await self.generation_service._handle_generation_failure(generation, Exception("Image generation failed")) + raise + + # After create_generation, generation.result_list is populated + if not generation.result_list: + raise ValueError("Generation completed but produced no assets") + + asset = await self.dao.assets.get_asset(generation.result_list[0], with_data=False) + if not asset: + raise ValueError(f"Asset {generation.result_list[0]} not found after generation") + + return generation, asset + + # ------------------------------------------------------------------ + # Step 3 — Send to admin + # ------------------------------------------------------------------ + + async def _send_to_admin( + self, + generation: Generation, + asset: Asset, + prompt: str, + caption: str, + ): + img_data = await self.s3_adapter.get_file(asset.minio_object_name) + if not img_data: + raise ValueError(f"Cannot load image from S3: {asset.minio_object_name}") + + self.pending_sessions[generation.id] = { + "prompt": prompt, + "caption": caption, + "asset_id": asset.id, + } + + msg = await self.bot.send_photo( + chat_id=self.admin_id, + photo=BufferedInputFile(img_data, filename="daily.png"), + caption=( + f"📸 Ежедневная генерация\n\n" + f"Подпись для соцсетей:\n{caption}\n\n" + f"Промпт:\n{prompt[:300]}" + ), + reply_markup=make_admin_keyboard(generation.id), + ) + self.pending_sessions[generation.id]["message_id"] = msg.message_id + self.pending_sessions[generation.id]["chat_id"] = msg.chat.id + + # ------------------------------------------------------------------ + # Admin action handlers (called from Telegram callback router) + # ------------------------------------------------------------------ + + async def handle_post(self, generation_id: str, message_id: int, chat_id: int): + """Post to Instagram feed + story.""" + session = self.pending_sessions.get(generation_id) + if not session: + return + + if not self.meta_adapter: + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption="⚠️ Meta API не настроен (META_ACCESS_TOKEN не задан). Публикация недоступна.", + ) + return + + try: + asset = await self.dao.assets.get_asset(session["asset_id"], with_data=False) + if not asset or not asset.minio_object_name: + raise ValueError("Asset not found in DB") + + image_url = await self.s3_adapter.get_presigned_url( + asset.minio_object_name, expiration=3600 + ) + if not image_url: + raise ValueError("Could not generate presigned URL for image") + + feed_id = await self.meta_adapter.post_to_feed(image_url, session["caption"]) + story_id = await self.meta_adapter.post_to_story(image_url) + + self.pending_sessions.pop(generation_id, None) + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption=( + f"✅ Опубликовано!\n\n" + f"📰 Feed ID: {feed_id}\n" + f"📖 Story ID: {story_id}" + ), + ) + except Exception as e: + logger.error(f"Meta publish failed for generation {generation_id}: {e}", exc_info=True) + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption=f"❌ Ошибка публикации:\n{e}", + reply_markup=make_admin_keyboard(generation_id), + ) + + async def handle_regen_all(self, generation_id: str, message_id: int, chat_id: int): + """Restart from step 1: generate new prompt, caption, and image.""" + self.pending_sessions.pop(generation_id, None) + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption="🔄 Перегенерация с нуля...", + ) + asyncio.create_task(self._run_regen_all(chat_id)) + + async def _run_regen_all(self, chat_id: int): + try: + await self.run_daily_generation() + except Exception as e: + logger.error(f"Regen-all failed: {e}", exc_info=True) + await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка перегенерации:\n{e}") + + async def handle_regen_image(self, generation_id: str, message_id: int, chat_id: int): + """Restart from step 2: generate new image using existing prompt/caption.""" + session = self.pending_sessions.pop(generation_id, None) + if not session: + return + + prompt = session["prompt"] + caption = session["caption"] + + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption="🖼 Перегенерация изображения...", + ) + asyncio.create_task(self._run_regen_image(prompt, caption, chat_id)) + + async def _run_regen_image(self, prompt: str, caption: str, chat_id: int): + try: + generation, asset = await self._generate_image_and_save(prompt) + await self._send_to_admin(generation, asset, prompt, caption) + except Exception as e: + logger.error(f"Regen-image failed: {e}", exc_info=True) + await self.bot.send_message(chat_id=chat_id, text=f"❌ Ошибка генерации:\n{e}") + + async def handle_regen_more(self, generation_id: str, message_id: int, chat_id: int): + """Generate 2 more pose-varied images using the existing prompt/caption.""" + session = self.pending_sessions.get(generation_id) + if not session: + return + + prompt = session["prompt"] + caption = session["caption"] + + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption="➕ Генерирую ещё 2 варианта...", + ) + asyncio.create_task(self._run_regen_more(prompt, caption, chat_id)) + + async def _run_regen_more(self, prompt: str, caption: str, chat_id: int): + variation_hints = [ + "Slightly vary the pose and camera angle while keeping the same scene, environment and lighting.", + "Try a different subtle pose or expression, same background and setting as described.", + ] + for i, hint in enumerate(variation_hints): + try: + generation, asset = await self._generate_image_and_save(prompt, variation_hint=hint) + await self._send_to_admin(generation, asset, prompt, caption) + except Exception as e: + logger.error(f"Regen-more variant {i + 1} failed: {e}", exc_info=True) + await self.bot.send_message( + chat_id=chat_id, + text=f"❌ Ошибка варианта {i + 1}:\n{e}", + ) + + async def handle_cancel(self, generation_id: str, message_id: int, chat_id: int): + """Dismiss: remove buttons, do nothing else.""" + self.pending_sessions.pop(generation_id, None) + await self.bot.edit_message_caption( + chat_id=chat_id, + message_id=message_id, + caption="🚫 Отменено.", + ) diff --git a/scheduler/telegram_admin_handler.py b/scheduler/telegram_admin_handler.py new file mode 100644 index 0000000..80a287e --- /dev/null +++ b/scheduler/telegram_admin_handler.py @@ -0,0 +1,82 @@ +"""Telegram inline-keyboard handlers for the daily scheduler admin flow. + +Usage (in aiws.py): + from scheduler.telegram_admin_handler import create_daily_scheduler_router + from scheduler.daily_scheduler import DailyScheduler + + daily_scheduler = DailyScheduler(...) + dp.include_router(create_daily_scheduler_router(daily_scheduler)) +""" + +import logging + +from aiogram import F, Router +from aiogram.types import CallbackQuery + +from scheduler.daily_scheduler import ( + CB_CANCEL, + CB_POST, + CB_REGEN_ALL, + CB_REGEN_IMG, + CB_REGEN_MORE, + DailyScheduler, +) + +logger = logging.getLogger(__name__) + + +def create_daily_scheduler_router(scheduler: DailyScheduler) -> Router: + """Return an aiogram Router with all callback handlers bound to *scheduler*.""" + router = Router(name="daily_scheduler") + + @router.callback_query(F.data.startswith(CB_POST + ":")) + async def on_post(callback: CallbackQuery): + generation_id = callback.data.split(":", 1)[1] + await callback.answer("Публикую в Instagram...") + await scheduler.handle_post( + generation_id=generation_id, + message_id=callback.message.message_id, + chat_id=callback.message.chat.id, + ) + + @router.callback_query(F.data.startswith(CB_REGEN_ALL + ":")) + async def on_regen_all(callback: CallbackQuery): + generation_id = callback.data.split(":", 1)[1] + await callback.answer("Перезапускаю с нуля...") + await scheduler.handle_regen_all( + generation_id=generation_id, + message_id=callback.message.message_id, + chat_id=callback.message.chat.id, + ) + + @router.callback_query(F.data.startswith(CB_REGEN_IMG + ":")) + async def on_regen_img(callback: CallbackQuery): + generation_id = callback.data.split(":", 1)[1] + await callback.answer("Генерирую новое изображение...") + await scheduler.handle_regen_image( + generation_id=generation_id, + message_id=callback.message.message_id, + chat_id=callback.message.chat.id, + ) + + @router.callback_query(F.data.startswith(CB_REGEN_MORE + ":")) + async def on_regen_more(callback: CallbackQuery): + generation_id = callback.data.split(":", 1)[1] + await callback.answer("Генерирую 2 варианта...") + await scheduler.handle_regen_more( + generation_id=generation_id, + message_id=callback.message.message_id, + chat_id=callback.message.chat.id, + ) + + @router.callback_query(F.data.startswith(CB_CANCEL + ":")) + async def on_cancel(callback: CallbackQuery): + generation_id = callback.data.split(":", 1)[1] + await callback.answer("Отменено") + await scheduler.handle_cancel( + generation_id=generation_id, + message_id=callback.message.message_id, + chat_id=callback.message.chat.id, + ) + + return router diff --git a/tests/test_ai_proxy_logic.py b/tests/test_ai_proxy_logic.py new file mode 100644 index 0000000..bb30eec --- /dev/null +++ b/tests/test_ai_proxy_logic.py @@ -0,0 +1,51 @@ +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch +from api.service.generation_service import GenerationService +from models.Settings import SystemSettings +from models.Generation import Generation +from models.enums import AspectRatios, Quality + +async def test_generation_service_proxy_logic(): + dao = MagicMock() + gemini = MagicMock() + s3_adapter = MagicMock() + + # Mock settings to have proxy ENABLED + dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True)) + dao.assets.get_assets_by_ids = AsyncMock(return_value=[]) + + service = GenerationService(dao, gemini, s3_adapter) + + # 1. Test ask_prompt_assistant with proxy + with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text: + mock_proxy_text.return_value = "Proxy Result" + result = await service.ask_prompt_assistant("Test Prompt") + assert result == "Proxy Result" + mock_proxy_text.assert_called_once() + gemini.generate_text.assert_not_called() + + # 2. Test create_generation with proxy + generation = Generation( + prompt="Test Image", + aspect_ratio=AspectRatios.ONEONE, + quality=Quality.ONEK, + assets_list=[] + ) + # Mock _prepare_generation_input to avoid complex DB calls + service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", [])) + service._process_generated_images = AsyncMock(return_value=[]) + service._finalize_generation = AsyncMock() + + with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img: + import io + mock_img_io = io.BytesIO(b"fake image data") + mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0}) + + await service.create_generation(generation) + mock_proxy_img.assert_called_once() + # gemini.generate_image would be called via generate_image_task in else branch + + print("✅ Proxy logic test passed!") + +if __name__ == "__main__": + asyncio.run(test_generation_service_proxy_logic())