52 lines
2.1 KiB
Python
52 lines
2.1 KiB
Python
import asyncio
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
from api.service.generation_service import GenerationService
|
|
from models.Settings import SystemSettings
|
|
from models.Generation import Generation
|
|
from models.enums import AspectRatios, Quality
|
|
|
|
async def test_generation_service_proxy_logic():
|
|
dao = MagicMock()
|
|
gemini = MagicMock()
|
|
s3_adapter = MagicMock()
|
|
|
|
# Mock settings to have proxy ENABLED
|
|
dao.settings.get_settings = AsyncMock(return_value=SystemSettings(use_ai_proxy=True))
|
|
dao.assets.get_assets_by_ids = AsyncMock(return_value=[])
|
|
|
|
service = GenerationService(dao, gemini, s3_adapter)
|
|
|
|
# 1. Test ask_prompt_assistant with proxy
|
|
with patch.object(service.ai_proxy, 'generate_text', new_callable=AsyncMock) as mock_proxy_text:
|
|
mock_proxy_text.return_value = "Proxy Result"
|
|
result = await service.ask_prompt_assistant("Test Prompt")
|
|
assert result == "Proxy Result"
|
|
mock_proxy_text.assert_called_once()
|
|
gemini.generate_text.assert_not_called()
|
|
|
|
# 2. Test create_generation with proxy
|
|
generation = Generation(
|
|
prompt="Test Image",
|
|
aspect_ratio=AspectRatios.ONEONE,
|
|
quality=Quality.ONEK,
|
|
assets_list=[]
|
|
)
|
|
# Mock _prepare_generation_input to avoid complex DB calls
|
|
service._prepare_generation_input = AsyncMock(return_value=([], "Test Image", []))
|
|
service._process_generated_images = AsyncMock(return_value=[])
|
|
service._finalize_generation = AsyncMock()
|
|
|
|
with patch.object(service.ai_proxy, 'generate_image', new_callable=AsyncMock) as mock_proxy_img:
|
|
import io
|
|
mock_img_io = io.BytesIO(b"fake image data")
|
|
mock_proxy_img.return_value = ([mock_img_io], {"api_execution_time_seconds": 1.0})
|
|
|
|
await service.create_generation(generation)
|
|
mock_proxy_img.assert_called_once()
|
|
# gemini.generate_image would be called via generate_image_task in else branch
|
|
|
|
print("✅ Proxy logic test passed!")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_generation_service_proxy_logic())
|