Compare commits
2 Commits
c25b029006
...
8a89b27624
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a89b27624 | |||
| c17c47ccc1 |
14
.gitignore
vendored
14
.gitignore
vendored
@@ -1,15 +1,11 @@
|
|||||||
minio_backup.tar.gz
|
minio_backup.tar.gz
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.DS_Store
|
|
||||||
.DS_Store
|
|
||||||
**/__pycache__/
|
**/__pycache__/
|
||||||
# Игнорируем файлы скомпилированного байт-кода напрямую
|
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
# Игнорируем расширения CPython конкретно
|
|
||||||
*.cpython-*.pyc
|
*.cpython-*.pyc
|
||||||
|
**/.DS_Store
|
||||||
# Игнорируем файлы .DS_Store на всех уровнях
|
.idea/ai-char-bot.iml
|
||||||
**/.*.DS_Store
|
.idea
|
||||||
**/.DS_Store
|
.venv
|
||||||
|
.vscode
|
||||||
81
api/endpoints/album_router.py
Normal file
81
api/endpoints/album_router.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from api.models.GenerationRequest import GenerationResponse
|
||||||
|
from models.Album import Album
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/albums", tags=["Albums"])
|
||||||
|
|
||||||
|
class AlbumCreateRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class AlbumUpdateRequest(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
|
||||||
|
class AlbumResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
generation_ids: List[str] = []
|
||||||
|
cover_asset_id: Optional[str] = None # Not implemented yet
|
||||||
|
|
||||||
|
@router.post("/", response_model=AlbumResponse)
|
||||||
|
async def create_album(request: Request, album_in: AlbumCreateRequest):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.create_album(name=album_in.name, description=album_in.description)
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[AlbumResponse])
|
||||||
|
async def get_albums(request: Request, limit: int = 10, offset: int = 0):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
albums = await service.get_albums(limit=limit, offset=offset)
|
||||||
|
return [AlbumResponse(**album.model_dump()) for album in albums]
|
||||||
|
|
||||||
|
@router.get("/{album_id}", response_model=AlbumResponse)
|
||||||
|
async def get_album(request: Request, album_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.put("/{album_id}", response_model=AlbumResponse)
|
||||||
|
async def update_album(request: Request, album_id: str, album_in: AlbumUpdateRequest):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
album = await service.update_album(album_id, name=album_in.name, description=album_in.description)
|
||||||
|
if not album:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
return AlbumResponse(**album.model_dump())
|
||||||
|
|
||||||
|
@router.delete("/{album_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def delete_album(request: Request, album_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
deleted = await service.delete_album(album_id)
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album not found")
|
||||||
|
|
||||||
|
@router.post("/{album_id}/generations/{generation_id}")
|
||||||
|
async def add_generation_to_album(request: Request, album_id: str, generation_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
success = await service.add_generation_to_album(album_id, generation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||||
|
return {"status": "success"}
|
||||||
|
|
||||||
|
@router.delete("/{album_id}/generations/{generation_id}")
|
||||||
|
async def remove_generation_from_album(request: Request, album_id: str, generation_id: str):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
success = await service.remove_generation_from_album(album_id, generation_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Album or Generation not found")
|
||||||
|
return {"status": "success"}
|
||||||
|
|
||||||
|
@router.get("/{album_id}/generations", response_model=List[GenerationResponse])
|
||||||
|
async def get_album_generations(request: Request, album_id: str, limit: int = 10, offset: int = 0):
|
||||||
|
service: AlbumService = request.app.state.album_service
|
||||||
|
generations = await service.get_generations_by_album(album_id, limit=limit, offset=offset)
|
||||||
|
return [GenerationResponse(**gen.model_dump()) for gen in generations]
|
||||||
85
api/service/album_service.py
Normal file
85
api/service/album_service.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from models.Album import Album
|
||||||
|
from models.Generation import Generation
|
||||||
|
from repos.dao import DAO
|
||||||
|
|
||||||
|
class AlbumService:
|
||||||
|
def __init__(self, dao: DAO):
|
||||||
|
self.dao = dao
|
||||||
|
|
||||||
|
async def create_album(self, name: str, description: Optional[str] = None) -> Album:
|
||||||
|
album = Album(name=name, description=description)
|
||||||
|
album_id = await self.dao.albums.create_album(album)
|
||||||
|
album.id = album_id
|
||||||
|
return album
|
||||||
|
|
||||||
|
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||||
|
return await self.dao.albums.get_albums(limit=limit, offset=offset)
|
||||||
|
|
||||||
|
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||||
|
return await self.dao.albums.get_album(album_id)
|
||||||
|
|
||||||
|
async def update_album(self, album_id: str, name: Optional[str] = None, description: Optional[str] = None) -> Optional[Album]:
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if name:
|
||||||
|
album.name = name
|
||||||
|
if description is not None:
|
||||||
|
album.description = description
|
||||||
|
|
||||||
|
await self.dao.albums.update_album(album_id, album)
|
||||||
|
return album
|
||||||
|
|
||||||
|
async def delete_album(self, album_id: str) -> bool:
|
||||||
|
return await self.dao.albums.delete_album(album_id)
|
||||||
|
|
||||||
|
async def add_generation_to_album(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
# Verify album exists
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify generation exists (optional but good practice)
|
||||||
|
gen = await self.dao.generations.get_generation(generation_id)
|
||||||
|
if not gen:
|
||||||
|
return False
|
||||||
|
if album.cover_asset_id is None and gen.status == 'done':
|
||||||
|
album.cover_asset_id = gen.result_list[0]
|
||||||
|
return await self.dao.albums.add_generation(album_id, generation_id, album.cover_asset_id)
|
||||||
|
|
||||||
|
async def remove_generation_from_album(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
return await self.dao.albums.remove_generation(album_id, generation_id)
|
||||||
|
|
||||||
|
async def get_generations_by_album(self, album_id: str, limit: int = 10, offset: int = 0) -> List[Generation]:
|
||||||
|
album = await self.dao.albums.get_album(album_id)
|
||||||
|
if not album or not album.generation_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Slice the generation IDs (simple pagination on ID list)
|
||||||
|
# Note: This pagination is on IDs, then we fetch objects.
|
||||||
|
# Ideally, fetch only slice.
|
||||||
|
|
||||||
|
# Reverse to show newest first? Or just follow list order?
|
||||||
|
# Assuming list order is insertion order (which usually is what we want for manual sorting or chronological if always appended).
|
||||||
|
# Let's assume user wants same order as in list.
|
||||||
|
|
||||||
|
sliced_ids = album.generation_ids[offset : offset + limit]
|
||||||
|
if not sliced_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Fetch generations by IDs
|
||||||
|
# We need a method in GenerationRepo to fetch by IDs.
|
||||||
|
# Currently we only have get_generations with filters.
|
||||||
|
# We can add get_generations_by_ids to GenerationRepo or use loop (inefficient).
|
||||||
|
# Let's add get_generations_by_ids to GenerationRepo.
|
||||||
|
|
||||||
|
# For now, I will use a loop if I can't modify Repo immediately,
|
||||||
|
# but I SHOULD modify GenerationRepo.
|
||||||
|
|
||||||
|
# Or I can use get_generations(filter={"_id": {"$in": [ObjectId(id) for id in sliced_ids]}})
|
||||||
|
# But get_generations doesn't support generic filter passing.
|
||||||
|
|
||||||
|
# I'll update GenerationRepo to support fetching by IDs.
|
||||||
|
return await self.dao.generations.get_generations_by_ids(sliced_ids)
|
||||||
5
main.py
5
main.py
@@ -18,6 +18,7 @@ from starlette.middleware.cors import CORSMiddleware
|
|||||||
from adapters.google_adapter import GoogleAdapter
|
from adapters.google_adapter import GoogleAdapter
|
||||||
from adapters.s3_adapter import S3Adapter
|
from adapters.s3_adapter import S3Adapter
|
||||||
from api.service.generation_service import GenerationService
|
from api.service.generation_service import GenerationService
|
||||||
|
from api.service.album_service import AlbumService
|
||||||
from middlewares.album import AlbumMiddleware
|
from middlewares.album import AlbumMiddleware
|
||||||
from middlewares.auth import AuthMiddleware
|
from middlewares.auth import AuthMiddleware
|
||||||
from middlewares.dao import DaoMiddleware
|
from middlewares.dao import DaoMiddleware
|
||||||
@@ -38,6 +39,7 @@ from api.endpoints.character_router import router as api_char_router # Роут
|
|||||||
from api.endpoints.generation_router import router as api_gen_router
|
from api.endpoints.generation_router import router as api_gen_router
|
||||||
from api.endpoints.auth import router as api_auth_router
|
from api.endpoints.auth import router as api_auth_router
|
||||||
from api.endpoints.admin import router as api_admin_router
|
from api.endpoints.admin import router as api_admin_router
|
||||||
|
from api.endpoints.album_router import router as api_album_router
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -79,6 +81,7 @@ s3_adapter = S3Adapter(
|
|||||||
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
dao = DAO(mongo_client, s3_adapter) # Главный DAO для бота
|
||||||
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
gemini = GoogleAdapter(api_key=GEMINI_API_KEY)
|
||||||
generation_service = GenerationService(dao, gemini, bot)
|
generation_service = GenerationService(dao, gemini, bot)
|
||||||
|
album_service = AlbumService(dao)
|
||||||
|
|
||||||
# Dispatcher
|
# Dispatcher
|
||||||
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
dp = Dispatcher(storage=MongoStorage(mongo_client, db_name=DB_NAME))
|
||||||
@@ -132,6 +135,7 @@ async def lifespan(app: FastAPI):
|
|||||||
app.state.gemini_client = gemini
|
app.state.gemini_client = gemini
|
||||||
app.state.bot = bot
|
app.state.bot = bot
|
||||||
app.state.s3_adapter = s3_adapter
|
app.state.s3_adapter = s3_adapter
|
||||||
|
app.state.album_service = album_service
|
||||||
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
app.state.users_repo = users_repo # Добавляем репозиторий в state
|
||||||
|
|
||||||
print("✅ DB & DAO initialized")
|
print("✅ DB & DAO initialized")
|
||||||
@@ -181,6 +185,7 @@ app.include_router(admin_api_router)
|
|||||||
app.include_router(api_assets_router)
|
app.include_router(api_assets_router)
|
||||||
app.include_router(api_char_router)
|
app.include_router(api_char_router)
|
||||||
app.include_router(api_gen_router)
|
app.include_router(api_gen_router)
|
||||||
|
app.include_router(api_album_router)
|
||||||
app.include_router(api_admin_router)
|
app.include_router(api_admin_router)
|
||||||
app.include_router(api_auth_router)
|
app.include_router(api_auth_router)
|
||||||
|
|
||||||
|
|||||||
12
models/Album.py
Normal file
12
models/Album.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from datetime import datetime, UTC
|
||||||
|
from typing import Optional, List
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class Album(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
cover_asset_id: Optional[str] = None
|
||||||
|
generation_ids: List[str] = []
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC))
|
||||||
61
repos/albums_repo.py
Normal file
61
repos/albums_repo.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
from bson import ObjectId
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
|
||||||
|
from models.Album import Album
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AlbumsRepo:
|
||||||
|
def __init__(self, client: AsyncIOMotorClient, db_name="bot_db"):
|
||||||
|
self.collection = client[db_name]["albums"]
|
||||||
|
|
||||||
|
async def create_album(self, album: Album) -> str:
|
||||||
|
res = await self.collection.insert_one(album.model_dump())
|
||||||
|
return str(res.inserted_id)
|
||||||
|
|
||||||
|
async def get_album(self, album_id: str) -> Optional[Album]:
|
||||||
|
try:
|
||||||
|
res = await self.collection.find_one({"_id": ObjectId(album_id)})
|
||||||
|
if not res:
|
||||||
|
return None
|
||||||
|
|
||||||
|
res["id"] = str(res.pop("_id"))
|
||||||
|
return Album(**res)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_albums(self, limit: int = 10, offset: int = 0) -> List[Album]:
|
||||||
|
res = await self.collection.find().sort("created_at", -1).skip(offset).limit(limit).to_list(None)
|
||||||
|
albums = []
|
||||||
|
for doc in res:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
albums.append(Album(**doc))
|
||||||
|
return albums
|
||||||
|
|
||||||
|
async def update_album(self, album_id: str, album: Album) -> bool:
|
||||||
|
if not album.id:
|
||||||
|
album.id = album_id
|
||||||
|
|
||||||
|
model_dump = album.model_dump()
|
||||||
|
res = await self.collection.update_one({"_id": ObjectId(album_id)}, {"$set": model_dump})
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def delete_album(self, album_id: str) -> bool:
|
||||||
|
res = await self.collection.delete_one({"_id": ObjectId(album_id)})
|
||||||
|
return res.deleted_count > 0
|
||||||
|
|
||||||
|
async def add_generation(self, album_id: str, generation_id: str, cover_asset_id: Optional[str] = None) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(album_id)},
|
||||||
|
{"$addToSet": {"generation_ids": generation_id}, "$set": {"cover_asset_id": cover_asset_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
|
|
||||||
|
async def remove_generation(self, album_id: str, generation_id: str) -> bool:
|
||||||
|
res = await self.collection.update_one(
|
||||||
|
{"_id": ObjectId(album_id)},
|
||||||
|
{"$pull": {"generation_ids": generation_id}}
|
||||||
|
)
|
||||||
|
return res.modified_count > 0
|
||||||
@@ -4,6 +4,7 @@ from repos.assets_repo import AssetsRepo
|
|||||||
from repos.char_repo import CharacterRepo
|
from repos.char_repo import CharacterRepo
|
||||||
from repos.generation_repo import GenerationRepo
|
from repos.generation_repo import GenerationRepo
|
||||||
from repos.user_repo import UsersRepo
|
from repos.user_repo import UsersRepo
|
||||||
|
from repos.albums_repo import AlbumsRepo
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -14,3 +15,4 @@ class DAO:
|
|||||||
self.chars = CharacterRepo(client, db_name)
|
self.chars = CharacterRepo(client, db_name)
|
||||||
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
self.assets = AssetsRepo(client, s3_adapter, db_name)
|
||||||
self.generations = GenerationRepo(client, db_name)
|
self.generations = GenerationRepo(client, db_name)
|
||||||
|
self.albums = AlbumsRepo(client, db_name)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class GenerationRepo:
|
|||||||
generations.append(Generation(**generation))
|
generations.append(Generation(**generation))
|
||||||
return generations
|
return generations
|
||||||
|
|
||||||
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None) -> int:
|
async def count_generations(self, character_id: Optional[str] = None, status: Optional[GenerationStatus] = None, album_id: Optional[str] = None) -> int:
|
||||||
args = {}
|
args = {}
|
||||||
if character_id is not None:
|
if character_id is not None:
|
||||||
args["linked_character_id"] = character_id
|
args["linked_character_id"] = character_id
|
||||||
@@ -48,5 +48,21 @@ class GenerationRepo:
|
|||||||
args["status"] = status
|
args["status"] = status
|
||||||
return await self.collection.count_documents(args)
|
return await self.collection.count_documents(args)
|
||||||
|
|
||||||
|
async def get_generations_by_ids(self, generation_ids: List[str]) -> List[Generation]:
|
||||||
|
object_ids = [ObjectId(gen_id) for gen_id in generation_ids if ObjectId.is_valid(gen_id)]
|
||||||
|
res = await self.collection.find({"_id": {"$in": object_ids}}).to_list(None)
|
||||||
|
generations: List[Generation] = []
|
||||||
|
|
||||||
|
# Maintain order of generation_ids
|
||||||
|
gen_map = {str(doc["_id"]): doc for doc in res}
|
||||||
|
|
||||||
|
for gen_id in generation_ids:
|
||||||
|
doc = gen_map.get(gen_id)
|
||||||
|
if doc:
|
||||||
|
doc["id"] = str(doc.pop("_id"))
|
||||||
|
generations.append(Generation(**doc))
|
||||||
|
|
||||||
|
return generations
|
||||||
|
|
||||||
async def update_generation(self, generation: Generation, ):
|
async def update_generation(self, generation: Generation, ):
|
||||||
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()})
|
||||||
|
|||||||
91
tests/verify_albums_manual.py
Normal file
91
tests/verify_albums_manual.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorClient
|
||||||
|
from repos.dao import DAO
|
||||||
|
from models.Album import Album
|
||||||
|
from models.Generation import Generation, GenerationStatus
|
||||||
|
from models.enums import AspectRatios, Quality
|
||||||
|
|
||||||
|
# Mock config
|
||||||
|
# Use the same host as main.py but different DB
|
||||||
|
MONGO_HOST = os.getenv("MONGO_HOST", "mongodb://admin:super_secure_password@31.59.58.220:27017")
|
||||||
|
DB_NAME = "bot_db_test_albums"
|
||||||
|
|
||||||
|
async def test_albums():
|
||||||
|
print(f"🚀 Starting Album Manual Verification using {MONGO_HOST}...")
|
||||||
|
|
||||||
|
# Needs to run inside a loop from main
|
||||||
|
client = AsyncIOMotorClient(MONGO_HOST)
|
||||||
|
dao = DAO(client, db_name=DB_NAME)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. Clean up
|
||||||
|
await client[DB_NAME]["albums"].drop()
|
||||||
|
await client[DB_NAME]["generations"].drop()
|
||||||
|
print("✅ Cleaned up test database")
|
||||||
|
|
||||||
|
# 2. Create Album
|
||||||
|
album = Album(name="Test Album", description="A test album")
|
||||||
|
print("Creating album...")
|
||||||
|
album_id = await dao.albums.create_album(album)
|
||||||
|
print(f"✅ Created Album: {album_id}")
|
||||||
|
|
||||||
|
# 3. Create Generations
|
||||||
|
gen1 = Generation(prompt="Gen 1", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||||
|
gen2 = Generation(prompt="Gen 2", aspect_ratio=AspectRatios.NINESIXTEEN, quality=Quality.ONEK)
|
||||||
|
|
||||||
|
print("Creating generations...")
|
||||||
|
gen1_id = await dao.generations.create_generation(gen1)
|
||||||
|
gen2_id = await dao.generations.create_generation(gen2)
|
||||||
|
print(f"✅ Created Generations: {gen1_id}, {gen2_id}")
|
||||||
|
|
||||||
|
# 4. Add generations to album
|
||||||
|
print("Adding generations to album...")
|
||||||
|
await dao.albums.add_generation(album_id, gen1_id)
|
||||||
|
await dao.albums.add_generation(album_id, gen2_id)
|
||||||
|
print("✅ Added generations to album")
|
||||||
|
|
||||||
|
# 5. Fetch album and check generation_ids
|
||||||
|
album_fetched = await dao.albums.get_album(album_id)
|
||||||
|
assert album_fetched is not None
|
||||||
|
assert len(album_fetched.generation_ids) == 2
|
||||||
|
assert gen1_id in album_fetched.generation_ids
|
||||||
|
assert gen2_id in album_fetched.generation_ids
|
||||||
|
print("✅ Verified generations in album")
|
||||||
|
|
||||||
|
# 6. Fetch generations by IDs via GenerationRepo
|
||||||
|
generations = await dao.generations.get_generations_by_ids([gen1_id, gen2_id])
|
||||||
|
assert len(generations) == 2
|
||||||
|
|
||||||
|
# Ensure ID type match (str vs ObjectId handling in repo)
|
||||||
|
gen_ids_fetched = [g.id for g in generations]
|
||||||
|
assert gen1_id in gen_ids_fetched
|
||||||
|
assert gen2_id in gen_ids_fetched
|
||||||
|
print("✅ Verified fetching generations by IDs")
|
||||||
|
|
||||||
|
# 7. Remove generation
|
||||||
|
print("Removing generation...")
|
||||||
|
await dao.albums.remove_generation(album_id, gen1_id)
|
||||||
|
album_fetched = await dao.albums.get_album(album_id)
|
||||||
|
assert len(album_fetched.generation_ids) == 1
|
||||||
|
assert album_fetched.generation_ids[0] == gen2_id
|
||||||
|
print("✅ Verified removing generation from album")
|
||||||
|
|
||||||
|
print("🎉 Album Verification SUCCESS")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup client
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
load_dotenv()
|
||||||
|
try:
|
||||||
|
asyncio.run(test_albums())
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
Reference in New Issue
Block a user