Files
ai-char-bot/adapters/google_adapter.py
2026-02-12 18:41:01 +03:00

171 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import io
import logging
from datetime import datetime
from typing import List, Union, Tuple, Dict, Any
from PIL import Image
from google import genai
from google.genai import types
from adapters.Exception import GoogleGenerationException
from models.enums import AspectRatios, Quality
logger = logging.getLogger(__name__)
class GoogleAdapter:
def __init__(self, api_key: str):
if not api_key:
raise ValueError("API Key for Gemini is missing")
self.client = genai.Client(api_key=api_key)
# Константы моделей
self.TEXT_MODEL = "gemini-3-pro-preview"
self.IMAGE_MODEL = "gemini-3-pro-image-preview"
def _prepare_contents(self, prompt: str, images_list: List[bytes] = None) -> tuple:
"""Вспомогательный метод для подготовки контента (текст + картинки).
Returns (contents, opened_images) — caller MUST close opened_images after use."""
contents = [prompt]
opened_images = []
if images_list:
logger.info(f"Preparing content with {len(images_list)} images")
for img_bytes in images_list:
try:
image = Image.open(io.BytesIO(img_bytes))
contents.append(image)
opened_images.append(image)
except Exception as e:
logger.error(f"Error processing input image: {e}")
else:
logger.info("Preparing content with no images")
return contents, opened_images
def generate_text(self, prompt: str, images_list: List[bytes] = None) -> str:
"""
Генерация текста (Чат или Vision).
Возвращает строку с ответом.
"""
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating text: {prompt}")
try:
response = self.client.models.generate_content(
model=self.TEXT_MODEL,
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['TEXT'],
temperature=0.7,
)
)
# Собираем текст из всех частей ответа
result_text = ""
if response.parts:
for part in response.parts:
if part.text:
result_text += part.text
logger.info(f"Generated text length: {len(result_text)}")
return result_text
except Exception as e:
logger.error(f"Gemini Text API Error: {e}")
raise GoogleGenerationException(f"Gemini Text API Error: {e}")
finally:
for img in opened_images:
img.close()
def generate_image(self, prompt: str, aspect_ratio: AspectRatios, quality: Quality, images_list: List[bytes] = None, ) -> Tuple[List[io.BytesIO], Dict[str, Any]]:
"""
Генерация изображений (Text-to-Image или Image-to-Image).
Возвращает список байтовых потоков (готовых к отправке).
"""
contents, opened_images = self._prepare_contents(prompt, images_list)
logger.info(f"Generating image. Prompt length: {len(prompt)}, Ratio: {aspect_ratio}, Quality: {quality}")
start_time = datetime.now()
token_usage = 0
try:
response = self.client.models.generate_content(
model=self.IMAGE_MODEL,
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['IMAGE'],
temperature=1.0,
image_config=types.ImageConfig(
aspect_ratio=aspect_ratio.value_ratio,
image_size=quality.value_quality
),
)
)
end_time = datetime.now()
api_duration = (end_time - start_time).total_seconds()
if response.usage_metadata:
token_usage = response.usage_metadata.total_token_count
# Check prompt-level block (e.g. PROHIBITED_CONTENT) — no candidates in this case
if response.prompt_feedback and response.prompt_feedback.block_reason:
raise GoogleGenerationException(
f"Generation blocked at prompt level: {response.prompt_feedback.block_reason.value}"
)
# Check candidate-level block
if response.parts is None:
response_reason = (
response.candidates[0].finish_reason
if response.candidates and len(response.candidates) > 0
else "Unknown"
)
raise GoogleGenerationException(f"Generation blocked: {response_reason}")
generated_images = []
if response.parts:
for part in response.parts:
# Ищем картинки (inline_data)
if part.inline_data:
try:
# 1. Берем сырые байты
raw_data = part.inline_data.data
byte_arr = io.BytesIO(raw_data)
# 2. Нейминг (формально, для TG)
timestamp = datetime.now().timestamp()
byte_arr.name = f'{timestamp}.png'
# 3. Важно: сбросить курсор в начало
byte_arr.seek(0)
generated_images.append(byte_arr)
except Exception as e:
logger.error(f"Error processing output image: {e}")
if generated_images:
logger.info(f"Successfully generated {len(generated_images)} images in {api_duration:.2f}s. Tokens: {token_usage}")
else:
logger.warning("No images text generated from parts")
input_tokens = 0
output_tokens = 0
if response.usage_metadata:
input_tokens = response.usage_metadata.prompt_token_count
output_tokens = response.usage_metadata.candidates_token_count
metrics = {
"api_execution_time_seconds": api_duration,
"token_usage": token_usage,
"input_token_usage": input_tokens,
"output_token_usage": output_tokens
}
return generated_images, metrics
except Exception as e:
logger.error(f"Gemini Image API Error: {e}")
raise GoogleGenerationException(f"Gemini Image API Error: {e}")
finally:
for img in opened_images:
img.close()
del contents