Source code for hivetracered.attacks.model_attack

from typing import Union, List, Optional, Dict, Any, AsyncGenerator
from hivetracered.attacks.base_attack import BaseAttack
from hivetracered.models.base_model import Model

[docs] class ModelAttack(BaseAttack): """ Attack that uses a language model to transform prompts based on an attacker prompt template. Leverages the model's abilities to generate adversarial prompts through prompt engineering. """
[docs] def __init__(self, model: Model, attacker_prompt: str, model_kwargs: Optional[Dict[str, Any]] = None, name: Optional[str] = None, description: Optional[str] = None): """ Initialize the model attack with a specific model and attacker prompt. Args: model: The language model to use for the attack attacker_prompt: The prompt template to use for the attack, with {prompt} as placeholder model_kwargs: Optional additional arguments to pass to the model name: Optional name for the attack (defaults to class name) description: Optional description for the attack """ self.model = model self.attacker_prompt = attacker_prompt self.model_kwargs = model_kwargs or {} self._name = name self._description = description
[docs] def post_process_response(self, response: str) -> str: """ Post-process the model's response to clean it and handle refusals. Args: response: The raw response from the model Returns: The cleaned and processed response """ pref_suffixes = [ "** ", "**\n", "**\n", "\n\n", "\n", " ", ] prefixes_to_remove = [ "преобразованный промпт:", "исходное намерение:", "вредоносное намерение:", ] for prefix in prefixes_to_remove: for pref in pref_suffixes: lower_response = response.lower() if prefix + pref in lower_response: index = lower_response.index(prefix + pref) response = response[index + len(prefix + pref):].lstrip() # Check for refusal patterns and attempt to extract usable content refusal_indicators = [ "я не могу" ] # If the entire response is a refusal, try to extract any useful content that follows if any(indicator in response.lower() for indicator in refusal_indicators): response = "Я не могу ответить на этот вопрос" return response
[docs] def apply(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]: """ Apply the model attack to the given prompt. Args: prompt: A string or list of messages to apply the attack to Returns: The transformed prompt with the model attack applied Raises: ValueError: If the prompt format is invalid """ if isinstance(prompt, str): # Format the attacker prompt with the user prompt formatted_prompt = self.attacker_prompt.format(prompt=prompt) # Get the model's response response = self.model.invoke(formatted_prompt, **self.model_kwargs)["content"] # Post-process the response response = self.post_process_response(response) return response elif isinstance(prompt, list): transformed_messages = prompt[:-1] if prompt[-1].get("role") == "human": # Format the attacker prompt with the last message content formatted_prompt = self.attacker_prompt.format(prompt=prompt[-1]["content"]) # Get the model's response response = self.model.invoke(formatted_prompt, **self.model_kwargs)["content"] # Post-process the response response = self.post_process_response(response) transformed_messages.append({"role": "human", "content": response}) else: raise ValueError("Last message in prompt is not a human message") return transformed_messages else: raise ValueError("Prompt is not a string or list of messages")
[docs] async def batch(self, prompts: List[Union[str, List[Dict[str, str]]]]) -> List[Union[str, List[Dict[str, str]]]]: """ Apply the model attack to a batch of prompts in a non-streaming manner. Args: prompts: A list of prompts to apply the attack to Returns: List of transformed prompts with the model attack applied Raises: ValueError: If any prompt has an invalid format """ formatted_prompts = [] # Process each prompt according to its role and format with attacker prompt for prompt in prompts: if isinstance(prompt, str): # Format the attacker prompt with the user prompt formatted_prompts.append(self.attacker_prompt.format(prompt=prompt)) elif isinstance(prompt, list) and prompt and prompt[-1].get("role") == "human": # Format the attacker prompt with the last message content formatted_prompts.append(self.attacker_prompt.format(prompt=prompt[-1]["content"])) else: raise ValueError("Prompt must be either a string or a list with the last message from human") # Use the model's abatch method to process all prompts at once responses = await self.model.abatch(formatted_prompts) # Format the responses back according to their original type transformed_prompts = [] for i, prompt in enumerate(prompts): if isinstance(prompt, str): # For string prompts, return the model response content with post-processing transformed_prompts.append(self.post_process_response(responses[i]["content"])) elif isinstance(prompt, list): # For list prompts, keep all messages except the last and append the transformed message transformed_messages = prompt[:-1] # Post-process the response processed_response = self.post_process_response(responses[i]["content"]) transformed_messages.append({"role": "human", "content": processed_response}) transformed_prompts.append(transformed_messages) return transformed_prompts
[docs] async def stream_abatch(self, prompts: List[Union[str, List[Dict[str, str]]]]) -> AsyncGenerator[List[Union[str, List[Dict[str, str]]]], None]: """ Apply the model attack to a batch of prompts asynchronously. Args: prompts: A list of prompts to apply the attack to Returns: An async generator yielding transformed prompts as they are processed Raises: ValueError: If any prompt has an invalid format """ formatted_prompts = [] # Process each prompt according to its role and format with attacker prompt for prompt in prompts: if isinstance(prompt, str): # Format the attacker prompt with the user prompt formatted_prompts.append(self.attacker_prompt.format(prompt=prompt)) elif isinstance(prompt, list) and prompt and prompt[-1].get("role") == "human": # Format the attacker prompt with the last message content formatted_prompts.append(self.attacker_prompt.format(prompt=prompt[-1]["content"])) else: raise ValueError("Prompt must be either a string or a list with the last message from human") i = 0 async for response in self.model.stream_abatch(formatted_prompts): if isinstance(prompts[i], str): # For string prompts, return the model response content with post-processing yield self.post_process_response(response["content"]) elif isinstance(prompts[i], list): # For list prompts, keep all messages except the last and append the transformed message transformed_messages = prompts[i][:-1] # Post-process the response processed_response = self.post_process_response(response["content"]) transformed_messages.append({"role": "human", "content": processed_response}) yield transformed_messages i += 1
[docs] def get_name(self) -> str: """ Get the name of the attack. Returns: The custom name if provided, otherwise the class name """ if self._name: return self._name return self.__class__.__name__
[docs] def get_description(self) -> str: """ Get the description of the attack. Returns: The custom description if provided, otherwise a default description """ if self._description: return self._description return f"Model-based attack using {self.model.__class__.__name__} with prompt: {self.attacker_prompt}"