Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ API Reference
:toctree: _autosummary/

class_name_to_snake_case
AttackIdentifier
ConverterIdentifier
Identifiable
Identifier
Expand Down
2 changes: 1 addition & 1 deletion pyrit/analytics/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats
raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}")

outcome = attack.outcome
attack_type = attack.attack_identifier.get("type", "unknown")
attack_type = attack.attack_identifier.class_name if attack.attack_identifier else "unknown"

if outcome == AttackOutcome.SUCCESS:
overall_counts["successes"] += 1
Expand Down
30 changes: 11 additions & 19 deletions pyrit/exceptions/exception_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from contextvars import ContextVar
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Optional, Union
from typing import Any, Optional

from pyrit.identifiers import Identifier
from pyrit.identifiers import AttackIdentifier, Identifier


class ComponentRole(Enum):
Expand Down Expand Up @@ -61,11 +61,11 @@ class ExecutionContext:
# The attack strategy class name (e.g., "PromptSendingAttack")
attack_strategy_name: Optional[str] = None

# The identifier from the attack strategy's get_identifier()
attack_identifier: Optional[Dict[str, Any]] = None
# The identifier for the attack strategy
attack_identifier: Optional[AttackIdentifier] = None

# The identifier from the component's get_identifier() (target, scorer, etc.)
component_identifier: Optional[Dict[str, Any]] = None
component_identifier: Optional[Identifier] = None

# The objective target conversation ID if available
objective_target_conversation_id: Optional[str] = None
Expand Down Expand Up @@ -192,8 +192,8 @@ def execution_context(
*,
component_role: ComponentRole,
attack_strategy_name: Optional[str] = None,
attack_identifier: Optional[Dict[str, Any]] = None,
component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None,
attack_identifier: Optional[AttackIdentifier] = None,
component_identifier: Optional[Identifier] = None,
objective_target_conversation_id: Optional[str] = None,
objective: Optional[str] = None,
) -> ExecutionContextManager:
Expand All @@ -203,9 +203,8 @@ def execution_context(
Args:
component_role: The role of the component being executed.
attack_strategy_name: The name of the attack strategy class.
attack_identifier: The identifier from attack.get_identifier().
attack_identifier: The attack identifier.
component_identifier: The identifier from component.get_identifier().
Can be an Identifier object or a dict (legacy format).
objective_target_conversation_id: The objective target conversation ID if available.
objective: The attack objective if available.

Expand All @@ -215,22 +214,15 @@ def execution_context(
# Extract endpoint and component_name from component_identifier if available
endpoint = None
component_name = None
component_id_dict: Optional[Dict[str, Any]] = None
if component_identifier:
if isinstance(component_identifier, Identifier):
endpoint = getattr(component_identifier, "endpoint", None)
component_name = component_identifier.class_name
component_id_dict = component_identifier.to_dict()
else:
endpoint = component_identifier.get("endpoint")
component_name = component_identifier.get("__type__")
component_id_dict = component_identifier
endpoint = getattr(component_identifier, "endpoint", None)
component_name = component_identifier.class_name

context = ExecutionContext(
component_role=component_role,
attack_strategy_name=attack_strategy_name,
attack_identifier=attack_identifier,
component_identifier=component_id_dict,
component_identifier=component_identifier,
objective_target_conversation_id=objective_target_conversation_id,
endpoint=endpoint,
component_name=component_name,
Expand Down
10 changes: 5 additions & 5 deletions pyrit/executor/attack/component/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import logging
import uuid
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence

from pyrit.common.utils import combine_dict
from pyrit.executor.attack.component.prepended_conversation_config import (
PrependedConversationConfig,
)
from pyrit.identifiers import TargetIdentifier
from pyrit.identifiers import AttackIdentifier, TargetIdentifier
from pyrit.memory import CentralMemory
from pyrit.message_normalizer import ConversationContextNormalizer
from pyrit.models import ChatMessageRole, Message, MessagePiece, Score
Expand Down Expand Up @@ -54,8 +54,8 @@ def get_adversarial_chat_messages(
prepended_conversation: List[Message],
*,
adversarial_chat_conversation_id: str,
attack_identifier: Dict[str, str],
adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]],
attack_identifier: AttackIdentifier,
adversarial_chat_target_identifier: TargetIdentifier,
labels: Optional[Dict[str, str]] = None,
) -> List[Message]:
"""
Expand Down Expand Up @@ -183,7 +183,7 @@ class ConversationManager:
def __init__(
self,
*,
attack_identifier: Dict[str, str],
attack_identifier: AttackIdentifier,
prompt_normalizer: Optional[PromptNormalizer] = None,
):
"""
Expand Down
60 changes: 59 additions & 1 deletion pyrit/executor/attack/core/attack_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
StrategyEventData,
StrategyEventHandler,
)
from pyrit.identifiers import AttackIdentifier, Identifiable
from pyrit.memory.central_memory import CentralMemory
from pyrit.models import (
AttackOutcome,
Expand Down Expand Up @@ -224,7 +225,7 @@ def _log_attack_outcome(self, result: AttackResult) -> None:
self._logger.info(message)


class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC):
class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable[AttackIdentifier], ABC):
"""
Abstract base class for attack strategies.
Defines the interface for executing attacks and handling results.
Expand Down Expand Up @@ -258,6 +259,54 @@ def __init__(
)
self._objective_target = objective_target
self._params_type = params_type
# Guard so subclasses that set converters before calling super() aren't clobbered
if not hasattr(self, "_request_converters"):
self._request_converters: list[Any] = []
if not hasattr(self, "_response_converters"):
self._response_converters: list[Any] = []

def _build_identifier(self) -> AttackIdentifier:
"""
Build the typed identifier for this attack strategy.

Captures the objective target, optional scorer, and converter pipeline.
This is the *stable* strategy-level identifier that does not change
between calls to ``execute_async``.

Returns:
AttackIdentifier: The constructed identifier.
"""
# Get target identifier
objective_target_identifier = self.get_objective_target().get_identifier()

# Get scorer identifier if present
scorer_identifier = None
scoring_config = self.get_attack_scoring_config()
if scoring_config and scoring_config.objective_scorer:
scorer_identifier = scoring_config.objective_scorer.get_identifier()

# Get request converter identifiers if present
request_converter_ids = None
if self._request_converters:
request_converter_ids = [
converter.get_identifier() for config in self._request_converters for converter in config.converters
]

# Get response converter identifiers if present
response_converter_ids = None
if self._response_converters:
response_converter_ids = [
converter.get_identifier() for config in self._response_converters for converter in config.converters
]

return AttackIdentifier(
class_name=self.__class__.__name__,
class_module=self.__class__.__module__,
objective_target_identifier=objective_target_identifier,
objective_scorer_identifier=scorer_identifier,
request_converter_identifiers=request_converter_ids or None,
response_converter_identifiers=response_converter_ids or None,
)

@property
def params_type(self) -> Type[AttackParameters]:
Expand Down Expand Up @@ -291,6 +340,15 @@ def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]:
"""
return None

def get_request_converters(self) -> list[Any]:
"""
Get request converter configurations used by this strategy.

Returns:
list[Any]: The list of request PromptConverterConfiguration objects.
"""
return self._request_converters

@overload
async def execute_async(
self,
Expand Down
5 changes: 3 additions & 2 deletions pyrit/executor/attack/multi_turn/tree_of_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from pyrit.executor.attack.core.attack_strategy import AttackStrategy
from pyrit.executor.attack.multi_turn import MultiTurnAttackContext
from pyrit.identifiers import AttackIdentifier
from pyrit.memory import CentralMemory
from pyrit.models import (
AttackOutcome,
Expand Down Expand Up @@ -267,7 +268,7 @@ def __init__(
request_converters: List[PromptConverterConfiguration],
response_converters: List[PromptConverterConfiguration],
auxiliary_scorers: Optional[List[Scorer]],
attack_id: dict[str, str],
attack_id: AttackIdentifier,
attack_strategy_name: str,
memory_labels: Optional[dict[str, str]] = None,
parent_id: Optional[str] = None,
Expand All @@ -289,7 +290,7 @@ def __init__(
request_converters (List[PromptConverterConfiguration]): Converters for request normalization
response_converters (List[PromptConverterConfiguration]): Converters for response normalization
auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response
attack_id (dict[str, str]): Unique identifier for the attack.
attack_id (AttackIdentifier): Unique identifier for the attack.
attack_strategy_name (str): Name of the attack strategy for execution context.
memory_labels (Optional[dict[str, str]]): Labels for memory storage.
parent_id (Optional[str]): ID of the parent node, if this is a child node
Expand Down
6 changes: 2 additions & 4 deletions pyrit/executor/attack/printer/console_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,8 @@ async def print_summary_async(self, result: AttackResult) -> None:

# Extract attack type name from attack_identifier
attack_type = "Unknown"
if isinstance(result.attack_identifier, dict) and "__type__" in result.attack_identifier:
attack_type = result.attack_identifier["__type__"]
elif isinstance(result.attack_identifier, str):
attack_type = result.attack_identifier
if result.attack_identifier:
attack_type = result.attack_identifier.class_name

self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN)
self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN)
Expand Down
2 changes: 1 addition & 1 deletion pyrit/executor/attack/printer/markdown_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]:
markdown_lines.append("|-------|-------|")
markdown_lines.append(f"| **Objective** | {result.objective} |")

attack_type = result.attack_identifier.get("__type__", "Unknown")
attack_type = result.attack_identifier.class_name if result.attack_identifier else "Unknown"

markdown_lines.append(f"| **Attack Type** | `{attack_type}` |")
markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |")
Expand Down
6 changes: 5 additions & 1 deletion pyrit/executor/benchmark/fairness_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PromptSendingAttack,
)
from pyrit.executor.core import Strategy, StrategyContext
from pyrit.identifiers import AttackIdentifier
from pyrit.memory import CentralMemory
from pyrit.models import (
AttackOutcome,
Expand Down Expand Up @@ -195,7 +196,10 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta
conversation_id=str(uuid.UUID(int=0)),
objective=context.generated_objective,
outcome=AttackOutcome.FAILURE,
attack_identifier=self.get_identifier(),
attack_identifier=AttackIdentifier(
class_name=self.__class__.__name__,
class_module=self.__class__.__module__,
),
)

return last_attack_result
Expand Down
13 changes: 0 additions & 13 deletions pyrit/executor/core/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,6 @@ def __init__(
default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS") or "{}"
)

def get_identifier(self) -> Dict[str, str]:
"""
Get a serializable identifier for the strategy instance.

Returns:
dict: A dictionary containing the type, module, and unique ID of the strategy.
"""
return {
"__type__": self.__class__.__name__,
"__module__": self.__class__.__module__,
"id": str(self._id),
}

def _register_event_handler(self, event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT]) -> None:
"""
Register an event handler for strategy events.
Expand Down
21 changes: 20 additions & 1 deletion pyrit/executor/promptgen/anecdoctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PromptGeneratorStrategyContext,
PromptGeneratorStrategyResult,
)
from pyrit.identifiers import AttackIdentifier, Identifiable
from pyrit.models import (
Message,
)
Expand Down Expand Up @@ -67,7 +68,10 @@ class AnecdoctorResult(PromptGeneratorStrategyResult):
generated_content: Message


class AnecdoctorGenerator(PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult]):
class AnecdoctorGenerator(
PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult],
Identifiable[AttackIdentifier],
):
"""
Implementation of the Anecdoctor prompt generation strategy.

Expand Down Expand Up @@ -131,6 +135,21 @@ def __init__(
else:
self._system_prompt_template = self._load_prompt_from_yaml(yaml_filename=self._ANECDOCTOR_USE_FEWSHOT_YAML)

def _build_identifier(self) -> AttackIdentifier:
"""
Build the typed identifier for this prompt generator.

Returns:
AttackIdentifier: The constructed identifier.
"""
objective_target_identifier = self._objective_target.get_identifier()

return AttackIdentifier(
class_name=self.__class__.__name__,
class_module=self.__class__.__module__,
objective_target_identifier=objective_target_identifier,
)

def _validate_context(self, *, context: AnecdoctorContext) -> None:
"""
Validate the context before executing the prompt generation.
Expand Down
21 changes: 20 additions & 1 deletion pyrit/executor/promptgen/fuzzer/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
PromptGeneratorStrategyResult,
)
from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter
from pyrit.identifiers import AttackIdentifier, Identifiable
from pyrit.memory import CentralMemory
from pyrit.models import (
Message,
Expand Down Expand Up @@ -492,7 +493,10 @@ def print_templates_only(self, result: FuzzerResult) -> None:
print("No successful templates found.")


class FuzzerGenerator(PromptGeneratorStrategy[FuzzerContext, FuzzerResult]):
class FuzzerGenerator(
PromptGeneratorStrategy[FuzzerContext, FuzzerResult],
Identifiable[AttackIdentifier],
):
"""
Implementation of the Fuzzer prompt generation strategy using Monte Carlo Tree Search (MCTS).
Expand Down Expand Up @@ -675,6 +679,21 @@ def __init__(
# Initialize utilities
self._prompt_normalizer = prompt_normalizer or PromptNormalizer()

def _build_identifier(self) -> AttackIdentifier:
"""
Build the typed identifier for this prompt generator.
Returns:
AttackIdentifier: The constructed identifier.
"""
objective_target_identifier = self._objective_target.get_identifier()

return AttackIdentifier(
class_name=self.__class__.__name__,
class_module=self.__class__.__module__,
objective_target_identifier=objective_target_identifier,
)

def _validate_inputs(
self,
*,
Expand Down
Loading