from typing import Any, Optional, List from datetime import datetime, timedelta, UTC from PIL.ImageChops import offset from bson import ObjectId from motor.motor_asyncio import AsyncIOMotorClient from api.models.GenerationRequest import GenerationResponse from models.Generation import Generation, GenerationStatus class GenerationRepo: def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): self.collection = client[db_name]["generations"] async def create_generation(self, generation: Generation) -> str: res = await self.collection.insert_one(generation.model_dump()) return str(res.inserted_id) async def get_generation(self, generation_id: str) -> Generation | None: res = await self.collection.find_one({"_id": ObjectId(generation_id)}) if res is None: return None else: res["id"] = str(res.pop("_id")) return Generation(**res) async def get_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, limit: int = 10, offset: int = 0, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> List[Generation]: filter: dict[str, Any] = {"is_deleted": False} if character_id is not None: filter["linked_character_id"] = character_id if status is not None: filter["status"] = status if created_by is not None: filter["created_by"] = created_by # If filtering by created_by user (e.g. "My Generations"), we typically imply personal scope if project_id is None. # But if project_id is passed, we filter by that. if project_id is None: filter["project_id"] = None if project_id is not None: filter["project_id"] = project_id if idea_id is not None: filter["idea_id"] = idea_id # If fetching for an idea, sort by created_at ascending (cronological) # Otherwise typically descending (newest first) sort_order = 1 if idea_id else -1 res = await self.collection.find(filter).sort("created_at", sort_order).skip( offset).limit(limit).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None, created_by: Optional[str] = None, project_id: Optional[str] = None, idea_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id if status is not None: args["status"] = status if created_by is not None: args["created_by"] = created_by if project_id is None: args["project_id"] = None if project_id is not None: args["project_id"] = project_id if idea_id is not None: args["idea_id"] = idea_id if album_id is not None: args["album_id"] = album_id return await self.collection.count_documents(args) async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)] res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None) generations: List[Generation] = [] # Maintain order of generation_ids gen_map = {str(doc["_id"]): doc for doc in res} for gen_id in generation_ids: doc = gen_map.get(gen_id) if doc: doc["id"] = str(doc.pop("_id")) generations.append(Generation(**doc)) return generations async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) async def get_usage_stats(self, created_by: Optional[str] = None, project_id: Optional[str] = None) -> dict: """ Calculates usage statistics (runs, tokens, cost) using MongoDB aggregation. Includes even soft-deleted generations to reflect actual expenditure. """ pipeline = [] # 1. Match all done generations (including soft-deleted) match_stage = {"status": GenerationStatus.DONE} if created_by: match_stage["created_by"] = created_by if project_id: match_stage["project_id"] = project_id pipeline.append({"$match": match_stage}) # 2. Group by null (total) pipeline.append({ "$group": { "_id": None, "total_runs": {"$sum": 1}, "total_tokens": { "$sum": { "$cond": [ {"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]}, {"$add": ["$input_token_usage", "$output_token_usage"]}, {"$ifNull": ["$token_usage", 0]} ] } }, "total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}}, "total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}}, "total_cost": { "$sum": { "$add": [ {"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]}, {"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]} ] } } } }) cursor = self.collection.aggregate(pipeline) res = await cursor.to_list(1) if not res: return { "total_runs": 0, "total_tokens": 0, "total_input_tokens": 0, "total_output_tokens": 0, "total_cost": 0.0 } result = res[0] result.pop("_id") result["total_cost"] = round(result["total_cost"], 4) return result async def get_usage_breakdown(self, group_by: str = "created_by", project_id: Optional[str] = None, created_by: Optional[str] = None) -> List[dict]: """ Returns usage statistics grouped by user or project. Includes even soft-deleted generations to reflect actual expenditure. """ pipeline = [] match_stage = {"status": GenerationStatus.DONE} if project_id: match_stage["project_id"] = project_id if created_by: match_stage["created_by"] = created_by pipeline.append({"$match": match_stage}) pipeline.append({ "$group": { "_id": f"${group_by}", "total_runs": {"$sum": 1}, "total_tokens": { "$sum": { "$cond": [ {"$and": [{"$gt": ["$input_token_usage", 0]}, {"$gt": ["$output_token_usage", 0]}]}, {"$add": ["$input_token_usage", "$output_token_usage"]}, {"$ifNull": ["$token_usage", 0]} ] } }, "total_input_tokens": {"$sum": {"$ifNull": ["$input_token_usage", 0]}}, "total_output_tokens": {"$sum": {"$ifNull": ["$output_token_usage", 0]}}, "total_cost": { "$sum": { "$add": [ {"$multiply": [{"$ifNull": ["$input_token_usage", 0]}, 0.000002]}, {"$multiply": [{"$ifNull": ["$output_token_usage", 0]}, 0.00012]} ] } } } }) pipeline.append({"$sort": {"total_cost": -1}}) cursor = self.collection.aggregate(pipeline) res = await cursor.to_list(None) results = [] for item in res: entity_id = item.pop("_id") item["total_cost"] = round(item["total_cost"], 4) results.append({ "entity_id": str(entity_id) if entity_id else "unknown", "stats": item }) return results async def get_generations_by_group(self, group_id: str) -> List[Generation]: res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) generations: List[Generation] = [] for generation in res: generation["id"] = str(generation.pop("_id")) generations.append(Generation(**generation)) return generations async def cancel_stale_generations(self, timeout_minutes: int = 5) -> int: cutoff_time = datetime.now(UTC) - timedelta(minutes=timeout_minutes) res = await self.collection.update_many( { "status": GenerationStatus.RUNNING, "created_at": {"$lt": cutoff_time} }, { "$set": { "status": GenerationStatus.FAILED, "failed_reason": "Timeout: Execution time limit exceeded", "updated_at": datetime.now(UTC) } } ) return res.modified_count async def soft_delete_old_generations(self, days: int = 2) -> tuple[int, List[str]]: """ Мягко удаляет генерации старше N дней. Возвращает (количество удалённых, список asset IDs для очистки). """ cutoff_time = datetime.now(UTC) - timedelta(days=days) filter_query = { "is_deleted": False, "status": {"$in": [GenerationStatus.DONE, GenerationStatus.FAILED]}, "created_at": {"$lt": cutoff_time} } # Сначала собираем asset IDs из удаляемых генераций asset_ids: List[str] = [] cursor = self.collection.find(filter_query, {"result_list": 1, "assets_list": 1}) async for doc in cursor: asset_ids.extend(doc.get("result_list", [])) asset_ids.extend(doc.get("assets_list", [])) # Мягкое удаление res = await self.collection.update_many( filter_query, { "$set": { "is_deleted": True, "updated_at": datetime.now(UTC) } } ) # Убираем дубликаты unique_asset_ids = list(set(asset_ids)) return res.modified_count, unique_asset_ids