3 Commits

26 changed files with 489 additions and 26 deletions

1
.gitignore vendored
View File

@@ -23,3 +23,4 @@ services/*.pyc
utils/__pycache__/
utils/*.pyc
.vscode/launch.json
repos/__pycache__/assets_repo.cpython-313.pyc

View File

@@ -23,10 +23,10 @@ class GoogleAdapter:
self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> tuple:
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents = [prompt]
contents : list [Any]= [prompt]
opened_images = []
if images_list:
logger.info(f"Preparing content with {len(images_list)} images")
@@ -41,7 +41,7 @@ class GoogleAdapter:
logger.info("Preparing content with no images")
return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
def generate_text(self, prompt: str, images_list: List[bytes] | None = None) -> str:
"""
Генерация текста (Чат или Vision).
Возвращает строку с ответом.
@@ -74,7 +74,7 @@ class GoogleAdapter:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке).
@@ -130,7 +130,9 @@ class GoogleAdapter:
try:
# 1. Берем сырые байты
raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data)
if raw_data is None:
raise GoogleGenerationException("Generation returned no data")
byte_arr : io.BytesIO = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp()

View File

@@ -18,7 +18,7 @@ class S3Adapter:
@asynccontextmanager
async def _get_client(self):
async with self.session.client(
async with self.session.client( # type: ignore[reportGeneralTypeIssues]
"s3",
endpoint_url=self.endpoint_url,
aws_access_key_id=self.aws_access_key_id,

13
aiws.py
View File

@@ -44,6 +44,7 @@ from api.endpoints.admin import router as api_admin_router
from api.endpoints.album_router import router as api_album_router
from api.endpoints.project_router import router as project_api_router
from api.endpoints.idea_router import router as idea_api_router
from api.endpoints.post_router import router as post_api_router
load_dotenv()
logger = logging.getLogger(__name__)
@@ -64,6 +65,8 @@ def setup_logging():
# --- ИНИЦИАЛИЗАЦИЯ ЗАВИСИМОСТЕЙ ---
if BOT_TOKEN is None:
raise ValueError("BOT_TOKEN is not set")
bot = Bot(token=BOT_TOKEN, default=DefaultBotProperties(parse_mode=ParseMode.HTML))
# Клиент БД создаем глобально, чтобы он был доступен и боту (Storage), и API
@@ -83,8 +86,12 @@ s3_adapter = S3Adapter(
)
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
if GEMINI_API_KEY is None:
raise ValueError("GEMINI_API_KEY is not set")
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
generation_service = GenerationService(dao, gemini, bot)
if bot is None:
raise ValueError("bot is not set")
generation_service = GenerationService(dao=dao, gemini=gemini, s3_adapter=s3_adapter, bot=bot)
album_service = AlbumService(dao)
# Dispatcher
@@ -126,11 +133,12 @@ async def start_scheduler(service: GenerationService):
try:
logger.info("Running scheduler for stacked generation killing")
await service.cleanup_stale_generations()
await service.cleanup_old_data(days=2)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Scheduler error: {e}")
await asyncio.sleep(60) # Check every 10 minutes
await asyncio.sleep(60) # Check every 60 seconds
# --- LIFESPAN (Запуск FastAPI + Bot) ---
@asynccontextmanager
@@ -212,6 +220,7 @@ app.include_router(api_gen_router)
app.include_router(api_album_router)
app.include_router(project_api_router)
app.include_router(idea_api_router)
app.include_router(post_api_router)
# Prometheus Metrics (Instrument after all routers are added)
Instrumentator(

View File

@@ -5,6 +5,7 @@ from motor.motor_asyncio import AsyncIOMotorClient
from adapters.google_adapter import GoogleAdapter
from api.service.generation_service import GenerationService
from repos.dao import DAO
from api.service.album_service import AlbumService
# ... ваши импорты ...
@@ -54,3 +55,11 @@ from fastapi import Header
async def get_project_id(x_project_id: Optional[str] = Header(None, alias="X-Project-ID")) -> Optional[str]:
return x_project_id
async def get_album_service(dao: DAO = Depends(get_dao)) -> AlbumService:
return AlbumService(dao)
from api.service.post_service import PostService
def get_post_service(dao: DAO = Depends(get_dao)) -> PostService:
return PostService(dao)

View File

@@ -23,7 +23,7 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], repo:
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
username: str | None = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:

View File

@@ -1,10 +1,13 @@
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Request
from fastapi import APIRouter, HTTPException, status, Request
from pydantic import BaseModel
from api.models.GenerationRequest import GenerationResponse
from models.Album import Album
from repos.dao import DAO
from api.dependency import get_album_service
from api.service.album_service import AlbumService
router = APIRouter(prefix="/api/albums", tags=["Albums"])

View File

@@ -0,0 +1,99 @@
from typing import List, Optional
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException
from api.dependency import get_post_service, get_project_id
from api.endpoints.auth import get_current_user
from api.service.post_service import PostService
from api.models.PostRequest import PostCreateRequest, PostUpdateRequest, AddGenerationsRequest
from models.Post import Post
router = APIRouter(prefix="/api/posts", tags=["posts"])
@router.post("", response_model=Post)
async def create_post(
request: PostCreateRequest,
project_id: Optional[str] = Depends(get_project_id),
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
pid = project_id or request.project_id
return await post_service.create_post(
date=request.date,
topic=request.topic,
generation_ids=request.generation_ids,
project_id=pid,
user_id=str(current_user["_id"]),
)
@router.get("", response_model=List[Post])
async def get_posts(
project_id: Optional[str] = Depends(get_project_id),
limit: int = 200,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
current_user: dict = Depends(get_current_user),
post_service: PostService = Depends(get_post_service),
):
return await post_service.get_posts(project_id, str(current_user["_id"]), limit, offset, date_from, date_to)
@router.get("/{post_id}", response_model=Post)
async def get_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.put("/{post_id}", response_model=Post)
async def update_post(
post_id: str,
request: PostUpdateRequest,
post_service: PostService = Depends(get_post_service),
):
post = await post_service.update_post(post_id, date=request.date, topic=request.topic)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return post
@router.delete("/{post_id}")
async def delete_post(
post_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.delete_post(post_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or could not be deleted")
return {"status": "success"}
@router.post("/{post_id}/generations")
async def add_generations(
post_id: str,
request: AddGenerationsRequest,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.add_generations(post_id, request.generation_ids)
if not success:
raise HTTPException(status_code=404, detail="Post not found")
return {"status": "success"}
@router.delete("/{post_id}/generations/{generation_id}")
async def remove_generation(
post_id: str,
generation_id: str,
post_service: PostService = Depends(get_post_service),
):
success = await post_service.remove_generation(post_id, generation_id)
if not success:
raise HTTPException(status_code=404, detail="Post not found or generation not linked")
return {"status": "success"}

19
api/models/PostRequest.py Normal file
View File

@@ -0,0 +1,19 @@
from datetime import datetime
from typing import Optional, List
from pydantic import BaseModel
class PostCreateRequest(BaseModel):
date: datetime
topic: str
generation_ids: List[str] = []
project_id: Optional[str] = None
class PostUpdateRequest(BaseModel):
date: Optional[datetime] = None
topic: Optional[str] = None
class AddGenerationsRequest(BaseModel):
generation_ids: List[str]

View File

@@ -77,7 +77,7 @@ class GenerationService:
self.bot = bot
async def ask_prompt_assistant(self, prompt: str, assets: List[str] = None) -> str:
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
future_prompt = """You are an prompt-assistant. You improving user-entered prompts for image generation. User may upload reference image too.
I will provide sources prompt entered by user. Understand user needs and generate best variation of prompt.
ANSWER ONLY PROMPT STRING!!! USER_ENTERED_PROMPT: """
@@ -157,6 +157,7 @@ class GenerationService:
# если генерация уже пошла и упала — пометим FAILED
try:
db_gen = await self.dao.generations.get_generation(gen.id)
if db_gen is not None:
db_gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(db_gen)
except Exception:
@@ -172,6 +173,7 @@ class GenerationService:
if gen_id is not None:
try:
gen = await self.dao.generations.get_generation(gen_id)
if gen is not None:
gen.status = GenerationStatus.FAILED
await self.dao.generations.update_generation(gen)
except Exception:
@@ -201,8 +203,9 @@ class GenerationService:
if char_info is None:
raise Exception(f"Character ID {generation.linked_character_id} not found")
if generation.use_profile_image:
if char_info.avatar_asset_id is not None:
avatar_asset = await self.dao.assets.get_asset(char_info.avatar_asset_id)
if avatar_asset:
if avatar_asset and avatar_asset.data:
media_group_bytes.append(avatar_asset.data)
# generation_prompt = generation_prompt.replace("$char_bio_inserted", f"1. CHARACTER BIO (Must be strictly followed): {char_info.character_bio}")
@@ -304,7 +307,9 @@ class GenerationService:
# 5. (Опционально) Обновляем запись генерации ссылками на результаты
# Предполагаем, что у модели Generation есть поле result_asset_ids
result_ids = [a.id for a in created_assets]
result_ids = []
for a in created_assets:
result_ids.append(a.id)
generation.result_list = result_ids
generation.status = GenerationStatus.DONE
@@ -480,3 +485,25 @@ class GenerationService:
logger.info(f"Cleaned up {count} stale generations (timeout)")
except Exception as e:
logger.error(f"Error cleaning up stale generations: {e}")
async def cleanup_old_data(self, days: int = 2):
"""
Очистка старых данных:
1. Мягко удаляет генерации старше N дней
2. Мягко удаляет связанные ассеты + жёстко удаляет файлы из S3
"""
try:
# 1. Мягко удаляем генерации и собираем asset IDs
gen_count, asset_ids = await self.dao.generations.soft_delete_old_generations(days=days)
if gen_count > 0:
logger.info(f"Soft-deleted {gen_count} generations older than {days} days. "
f"Found {len(asset_ids)} associated asset IDs.")
# 2. Мягко удаляем ассеты + жёстко удаляем файлы из S3
if asset_ids:
purged = await self.dao.assets.soft_delete_and_purge_assets(asset_ids)
logger.info(f"Purged {purged} assets (soft-deleted + S3 files removed).")
except Exception as e:
logger.error(f"Error during old data cleanup: {e}")

View File

@@ -0,0 +1,79 @@
from typing import List, Optional
from datetime import datetime, UTC
from repos.dao import DAO
from models.Post import Post
class PostService:
def __init__(self, dao: DAO):
self.dao = dao
async def create_post(
self,
date: datetime,
topic: str,
generation_ids: List[str],
project_id: Optional[str],
user_id: str,
) -> Post:
post = Post(
date=date,
topic=topic,
generation_ids=generation_ids,
project_id=project_id,
created_by=user_id,
)
post_id = await self.dao.posts.create_post(post)
post.id = post_id
return post
async def get_post(self, post_id: str) -> Optional[Post]:
return await self.dao.posts.get_post(post_id)
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
return await self.dao.posts.get_posts(project_id, user_id, limit, offset, date_from, date_to)
async def update_post(
self,
post_id: str,
date: Optional[datetime] = None,
topic: Optional[str] = None,
) -> Optional[Post]:
post = await self.dao.posts.get_post(post_id)
if not post:
return None
updates: dict = {"updated_at": datetime.now(UTC)}
if date is not None:
updates["date"] = date
if topic is not None:
updates["topic"] = topic
await self.dao.posts.update_post(post_id, updates)
# Return refreshed post
return await self.dao.posts.get_post(post_id)
async def delete_post(self, post_id: str) -> bool:
return await self.dao.posts.delete_post(post_id)
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.add_generations(post_id, generation_ids)
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
post = await self.dao.posts.get_post(post_id)
if not post:
return False
return await self.dao.posts.remove_generation(post_id, generation_id)

View File

@@ -30,6 +30,7 @@ class Asset(BaseModel):
tags: List[str] = []
created_by: Optional[str] = None
project_id: Optional[str] = None
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))

23
models/Post.py Normal file
View File

@@ -0,0 +1,23 @@
from datetime import datetime, timezone, UTC
from typing import Optional, List
from pydantic import BaseModel, Field, model_validator
class Post(BaseModel):
id: Optional[str] = None
date: datetime
topic: str
generation_ids: List[str] = Field(default_factory=list)
project_id: Optional[str] = None
created_by: str
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
@model_validator(mode="after")
def ensure_tz_aware(self):
for field in ("date", "created_at", "updated_at"):
val = getattr(self, field)
if val is not None and val.tzinfo is None:
setattr(self, field, val.replace(tzinfo=timezone.utc))
return self

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from typing import Any, List, Optional
import logging
from datetime import datetime, UTC
from bson import ObjectId
from uuid import uuid4
from motor.motor_asyncio import AsyncIOMotorClient
@@ -50,7 +51,7 @@ class AssetsRepo:
return str(res.inserted_id)
async def get_assets(self, asset_type: Optional[str] = None, limit: int = 10, offset: int = 0, with_data: bool = False, created_by: Optional[str] = None, project_id: Optional[str] = None) -> List[Asset]:
filter = {}
filter: dict[str, Any]= {"is_deleted": {"$ne": True}}
if asset_type:
filter["type"] = asset_type
args = {}
@@ -202,6 +203,61 @@ class AssetsRepo:
res = await self.collection.delete_one({"_id": ObjectId(asset_id)})
return res.deleted_count > 0
async def soft_delete_and_purge_assets(self, asset_ids: List[str]) -> int:
"""
Мягко удаляет ассеты и жёстко удаляет их файлы из S3.
Возвращает количество обработанных ассетов.
"""
if not asset_ids:
return 0
object_ids = [ObjectId(aid) for aid in asset_ids if ObjectId.is_valid(aid)]
if not object_ids:
return 0
# Находим ассеты, которые ещё не удалены
cursor = self.collection.find(
{"_id": {"$in": object_ids}, "is_deleted": {"$ne": True}},
{"minio_object_name": 1, "minio_thumbnail_object_name": 1}
)
purged_count = 0
ids_to_update = []
async for doc in cursor:
ids_to_update.append(doc["_id"])
# Жёсткое удаление файлов из S3
if self.s3:
if doc.get("minio_object_name"):
try:
await self.s3.delete_file(doc["minio_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 object {doc['minio_object_name']}: {e}")
if doc.get("minio_thumbnail_object_name"):
try:
await self.s3.delete_file(doc["minio_thumbnail_object_name"])
except Exception as e:
logger.error(f"Failed to delete S3 thumbnail {doc['minio_thumbnail_object_name']}: {e}")
purged_count += 1
# Мягкое удаление + очистка ссылок на S3
if ids_to_update:
await self.collection.update_many(
{"_id": {"$in": ids_to_update}},
{
"$set": {
"is_deleted": True,
"minio_object_name": None,
"minio_thumbnail_object_name": None,
"updated_at": datetime.now(UTC)
}
}
)
return purged_count
async def migrate_to_minio(self) -> dict:
"""Переносит данные и thumbnails из Mongo в MinIO."""
if not self.s3:

View File

@@ -7,6 +7,7 @@ from repos.user_repo import UsersRepo
from repos.albums_repo import AlbumsRepo
from repos.project_repo import ProjectRepo
from repos.idea_repo import IdeaRepo
from repos.post_repo import PostRepo
from typing import Optional
@@ -21,3 +22,4 @@ class DAO:
self.projects = ProjectRepo(client, db_name)
self.users = UsersRepo(client, db_name)
self.ideas = IdeaRepo(client, db_name)
self.posts = PostRepo(client, db_name)

View File

@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Any, Optional, List
from datetime import datetime, timedelta, UTC
from PIL.ImageChops import offset
@@ -17,7 +17,7 @@ class GenerationRepo:
res = await self.collection.insert_one(generation.model_dump())
return str(res.inserted_id)
async def get_generation(self, generation_id: str) -> Optional[Generation]:
async def get_generation(self, generation_id: str) -> Generation | None:
res = await self.collection.find_one({"_id": ObjectId(generation_id)})
if res is None:
return None
@@ -28,7 +28,7 @@ class GenerationRepo:
async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None,
limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]:
filter = {"is_deleted": False}
filter: dict[str, Any] = {"is_deleted": False}
if character_id is not None:
filter["linked_character_id"] = character_id
if status is not None:
@@ -69,6 +69,8 @@ class GenerationRepo:
args["project_id"] = project_id
if idea_id is not None:
args["idea_id"] = idea_id
if album_id is not None:
args["album_id"] = album_id
return await self.collection.count_documents(args)
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
@@ -114,3 +116,37 @@ class GenerationRepo:
}
)
return res.modified_count
async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]:
"""
Мягко удаляет генерации старше N дней.
Возвращает (количество удалённых, список asset IDs для очистки).
"""
cutoff_time = datetime.now(UTC) - timedelta(days=days)
filter_query = {
"is_deleted": False,
"status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]},
"created_at": {"$lt": cutoff_time}
}
# Сначала собираем asset IDs из удаляемых генераций
asset_ids: List[str] = []
cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1})
async for doc in cursor:
asset_ids.extend(doc.get("result_list", []))
asset_ids.extend(doc.get("assets_list", []))
# Мягкое удаление
res = await self.collection.update_many(
filter_query,
{
"$set": {
"is_deleted": True,
"updated_at": datetime.now(UTC)
}
}
)
# Убираем дубликаты
unique_asset_ids = list(set(asset_ids))
return res.modified_count, unique_asset_ids

97
repos/post_repo.py Normal file
View File

@@ -0,0 +1,97 @@
from typing import List, Optional
from datetime import datetime
import logging
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models.Post import Post
logger = logging.getLogger(__name__)
class PostRepo:
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
self.collection = client[db_name]["posts"]
async def create_post(self, post: Post) -> str:
res = await self.collection.insert_one(post.model_dump())
return str(res.inserted_id)
async def get_post(self, post_id: str) -> Optional[Post]:
if not ObjectId.is_valid(post_id):
return None
res = await self.collection.find_one({"_id": ObjectId(post_id), "is_deleted": False})
if res:
res["id"] = str(res.pop("_id"))
return Post(**res)
return None
async def get_posts(
self,
project_id: Optional[str],
user_id: str,
limit: int = 20,
offset: int = 0,
date_from: Optional[datetime] = None,
date_to: Optional[datetime] = None,
) -> List[Post]:
if project_id:
match = {"project_id": project_id, "is_deleted": False}
else:
match = {"created_by": user_id, "project_id": None, "is_deleted": False}
if date_from or date_to:
date_filter = {}
if date_from:
date_filter["$gte"] = date_from
if date_to:
date_filter["$lte"] = date_to
match["date"] = date_filter
cursor = (
self.collection.find(match)
.sort("date", -1)
.skip(offset)
.limit(limit)
)
posts = []
async for doc in cursor:
doc["id"] = str(doc.pop("_id"))
posts.append(Post(**doc))
return posts
async def update_post(self, post_id: str, data: dict) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": data},
)
return res.modified_count > 0
async def delete_post(self, post_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$set": {"is_deleted": True}},
)
return res.modified_count > 0
async def add_generations(self, post_id: str, generation_ids: List[str]) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$addToSet": {"generation_ids": {"$each": generation_ids}}},
)
return res.modified_count > 0
async def remove_generation(self, post_id: str, generation_id: str) -> bool:
if not ObjectId.is_valid(post_id):
return False
res = await self.collection.update_one(
{"_id": ObjectId(post_id)},
{"$pull": {"generation_ids": generation_id}},
)
return res.modified_count > 0