Source code for hivetracered.models.base_model
import asyncio
import contextlib
import warnings
from collections.abc import AsyncGenerator
from contextlib import AbstractAsyncContextManager
from abc import ABC, abstractmethod
[docs]
class Model(ABC):
"""
Abstract base class for language model implementations.
Defines the standard interface for interacting with various LLM providers,
supporting both synchronous and asynchronous operations for single requests and batches.
"""
model_name: str
max_concurrency: int = 0
def _concurrency_slot(self) -> AbstractAsyncContextManager:
"""Acquire one concurrency slot for an async call.
Lazily constructs a per-instance ``asyncio.Semaphore`` capped at
``self.max_concurrency``. The semaphore binds to the running event loop
on first acquire; if the model is reused across event loops, the slot
is reconstructed for the new loop.
Returns ``contextlib.nullcontext()`` when ``max_concurrency == 0``
(unlimited). Subclasses' ``ainvoke`` implementations must wrap their
outbound request with ``async with self._concurrency_slot():`` so that
every async entry point (``ainvoke``, ``abatch``, ``stream_abatch``)
observes the per-model cap.
"""
if self.max_concurrency == 0:
return contextlib.nullcontext()
sem = getattr(self, "_concurrency_sem", None)
if sem is None:
sem = asyncio.Semaphore(self.max_concurrency)
self._concurrency_sem = sem
return sem
[docs]
@abstractmethod
def invoke(self, prompt: str | list[dict[str, str]]) -> dict:
"""
Send a single request to the model synchronously.
Args:
prompt: A string or list of messages to send to the model
Returns:
Dictionary containing the model's response with at least a 'content' key
"""
pass
[docs]
@abstractmethod
async def ainvoke(self, prompt: str | list[dict[str, str]]) -> dict:
"""
Send a single request to the model asynchronously.
Args:
prompt: A string or list of messages to send to the model
Returns:
Dictionary containing the model's response with at least a 'content' key
"""
pass
[docs]
@abstractmethod
def batch(self, prompts: list[str | list[dict[str, str]]]) -> list[dict]:
"""
Send multiple requests to the model synchronously.
Args:
prompts: A list of prompts to send to the model
Returns:
List of response dictionaries in the same order as the input prompts
"""
pass
[docs]
@abstractmethod
async def abatch(self, prompts: list[str | list[dict[str, str]]]) -> list[dict]:
"""
Send multiple requests to the model asynchronously.
Args:
prompts: A list of prompts to send to the model
Returns:
List of response dictionaries in the same order as the input prompts
"""
pass
[docs]
def is_answer_blocked(self, answer: dict) -> bool:
"""
Check if the answer is blocked by model's safety guardrails.
Args:
answer: The model response dictionary to check
Returns:
Boolean indicating if the response was blocked
"""
return False
[docs]
def get_params(self) -> dict:
"""
Get the parameters of the model.
Returns:
Dictionary containing the model's configuration parameters
"""
return self.__dict__
@staticmethod
def _resolve_concurrency(
max_concurrency: int | None,
batch_size: int | None,
default: int,
) -> int:
"""
Resolve effective concurrency, honoring the deprecated `batch_size` alias.
Emits DeprecationWarning if batch_size is provided. When both are set,
max_concurrency wins. Falls back to `default` when neither is provided.
"""
if batch_size is not None:
warnings.warn(
"The 'batch_size' parameter is deprecated and will be removed in v2.0.0. "
"Use 'max_concurrency' instead.",
DeprecationWarning,
stacklevel=3,
)
if max_concurrency is None:
max_concurrency = batch_size
if max_concurrency is None:
max_concurrency = default
return max_concurrency
[docs]
@abstractmethod
async def stream_abatch(self, prompts: list[str | list[dict[str, str]]]) -> AsyncGenerator[dict, None]:
"""
Send multiple requests to the model asynchronously and yield results as they complete.
Args:
prompts: A list of prompts to send to the model
Yields:
Response dictionaries. Implementations MUST yield results in the same order
as the input prompts list, since downstream stages match responses to prompts
by sequential index.
"""
pass