174 lines
7.4 KiB
Python
174 lines
7.4 KiB
Python
import io
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import List, Union, Tuple, Dict, Any
|
||
|
||
from PIL import Image
|
||
from google import genai
|
||
from google.genai import types
|
||
|
||
from adapters.Exception import GoogleGenerationException
|
||
from models.enums import AspectRatios, Quality, TextModel, ImageModel
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class GoogleAdapter:
|
||
def __init__(self, api_key: str):
|
||
if not api_key:
|
||
raise ValueError("API Key for Gemini is missing")
|
||
self.client = genai.Client(api_key=api_key)
|
||
|
||
def _prepare_contents(self, prompt: str, images_list: List[bytes] | None = None) -> tuple:
|
||
"""Вспомогательный метод для подготовки контента (текст + картинки).
|
||
Returns (contents, opened_images) — caller MUST close opened_images after use."""
|
||
contents : list [Any]= [prompt]
|
||
opened_images = []
|
||
if images_list:
|
||
logger.info(f"Preparing content with {len(images_list)} images")
|
||
for img_bytes in images_list:
|
||
try:
|
||
image = Image.open(io.BytesIO(img_bytes))
|
||
contents.append(image)
|
||
opened_images.append(image)
|
||
except Exception as e:
|
||
logger.error(f"Error processing input image: {e}")
|
||
else:
|
||
logger.info("Preparing content with no images")
|
||
return contents, opened_images
|
||
|
||
def generate_text(self, prompt: str, model: str = "gemini-3.1-pro-preview", images_list: List[bytes] | None = None) -> str:
|
||
"""
|
||
Генерация текста (Чат или Vision).
|
||
Возвращает строку с ответом.
|
||
"""
|
||
if model not in [m.value for m in TextModel]:
|
||
raise ValueError(f"Invalid model for text generation: {model}. Expected one of: {[m.value for m in TextModel]}")
|
||
|
||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||
logger.info(f"Generating text: {prompt} with model: {model}")
|
||
try:
|
||
response = self.client.models.generate_content(
|
||
model=model,
|
||
contents=contents,
|
||
config=types.GenerateContentConfig(
|
||
response_modalities=['TEXT'],
|
||
temperature=0.7,
|
||
)
|
||
)
|
||
|
||
# Собираем текст из всех частей ответа
|
||
result_text = ""
|
||
if response.parts:
|
||
for part in response.parts:
|
||
if part.text:
|
||
result_text += part.text
|
||
logger.info(f"Generated text length: {len(result_text)}")
|
||
return result_text
|
||
|
||
except Exception as e:
|
||
logger.error(f"Gemini Text API Error: {e}")
|
||
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
|
||
finally:
|
||
for img in opened_images:
|
||
img.close()
|
||
|
||
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, model: str = "gemini-3-pro-image-preview", images_list: List[bytes] | None = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
|
||
"""
|
||
Генерация изображений (Text-to-Image или Image-to-Image).
|
||
Возвращает список байтовых потоков (готовых к отправке).
|
||
"""
|
||
if model not in [m.value for m in ImageModel]:
|
||
raise ValueError(f"Invalid model for image generation: {model}. Expected one of: {[m.value for m in ImageModel]}")
|
||
|
||
contents, opened_images = self._prepare_contents(prompt, images_list)
|
||
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}, Model: {model}")
|
||
|
||
start_time = datetime.now()
|
||
token_usage = 0
|
||
|
||
try:
|
||
response = self.client.models.generate_content(
|
||
model=model,
|
||
contents=contents,
|
||
config=types.GenerateContentConfig(
|
||
response_modalities=['IMAGE'],
|
||
temperature=1.0,
|
||
image_config=types.ImageConfig(
|
||
aspect_ratio=aspect_ratio.value_ratio,
|
||
image_size=quality.value_quality
|
||
),
|
||
)
|
||
)
|
||
|
||
end_time = datetime.now()
|
||
api_duration = (end_time - start_time).total_seconds()
|
||
|
||
if response.usage_metadata:
|
||
token_usage = response.usage_metadata.total_token_count
|
||
|
||
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
|
||
if response.prompt_feedback and response.prompt_feedback.block_reason:
|
||
raise GoogleGenerationException(
|
||
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
|
||
)
|
||
|
||
# Check candidate-level block
|
||
if response.parts is None:
|
||
response_reason = (
|
||
response.candidates[0].finish_reason
|
||
if response.candidates and len(response.candidates) > 0
|
||
else "Unknown"
|
||
)
|
||
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
|
||
|
||
generated_images = []
|
||
|
||
if response.parts:
|
||
for part in response.parts:
|
||
# Ищем картинки (inline_data)
|
||
if part.inline_data:
|
||
try:
|
||
# 1. Берем сырые байты
|
||
raw_data = part.inline_data.data
|
||
if raw_data is None:
|
||
raise GoogleGenerationException("Generation returned no data")
|
||
byte_arr : io.BytesIO = io.BytesIO(raw_data)
|
||
|
||
# 2. Нейминг (формально, для TG)
|
||
timestamp = datetime.now().timestamp()
|
||
byte_arr.name = f'{timestamp}.png'
|
||
|
||
# 3. Важно: сбросить курсор в начало
|
||
byte_arr.seek(0)
|
||
|
||
generated_images.append(byte_arr)
|
||
except Exception as e:
|
||
logger.error(f"Error processing output image: {e}")
|
||
|
||
if generated_images:
|
||
logger.info(f"Successfully generated {len(generated_images)} images in {api_duration:.2f}s. Tokens: {token_usage}")
|
||
else:
|
||
logger.warning("No images text generated from parts")
|
||
|
||
input_tokens = 0
|
||
output_tokens = 0
|
||
if response.usage_metadata:
|
||
input_tokens = response.usage_metadata.prompt_token_count
|
||
output_tokens = response.usage_metadata.candidates_token_count
|
||
|
||
metrics = {
|
||
"api_execution_time_seconds": api_duration,
|
||
"token_usage": token_usage,
|
||
"input_token_usage": input_tokens,
|
||
"output_token_usage": output_tokens
|
||
}
|
||
return generated_images, metrics
|
||
|
||
except Exception as e:
|
||
logger.error(f"Gemini Image API Error: {e}")
|
||
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
|
||
finally:
|
||
for img in opened_images:
|
||
img.close()
|
||
del contents |