Source code for hivetracered.models.openai_model

from typing import List, Any, Optional, Union, Dict
from langchain_openai import ChatOpenAI
from hivetracered.models.langchain_model import LangchainModel
from dotenv import load_dotenv
import os
from typing import AsyncGenerator
import asyncio
from tqdm import tqdm
import warnings

from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_core.runnables import RunnableLambda

[docs] class OpenAIModel(LangchainModel): """ OpenAI-compatible language model implementation using the LangChain integration. Provides a standardized interface to OpenAI's API (or any OpenAI-compatible endpoint) with rate limiting support and both synchronous and asynchronous processing capabilities. """
[docs] def __init__(self, model: str = "gpt-4.1-nano", base_url: str = "https://api.openai.com/v1", max_concurrency: Optional[int] = None, batch_size: Optional[int] = None, rpm: int = 300, api_key: Optional[str] = None, max_retries: int = 3, **kwargs): """ Initialize the OpenAI model client with the specified configuration. Args: model: Model identifier (e.g., "gpt-4", "gpt-3.5-turbo", or any model name for compatible APIs) base_url: API base URL (default: "https://api.openai.com/v1"). Override for OpenAI-compatible endpoints. max_concurrency: Maximum number of concurrent requests (replaces batch_size) batch_size: (Deprecated) Use max_concurrency instead. Will be removed in v2.0.0 rpm: Rate limit in requests per minute api_key: API key; defaults to OPENAI_API_KEY env var max_retries: Maximum number of retry attempts on transient errors (default: 3) **kwargs: Additional parameters to pass to the ChatOpenAI constructor """ load_dotenv(override=True) if api_key is None: api_key = os.getenv("OPENAI_API_KEY") self.base_url = base_url self.model_name = model self.max_retries = max_retries # Handle deprecation if batch_size is not None: warnings.warn( "The 'batch_size' parameter is deprecated and will be removed in v2.0.0. " "Use 'max_concurrency' instead.", DeprecationWarning, stacklevel=2 ) if max_concurrency is None: max_concurrency = batch_size # Set default if neither provided if max_concurrency is None: max_concurrency = 1 self.max_concurrency = max_concurrency # Keep for backward compatibility in get_params() self.batch_size = self.max_concurrency self.kwargs = kwargs or {} if "temperature" not in self.kwargs: self.kwargs["temperature"] = 0.000001 rate_limiter = InMemoryRateLimiter( requests_per_second=rpm / 60, check_every_n_seconds=0.1, ) self.client = ChatOpenAI(model=model, rate_limiter=rate_limiter, base_url=base_url, api_key=api_key, **self.kwargs) self.client = self._add_retry_policy(self.client)