Source code for hivetracered.attacks.template_attack
from typing import Union, List, Optional, Dict, AsyncGenerator
from hivetracered.attacks.base_attack import BaseAttack
[docs]
class TemplateAttack(BaseAttack):
"""
A base class for template-based attacks.
Allows creating new attacks by defining a template string with a '{prompt}' placeholder where the original prompt will be inserted.
"""
[docs]
def __init__(self, template: str = "{prompt}", name: Optional[str] = None, description: Optional[str] = None):
"""
Initialize the template attack with a specific template string.
Args:
template: A format string with a '{prompt}' placeholder
name: Optional name for the attack (defaults to class name)
description: Optional description for the attack
"""
self.template = template
self._name = name
self._description = description
[docs]
def apply(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
"""
Apply the template attack to the given prompt.
Args:
prompt: A string or list of messages to apply the attack to.
If the prompt is a list, the template will be applied to the last message.
Returns:
The transformed prompt with the template applied
Raises:
ValueError: If the prompt is invalid or the last message is not a human message
"""
if isinstance(prompt, str):
return self.template.format(prompt=prompt)
elif isinstance(prompt, list):
transformed_messages = prompt[:-1]
if prompt[-1].get("role") == "human":
transformed_messages.append({"role": "human", "content": self.template.format(prompt=prompt[-1]["content"])})
else:
raise ValueError("Last message in prompt is not a human message")
return transformed_messages
else:
raise ValueError("Prompt is not a string or list of messages")
[docs]
async def stream_abatch(self, prompts: List[Union[str, List[Dict[str, str]]]]) -> AsyncGenerator[List[Union[str, List[Dict[str, str]]]], None]:
"""
Apply the template attack to a batch of prompts asynchronously.
Args:
prompts: A list of prompts to apply the attack to
Returns:
An async generator yielding transformed prompts
"""
for prompt in prompts:
yield self.apply(prompt)
[docs]
def get_name(self) -> str:
"""
Get the name of the attack.
Returns:
The custom name if provided, otherwise the class name
"""
if self._name:
return self._name
return self.__class__.__name__
[docs]
def get_description(self) -> str:
"""
Get the description of the attack.
Returns:
The custom description if provided, otherwise a default description based on the template
"""
if self._description:
return self._description
return f"Template attack using template: {self.template}"
[docs]
def get_params(self):
"""
Get the parameters of the attack.
Returns:
A dictionary containing the attack's parameters
"""
return {
"template": self.template,
"name": self.get_name(),
"description": self.get_description()
}