127 lines
2.9 KiB
Python
127 lines
2.9 KiB
Python
from google import genai
|
|
from google.genai import types
|
|
|
|
from backend.app.core.config import settings
|
|
|
|
_client: genai.Client | None = None
|
|
|
|
|
|
def get_client() -> genai.Client:
|
|
global _client
|
|
if _client is None:
|
|
_client = genai.Client(api_key=settings.GEMINI_API_KEY)
|
|
return _client
|
|
|
|
|
|
def chat_sync(
|
|
messages: list[dict[str, str]],
|
|
system_instruction: str | None = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 8192,
|
|
) -> str:
|
|
"""
|
|
Synchronous chat with Gemini.
|
|
|
|
messages: list of {"role": "user"|"model", "text": "..."}
|
|
Returns the model's text response.
|
|
"""
|
|
client = get_client()
|
|
|
|
contents = [
|
|
types.Content(
|
|
role=m["role"],
|
|
parts=[types.Part.from_text(text=m["text"])],
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
config = types.GenerateContentConfig(
|
|
temperature=temperature,
|
|
max_output_tokens=max_tokens,
|
|
)
|
|
if system_instruction:
|
|
config.system_instruction = system_instruction
|
|
|
|
response = client.models.generate_content(
|
|
model=settings.GEMINI_MODEL,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
return response.text or ""
|
|
|
|
|
|
async def chat_async(
|
|
messages: list[dict[str, str]],
|
|
system_instruction: str | None = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 8192,
|
|
) -> str:
|
|
"""
|
|
Async chat with Gemini.
|
|
|
|
messages: list of {"role": "user"|"model", "text": "..."}
|
|
Returns the model's text response.
|
|
"""
|
|
client = get_client()
|
|
|
|
contents = [
|
|
types.Content(
|
|
role=m["role"],
|
|
parts=[types.Part.from_text(text=m["text"])],
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
config = types.GenerateContentConfig(
|
|
temperature=temperature,
|
|
max_output_tokens=max_tokens,
|
|
)
|
|
if system_instruction:
|
|
config.system_instruction = system_instruction
|
|
|
|
response = await client.aio.models.generate_content(
|
|
model=settings.GEMINI_MODEL,
|
|
contents=contents,
|
|
config=config,
|
|
)
|
|
|
|
return response.text or ""
|
|
|
|
|
|
async def chat_stream(
|
|
messages: list[dict[str, str]],
|
|
system_instruction: str | None = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 8192,
|
|
):
|
|
"""
|
|
Async streaming chat with Gemini. Yields text chunks.
|
|
|
|
messages: list of {"role": "user"|"model", "text": "..."}
|
|
"""
|
|
client = get_client()
|
|
|
|
contents = [
|
|
types.Content(
|
|
role=m["role"],
|
|
parts=[types.Part.from_text(text=m["text"])],
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
config = types.GenerateContentConfig(
|
|
temperature=temperature,
|
|
max_output_tokens=max_tokens,
|
|
)
|
|
if system_instruction:
|
|
config.system_instruction = system_instruction
|
|
|
|
async for chunk in client.aio.models.generate_content_stream(
|
|
model=settings.GEMINI_MODEL,
|
|
contents=contents,
|
|
config=config,
|
|
):
|
|
if chunk.text:
|
|
yield chunk.text
|