diff --git a/.gitignore b/.gitignore index c12fb5d..4a15c02 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,9 @@ minio_backup.tar.gz .DS_Store **/__pycache__/ -# Игнорируем файлы скомпилированного байт-кода напрямую *.py[cod] *$py.class - -# Игнорируем расширения CPython конкретно *.cpython-*.pyc - -# Игнорируем файлы .DS_Store на всех уровнях **/.DS_Store .idea/ai-char-bot.iml .idea diff --git a/api/endpoints/album_router.py b/api/endpoints/album_router.py new file mode 100644 index 0000000..207dde0 --- /dev/null +++ b/api/endpoints/album_router.py @@ -0,0 +1,81 @@ +from typing import List, Optional +from fastapi import APIRouter, Depends, HTTPException, status, Request +from pydantic import BaseModel + +from api.models.GenerationRequest import GenerationResponse +from models.Album import Album +from repos.dao import DAO + +router = APIRouter(prefix="/api/albums", tags=["Albums"]) + +class AlbumCreateRequest(BaseModel): + name: str + description: Optional[str] = None + +class AlbumUpdateRequest(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + +class AlbumResponse(BaseModel): + id: str + name: str + description: Optional[str] = None + generation_ids: List[str] = [] + cover_asset_id: Optional[str] = None # Not implemented yet + +@router.post("/", response_model=AlbumResponse) +async def create_album(request: Request, album_in: AlbumCreateRequest): + service: AlbumService = request.app.state.album_service + album = await service.create_album(name=album_in.name, description=album_in.description) + return AlbumResponse(**album.model_dump()) + +@router.get("/", response_model=List[AlbumResponse]) +async def get_albums(request: Request, limit: int = 10, offset: int = 0): + service: AlbumService = request.app.state.album_service + albums = await service.get_albums(limit=limit, offset=offset) + return [AlbumResponse(**album.model_dump()) for album in albums] + +@router.get("/{album_id}", response_model=AlbumResponse) +async def get_album(request: Request, album_id: str): + service: AlbumService = request.app.state.album_service + album = await service.get_album(album_id) + if not album: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found") + return AlbumResponse(**album.model_dump()) + +@router.put("/{album_id}", response_model=AlbumResponse) +async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest): + service: AlbumService = request.app.state.album_service + album = await service.update_album(album_id, name=album_in.name, description=album_in.description) + if not album: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found") + return AlbumResponse(**album.model_dump()) + +@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_album(request: Request, album_id: str): + service: AlbumService = request.app.state.album_service + deleted = await service.delete_album(album_id) + if not deleted: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found") + +@router.post("/{album_id}/generations/{generation_id}") +async def add_generation_to_album(request: Request, album_id: str, generation_id: str): + service: AlbumService = request.app.state.album_service + success = await service.add_generation_to_album(album_id, generation_id) + if not success: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found") + return {"status": "success"} + +@router.delete("/{album_id}/generations/{generation_id}") +async def remove_generation_from_album(request: Request, album_id: str, generation_id: str): + service: AlbumService = request.app.state.album_service + success = await service.remove_generation_from_album(album_id, generation_id) + if not success: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found") + return {"status": "success"} + +@router.get("/{album_id}/generations", response_model=List[GenerationResponse]) +async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0): + service: AlbumService = request.app.state.album_service + generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset) + return [GenerationResponse(**gen.model_dump()) for gen in generations] diff --git a/api/service/album_service.py b/api/service/album_service.py new file mode 100644 index 0000000..6d062d8 --- /dev/null +++ b/api/service/album_service.py @@ -0,0 +1,85 @@ +from typing import List, Optional +from models.Album import Album +from models.Generation import Generation +from repos.dao import DAO + +class AlbumService: + def __init__(self, dao: DAO): + self.dao = dao + + async def create_album(self, name: str, description: Optional[str] = None) -> Album: + album = Album(name=name, description=description) + album_id = await self.dao.albums.create_album(album) + album.id = album_id + return album + + async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]: + return await self.dao.albums.get_albums(limit=limit, offset=offset) + + async def get_album(self, album_id: str) -> Optional[Album]: + return await self.dao.albums.get_album(album_id) + + async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]: + album = await self.dao.albums.get_album(album_id) + if not album: + return None + + if name: + album.name = name + if description is not None: + album.description = description + + await self.dao.albums.update_album(album_id, album) + return album + + async def delete_album(self, album_id: str) -> bool: + return await self.dao.albums.delete_album(album_id) + + async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool: + # Verify album exists + album = await self.dao.albums.get_album(album_id) + if not album: + return False + + # Verify generation exists (optional but good practice) + gen = await self.dao.generations.get_generation(generation_id) + if not gen: + return False + if album.cover_asset_id is None and gen.status == 'done': + album.cover_asset_id = gen.result_list[0] + return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id) + + async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool: + return await self.dao.albums.remove_generation(album_id, generation_id) + + async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]: + album = await self.dao.albums.get_album(album_id) + if not album or not album.generation_ids: + return [] + + # Slice the generation IDs (simple pagination on ID list) + # Note: This pagination is on IDs, then we fetch objects. + # Ideally, fetch only slice. + + # Reverse to show newest first? Or just follow list order? + # Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended). + # Let's assume user wants same order as in list. + + sliced_ids = album.generation_ids[offset : offset + limit] + if not sliced_ids: + return [] + + # Fetch generations by IDs + # We need a method in GenerationRepo to fetch by IDs. + # Currently we only have get_generations with filters. + # We can add get_generations_by_ids to GenerationRepo or use loop (inefficient). + # Let's add get_generations_by_ids to GenerationRepo. + + # For now, I will use a loop if I can't modify Repo immediately, + # but I SHOULD modify GenerationRepo. + + # Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}}) + # But get_generations doesn't support generic filter passing. + + # I'll update GenerationRepo to support fetching by IDs. + return await self.dao.generations.get_generations_by_ids(sliced_ids) diff --git a/main.py b/main.py index d454420..d97ebc5 100644 --- a/main.py +++ b/main.py @@ -18,6 +18,7 @@ from starlette.middleware.cors import CORSMiddleware from adapters.google_adapter import GoogleAdapter from adapters.s3_adapter import S3Adapter from api.service.generation_service import GenerationService +from api.service.album_service import AlbumService from middlewares.album import AlbumMiddleware from middlewares.auth import AuthMiddleware from middlewares.dao import DaoMiddleware @@ -38,6 +39,7 @@ from api.endpoints.character_router import router as api_char_router # Роут from api.endpoints.generation_router import router as api_gen_router from api.endpoints.auth import router as api_auth_router from api.endpoints.admin import router as api_admin_router +from api.endpoints.album_router import router as api_album_router load_dotenv() logger = logging.getLogger(__name__) @@ -79,6 +81,7 @@ s3_adapter = S3Adapter( dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота gemini = GoogleAdapter(api_key=GEMINI_API_KEY) generation_service = GenerationService(dao, gemini, bot) +album_service = AlbumService(dao) # Dispatcher dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME)) @@ -132,6 +135,7 @@ async def lifespan(app: FastAPI): app.state.gemini_client = gemini app.state.bot = bot app.state.s3_adapter = s3_adapter + app.state.album_service = album_service app.state.users_repo = users_repo # Добавляем репозиторий в state print("✅ DB & DAO initialized") @@ -181,6 +185,7 @@ app.include_router(admin_api_router) app.include_router(api_assets_router) app.include_router(api_char_router) app.include_router(api_gen_router) +app.include_router(api_album_router) app.include_router(api_admin_router) app.include_router(api_auth_router) diff --git a/models/Album.py b/models/Album.py new file mode 100644 index 0000000..26e6b18 --- /dev/null +++ b/models/Album.py @@ -0,0 +1,12 @@ +from datetime import datetime, UTC +from typing import Optional, List +from pydantic import BaseModel, Field + +class Album(BaseModel): + id: Optional[str] = None + name: str + description: Optional[str] = None + cover_asset_id: Optional[str] = None + generation_ids: List[str] = [] + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/repos/albums_repo.py b/repos/albums_repo.py new file mode 100644 index 0000000..d645b27 --- /dev/null +++ b/repos/albums_repo.py @@ -0,0 +1,61 @@ +from typing import List, Optional +import logging +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient + +from models.Album import Album + +logger = logging.getLogger(__name__) + +class AlbumsRepo: + def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"): + self.collection = client[db_name]["albums"] + + async def create_album(self, album: Album) -> str: + res = await self.collection.insert_one(album.model_dump()) + return str(res.inserted_id) + + async def get_album(self, album_id: str) -> Optional[Album]: + try: + res = await self.collection.find_one({"_id": ObjectId(album_id)}) + if not res: + return None + + res["id"] = str(res.pop("_id")) + return Album(**res) + except Exception: + return None + + async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]: + res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None) + albums = [] + for doc in res: + doc["id"] = str(doc.pop("_id")) + albums.append(Album(**doc)) + return albums + + async def update_album(self, album_id: str, album: Album) -> bool: + if not album.id: + album.id = album_id + + model_dump = album.model_dump() + res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump}) + return res.modified_count > 0 + + async def delete_album(self, album_id: str) -> bool: + res = await self.collection.delete_one({"_id": ObjectId(album_id)}) + return res.deleted_count > 0 + + async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(album_id)}, + {"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}} + ) + return res.modified_count > 0 + + async def remove_generation(self, album_id: str, generation_id: str) -> bool: + res = await self.collection.update_one( + {"_id": ObjectId(album_id)}, + {"$pull": {"generation_ids": generation_id}} + ) + return res.modified_count > 0 diff --git a/repos/dao.py b/repos/dao.py index f43a5e3..5bc70bd 100644 --- a/repos/dao.py +++ b/repos/dao.py @@ -4,6 +4,7 @@ from repos.assets_repo import AssetsRepo from repos.char_repo import CharacterRepo from repos.generation_repo import GenerationRepo from repos.user_repo import UsersRepo +from repos.albums_repo import AlbumsRepo from typing import Optional @@ -14,3 +15,4 @@ class DAO: self.chars = CharacterRepo(client, db_name) self.assets = AssetsRepo(client, s3_adapter, db_name) self.generations = GenerationRepo(client, db_name) + self.albums = AlbumsRepo(client, db_name) diff --git a/repos/generation_repo.py b/repos/generation_repo.py index 5d03370..6035fd7 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -40,7 +40,7 @@ class GenerationRepo: generations.append(Generation(**generation)) return generations - async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int: + async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int: args = {} if character_id is not None: args["linked_character_id"] = character_id @@ -48,5 +48,21 @@ class GenerationRepo: args["status"] = status return await self.collection.count_documents(args) + async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]: + object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)] + res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None) + generations: List[Generation] = [] + + # Maintain order of generation_ids + gen_map = {str(doc["_id"]): doc for doc in res} + + for gen_id in generation_ids: + doc = gen_map.get(gen_id) + if doc: + doc["id"] = str(doc.pop("_id")) + generations.append(Generation(**doc)) + + return generations + async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) diff --git a/tests/verify_albums_manual.py b/tests/verify_albums_manual.py new file mode 100644 index 0000000..c87a933 --- /dev/null +++ b/tests/verify_albums_manual.py @@ -0,0 +1,91 @@ +import asyncio +import os +import sys + +# Add project root to path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from motor.motor_asyncio import AsyncIOMotorClient +from repos.dao import DAO +from models.Album import Album +from models.Generation import Generation, GenerationStatus +from models.enums import AspectRatios, Quality + +# Mock config +# Use the same host as main.py but different DB +MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017") +DB_NAME = "bot_db_test_albums" + +async def test_albums(): + print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...") + + # Needs to run inside a loop from main + client = AsyncIOMotorClient(MONGO_HOST) + dao = DAO(client, db_name=DB_NAME) + + try: + # 1. Clean up + await client[DB_NAME]["albums"].drop() + await client[DB_NAME]["generations"].drop() + print("✅ Cleaned up test database") + + # 2. Create Album + album = Album(name="Test Album", description="A test album") + print("Creating album...") + album_id = await dao.albums.create_album(album) + print(f"✅ Created Album: {album_id}") + + # 3. Create Generations + gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK) + gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK) + + print("Creating generations...") + gen1_id = await dao.generations.create_generation(gen1) + gen2_id = await dao.generations.create_generation(gen2) + print(f"✅ Created Generations: {gen1_id}, {gen2_id}") + + # 4. Add generations to album + print("Adding generations to album...") + await dao.albums.add_generation(album_id, gen1_id) + await dao.albums.add_generation(album_id, gen2_id) + print("✅ Added generations to album") + + # 5. Fetch album and check generation_ids + album_fetched = await dao.albums.get_album(album_id) + assert album_fetched is not None + assert len(album_fetched.generation_ids) == 2 + assert gen1_id in album_fetched.generation_ids + assert gen2_id in album_fetched.generation_ids + print("✅ Verified generations in album") + + # 6. Fetch generations by IDs via GenerationRepo + generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id]) + assert len(generations) == 2 + + # Ensure ID type match (str vs ObjectId handling in repo) + gen_ids_fetched = [g.id for g in generations] + assert gen1_id in gen_ids_fetched + assert gen2_id in gen_ids_fetched + print("✅ Verified fetching generations by IDs") + + # 7. Remove generation + print("Removing generation...") + await dao.albums.remove_generation(album_id, gen1_id) + album_fetched = await dao.albums.get_album(album_id) + assert len(album_fetched.generation_ids) == 1 + assert album_fetched.generation_ids[0] == gen2_id + print("✅ Verified removing generation from album") + + print("🎉 Album Verification SUCCESS") + + finally: + # Cleanup client + client.close() + +if __name__ == "__main__": + from dotenv import load_dotenv + load_dotenv() + try: + asyncio.run(test_albums()) + except Exception as e: + print(f"Error: {e}")