models + refactor
This commit is contained in:
51
tests/test_ai_proxy_logic.py
Normal file
51
tests/test_ai_proxy_logic.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from api.service.generation_service import GenerationService
|
||||
from models.Settings import SystemSettings
|
||||
from models.Generation import Generation
|
||||
from models.enums import AspectRatios, Quality
|
||||
|
||||
async def test_generation_service_proxy_logic():
|
||||
dao = MagicMock()
|
||||
gemini = MagicMock()
|
||||
s3_adapter = MagicMock()
|
||||
|
||||
# Mock settings to have proxy ENABLED
|
||||
dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True))
|
||||
dao.assets.get_assets_by_ids = AsyncMock(return_value=[])
|
||||
|
||||
service = GenerationService(dao, gemini, s3_adapter)
|
||||
|
||||
# 1. Test ask_prompt_assistant with proxy
|
||||
with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text:
|
||||
mock_proxy_text.return_value = "Proxy Result"
|
||||
result = await service.ask_prompt_assistant("Test Prompt")
|
||||
assert result == "Proxy Result"
|
||||
mock_proxy_text.assert_called_once()
|
||||
gemini.generate_text.assert_not_called()
|
||||
|
||||
# 2. Test create_generation with proxy
|
||||
generation = Generation(
|
||||
prompt="Test Image",
|
||||
aspect_ratio=AspectRatios.ONEONE,
|
||||
quality=Quality.ONEK,
|
||||
assets_list=[]
|
||||
)
|
||||
# Mock _prepare_generation_input to avoid complex DB calls
|
||||
service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", []))
|
||||
service._process_generated_images = AsyncMock(return_value=[])
|
||||
service._finalize_generation = AsyncMock()
|
||||
|
||||
with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img:
|
||||
import io
|
||||
mock_img_io = io.BytesIO(b"fake image data")
|
||||
mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0})
|
||||
|
||||
await service.create_generation(generation)
|
||||
mock_proxy_img.assert_called_once()
|
||||
# gemini.generate_image would be called via generate_image_task in else branch
|
||||
|
||||
print("✅ Proxy logic test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_generation_service_proxy_logic())
|
||||
Reference in New Issue
Block a user