models + refactor
This commit is contained in:
@@ -34,6 +34,7 @@ async def generate_image_task(
|
||||
media_group_bytes: List[bytes],
|
||||
aspect_ratio: AspectRatios,
|
||||
quality: Quality,
|
||||
model: str,
|
||||
gemini: GoogleAdapter,
|
||||
) -> Tuple[List[bytes], Dict[str, Any]]:
|
||||
"""
|
||||
@@ -47,6 +48,7 @@ async def generate_image_task(
|
||||
images_list=media_group_bytes,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
model=model,
|
||||
)
|
||||
generated_images_io, metrics = result
|
||||
logger.info(f"generate_image_task completed, received {len(generated_images_io) if generated_images_io else 0} images")
|
||||
@@ -75,7 +77,7 @@ class GenerationService:
|
||||
|
||||
# --- Public API ---
|
||||
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None) -> str:
|
||||
async def ask_prompt_assistant(self, prompt: str, assets: list[str] | None = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
future_prompt = (
|
||||
"You are an prompt-assistant. You improving user-entered prompts for image generation. "
|
||||
"User may upload reference image too. I will provide sources prompt entered by user. "
|
||||
@@ -87,17 +89,17 @@ class GenerationService:
|
||||
assets_db = await self.dao.assets.get_assets_by_ids(assets)
|
||||
assets_data.extend(asset.data for asset in assets_db if asset.data)
|
||||
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, assets_data)
|
||||
generated_prompt = await asyncio.to_thread(self.gemini.generate_text, future_prompt, model, assets_data)
|
||||
logger.info(f"Prompt Assistant: {generated_prompt}")
|
||||
return generated_prompt
|
||||
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None) -> str:
|
||||
async def generate_prompt_from_images(self, images: List[bytes], user_prompt: Optional[str] = None, model: str = "gemini-3.1-pro-preview") -> str:
|
||||
technical_prompt = "You are a prompt engineer. Describe this image in detail to create a stable diffusion using this image as reference. "
|
||||
if user_prompt:
|
||||
technical_prompt += f"User also provided this context: {user_prompt}. "
|
||||
technical_prompt += "Provide ONLY the detailed prompt."
|
||||
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, images_list=images)
|
||||
return await asyncio.to_thread(self.gemini.generate_text, prompt=technical_prompt, model=model, images_list=images)
|
||||
|
||||
async def get_generations(self, **kwargs) -> GenerationsResponse:
|
||||
current_user_id = kwargs.pop('current_user_id', None)
|
||||
@@ -162,6 +164,7 @@ class GenerationService:
|
||||
media_group_bytes=media_group_bytes,
|
||||
aspect_ratio=generation.aspect_ratio,
|
||||
quality=generation.quality,
|
||||
model=generation.model or "gemini-3-pro-image-preview",
|
||||
gemini=self.gemini
|
||||
)
|
||||
self._update_generation_metrics(generation, metrics)
|
||||
@@ -205,7 +208,9 @@ class GenerationService:
|
||||
aspect_ratio=external_gen.aspect_ratio,
|
||||
quality=external_gen.quality,
|
||||
prompt=external_gen.prompt,
|
||||
model=external_gen.model,
|
||||
tech_prompt=external_gen.tech_prompt,
|
||||
seed=external_gen.seed,
|
||||
result_list=[new_asset.id],
|
||||
result=new_asset.id,
|
||||
progress=100,
|
||||
|
||||
Reference in New Issue
Block a user