Source code for hivetracered.models.gigachat_model
from typing import List, Any, Optional, Union, Dict
from langchain_gigachat import GigaChat
from hivetracered.models.langchain_model import LangchainModel
import os
from dotenv import load_dotenv
import warnings
[docs]
class GigaChatModel(LangchainModel):
"""
GigaChat language model implementation using LangChain integration.
Provides standardized access to Sber's GigaChat models with support for
both synchronous and asynchronous request processing.
"""
[docs]
def __init__(self, model: str = "GigaChat", max_concurrency: Optional[int] = None, batch_size: Optional[int] = None, scope: Optional[str] = None, credentials: Optional[str] = None, verify_ssl_certs: bool = False, max_retries: int = 3, **kwargs):
"""
Initialize the GigaChat model client with the specified configuration.
Args:
model: GigaChat model variant (e.g., "GigaChat", "GigaChat-Pro")
max_concurrency: Maximum number of concurrent requests (replaces batch_size)
batch_size: (Deprecated) Use max_concurrency instead. Will be removed in v2.0.0
scope: API scope for authorization (from env or explicit)
credentials: API credentials for authentication (from env or explicit)
verify_ssl_certs: Whether to verify SSL certificates for API connections
max_retries: Maximum number of retry attempts on transient errors (default: 3)
**kwargs: Additional parameters for model configuration:
- profanity_check: Whether to enable profanity filtering
- temperature: Sampling temperature (lower = more deterministic)
- max_tokens: Maximum tokens in generated responses
- top_p: Top-p sampling parameter for response diversity
"""
load_dotenv(override=True)
# Get credentials from environment if not provided
if scope is None:
scope = os.getenv("GIGACHAT_API_SCOPE")
if credentials is None:
credentials = os.getenv("GIGACHAT_CREDENTIALS")
self.model_name = model
self.max_retries = max_retries
# Handle deprecation
if batch_size is not None:
warnings.warn(
"The 'batch_size' parameter is deprecated and will be removed in v2.0.0. "
"Use 'max_concurrency' instead.",
DeprecationWarning,
stacklevel=2
)
if max_concurrency is None:
max_concurrency = batch_size
# Set default if neither provided
if max_concurrency is None:
max_concurrency = 1
self.max_concurrency = max_concurrency
# Keep for backward compatibility in get_params()
self.batch_size = self.max_concurrency
self.kwargs = kwargs or {}
if not "temperature" in self.kwargs:
self.kwargs["temperature"] = 0.000001
self.client = GigaChat(credentials=credentials, model=model, scope=scope, verify_ssl_certs=verify_ssl_certs, **self.kwargs)
self.client = self._add_retry_policy(self.client)