Source code for hivetracered.evaluators.keyword_evaluator
from typing import Dict, Any, Union, List, Set, Optional, AsyncGenerator
from hivetracered.evaluators.base_evaluator import BaseEvaluator
from hivetracered.evaluators.data.utils import load_keywords
[docs]
class KeywordEvaluator(BaseEvaluator):
"""
Evaluator that detects presence of specified keywords in model responses.
Used to identify prohibited content, harmful language, or required terms in responses.
"""
[docs]
def __init__(self,
keywords: Optional[List[str]] = None,
case_sensitive: bool = False,
match_all: bool = False,
name: Optional[str] = None,
description: Optional[str] = None):
"""
Initialize the keyword evaluator with detection parameters.
Args:
keywords: List of keywords to search for in responses (loads defaults if None)
case_sensitive: Whether to perform case-sensitive matching
match_all: If True, requires all keywords to be present; if False, any keyword is sufficient
name: Custom name for the evaluator
description: Custom description for the evaluator
"""
# Load default keywords if none provided
if keywords is None:
keywords = load_keywords()
self.keywords = keywords
self.case_sensitive = case_sensitive
self.match_all = match_all
self._name = name
self._description = description
[docs]
def evaluate(self, prompt: Union[str, List[Dict[str, str]]], response: Any) -> Dict[str, Any]:
"""
Evaluate a model response by checking for keyword presence.
Args:
prompt: The original prompt (string or message list)
response: The model's response to evaluate
Returns:
Dictionary with:
- 'success': False if keywords were matched according to criteria
- 'matched_keywords': List of keywords found in the response
- 'keyword_count': Dictionary mapping each keyword to occurrence count
- 'error': Error message if response processing failed
"""
# Extract the response text
if hasattr(response, 'content'):
response_text = response.content
elif isinstance(response, str):
response_text = response
elif isinstance(response, dict) and "content" in response:
response_text = response["content"]
else:
try:
response_text = str(response)
except:
return {
'matched': False,
'matched_keywords': [],
'keyword_count': {},
'error': 'Could not extract text from response'
}
# Perform case conversion if needed
if not self.case_sensitive:
response_text = response_text.lower()
search_keywords = [keyword.lower() for keyword in self.keywords]
else:
search_keywords = self.keywords
# Check for keyword occurrences
matched_keywords = []
keyword_count = {}
for keyword in search_keywords:
count = response_text.count(keyword)
keyword_count[keyword] = count
if count > 0:
matched_keywords.append(keyword)
# Determine overall match based on match_all setting
if self.match_all:
matched = len(matched_keywords) == len(search_keywords)
else:
matched = len(matched_keywords) > 0
return {
'success': not matched,
'matched_keywords': matched_keywords,
'keyword_count': keyword_count
}
[docs]
async def stream_abatch(self, prompts: List[Dict[str, str]], responses: List[Any]) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process and evaluate multiple prompt-response pairs asynchronously.
Args:
prompts: List of prompts that generated the responses
responses: List of model responses to evaluate
Yields:
Evaluation result dictionaries for each prompt-response pair
"""
for prompt, response in zip(prompts, responses):
yield self.evaluate(prompt, response)
[docs]
def get_name(self) -> str:
"""
Get the name of the evaluator.
Returns:
Custom name if provided, otherwise "Keyword Evaluator"
"""
if self._name:
return self._name
return "Keyword Evaluator"
[docs]
def get_description(self) -> str:
"""
Get the description of the evaluator.
Returns:
Custom description if provided, otherwise auto-generated description
"""
if self._description:
return self._description
mode = "all" if self.match_all else "any"
case = "case-sensitive" if self.case_sensitive else "case-insensitive"
return f"Evaluates responses for the presence of {mode} keywords: {', '.join(self.keywords)} ({case})"
[docs]
def get_params(self) -> Dict[str, Any]:
"""
Get the parameters of the evaluator.
Returns:
Dictionary containing all evaluator configuration parameters
"""
return {
"keywords": self.keywords,
"case_sensitive": self.case_sensitive,
"match_all": self.match_all,
"name": self.get_name(),
"description": self.get_description(),
}