diff --git a/doc/api.rst b/doc/api.rst index 780ae04206..523efb0ad3 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -272,6 +272,7 @@ API Reference :toctree: _autosummary/ class_name_to_snake_case + AttackIdentifier ConverterIdentifier Identifiable Identifier diff --git a/pyrit/analytics/result_analysis.py b/pyrit/analytics/result_analysis.py index 7db3c02e87..a6e260af39 100644 --- a/pyrit/analytics/result_analysis.py +++ b/pyrit/analytics/result_analysis.py @@ -62,7 +62,7 @@ def analyze_results(attack_results: list[AttackResult]) -> dict[str, AttackStats raise TypeError(f"Expected AttackResult, got {type(attack).__name__}: {attack!r}") outcome = attack.outcome - attack_type = attack.attack_identifier.get("type", "unknown") + attack_type = attack.attack_identifier.class_name if attack.attack_identifier else "unknown" if outcome == AttackOutcome.SUCCESS: overall_counts["successes"] += 1 diff --git a/pyrit/exceptions/exception_context.py b/pyrit/exceptions/exception_context.py index bce5b92382..b88c92a017 100644 --- a/pyrit/exceptions/exception_context.py +++ b/pyrit/exceptions/exception_context.py @@ -13,9 +13,9 @@ from contextvars import ContextVar from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, Optional, Union +from typing import Any, Optional -from pyrit.identifiers import Identifier +from pyrit.identifiers import AttackIdentifier, Identifier class ComponentRole(Enum): @@ -61,11 +61,11 @@ class ExecutionContext: # The attack strategy class name (e.g., "PromptSendingAttack") attack_strategy_name: Optional[str] = None - # The identifier from the attack strategy's get_identifier() - attack_identifier: Optional[Dict[str, Any]] = None + # The identifier for the attack strategy + attack_identifier: Optional[AttackIdentifier] = None # The identifier from the component's get_identifier() (target, scorer, etc.) - component_identifier: Optional[Dict[str, Any]] = None + component_identifier: Optional[Identifier] = None # The objective target conversation ID if available objective_target_conversation_id: Optional[str] = None @@ -192,8 +192,8 @@ def execution_context( *, component_role: ComponentRole, attack_strategy_name: Optional[str] = None, - attack_identifier: Optional[Dict[str, Any]] = None, - component_identifier: Optional[Union[Identifier, Dict[str, Any]]] = None, + attack_identifier: Optional[AttackIdentifier] = None, + component_identifier: Optional[Identifier] = None, objective_target_conversation_id: Optional[str] = None, objective: Optional[str] = None, ) -> ExecutionContextManager: @@ -203,9 +203,8 @@ def execution_context( Args: component_role: The role of the component being executed. attack_strategy_name: The name of the attack strategy class. - attack_identifier: The identifier from attack.get_identifier(). + attack_identifier: The attack identifier. component_identifier: The identifier from component.get_identifier(). - Can be an Identifier object or a dict (legacy format). objective_target_conversation_id: The objective target conversation ID if available. objective: The attack objective if available. @@ -215,22 +214,15 @@ def execution_context( # Extract endpoint and component_name from component_identifier if available endpoint = None component_name = None - component_id_dict: Optional[Dict[str, Any]] = None if component_identifier: - if isinstance(component_identifier, Identifier): - endpoint = getattr(component_identifier, "endpoint", None) - component_name = component_identifier.class_name - component_id_dict = component_identifier.to_dict() - else: - endpoint = component_identifier.get("endpoint") - component_name = component_identifier.get("__type__") - component_id_dict = component_identifier + endpoint = getattr(component_identifier, "endpoint", None) + component_name = component_identifier.class_name context = ExecutionContext( component_role=component_role, attack_strategy_name=attack_strategy_name, attack_identifier=attack_identifier, - component_identifier=component_id_dict, + component_identifier=component_identifier, objective_target_conversation_id=objective_target_conversation_id, endpoint=endpoint, component_name=component_name, diff --git a/pyrit/executor/attack/component/conversation_manager.py b/pyrit/executor/attack/component/conversation_manager.py index 75228c0f2d..6b6e648da3 100644 --- a/pyrit/executor/attack/component/conversation_manager.py +++ b/pyrit/executor/attack/component/conversation_manager.py @@ -4,13 +4,13 @@ import logging import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence from pyrit.common.utils import combine_dict from pyrit.executor.attack.component.prepended_conversation_config import ( PrependedConversationConfig, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.message_normalizer import ConversationContextNormalizer from pyrit.models import ChatMessageRole, Message, MessagePiece, Score @@ -54,8 +54,8 @@ def get_adversarial_chat_messages( prepended_conversation: List[Message], *, adversarial_chat_conversation_id: str, - attack_identifier: Dict[str, str], - adversarial_chat_target_identifier: Union[TargetIdentifier, Dict[str, Any]], + attack_identifier: AttackIdentifier, + adversarial_chat_target_identifier: TargetIdentifier, labels: Optional[Dict[str, str]] = None, ) -> List[Message]: """ @@ -183,7 +183,7 @@ class ConversationManager: def __init__( self, *, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, prompt_normalizer: Optional[PromptNormalizer] = None, ): """ diff --git a/pyrit/executor/attack/core/attack_strategy.py b/pyrit/executor/attack/core/attack_strategy.py index 98fc4c6abe..a442d81548 100644 --- a/pyrit/executor/attack/core/attack_strategy.py +++ b/pyrit/executor/attack/core/attack_strategy.py @@ -20,6 +20,7 @@ StrategyEventData, StrategyEventHandler, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory.central_memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -224,7 +225,7 @@ def _log_attack_outcome(self, result: AttackResult) -> None: self._logger.info(message) -class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], ABC): +class AttackStrategy(Strategy[AttackStrategyContextT, AttackStrategyResultT], Identifiable[AttackIdentifier], ABC): """ Abstract base class for attack strategies. Defines the interface for executing attacks and handling results. @@ -258,6 +259,54 @@ def __init__( ) self._objective_target = objective_target self._params_type = params_type + # Guard so subclasses that set converters before calling super() aren't clobbered + if not hasattr(self, "_request_converters"): + self._request_converters: list[Any] = [] + if not hasattr(self, "_response_converters"): + self._response_converters: list[Any] = [] + + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this attack strategy. + + Captures the objective target, optional scorer, and converter pipeline. + This is the *stable* strategy-level identifier that does not change + between calls to ``execute_async``. + + Returns: + AttackIdentifier: The constructed identifier. + """ + # Get target identifier + objective_target_identifier = self.get_objective_target().get_identifier() + + # Get scorer identifier if present + scorer_identifier = None + scoring_config = self.get_attack_scoring_config() + if scoring_config and scoring_config.objective_scorer: + scorer_identifier = scoring_config.objective_scorer.get_identifier() + + # Get request converter identifiers if present + request_converter_ids = None + if self._request_converters: + request_converter_ids = [ + converter.get_identifier() for config in self._request_converters for converter in config.converters + ] + + # Get response converter identifiers if present + response_converter_ids = None + if self._response_converters: + response_converter_ids = [ + converter.get_identifier() for config in self._response_converters for converter in config.converters + ] + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + objective_scorer_identifier=scorer_identifier, + request_converter_identifiers=request_converter_ids or None, + response_converter_identifiers=response_converter_ids or None, + ) @property def params_type(self) -> Type[AttackParameters]: @@ -291,6 +340,15 @@ def get_attack_scoring_config(self) -> Optional[AttackScoringConfig]: """ return None + def get_request_converters(self) -> list[Any]: + """ + Get request converter configurations used by this strategy. + + Returns: + list[Any]: The list of request PromptConverterConfiguration objects. + """ + return self._request_converters + @overload async def execute_async( self, diff --git a/pyrit/executor/attack/multi_turn/tree_of_attacks.py b/pyrit/executor/attack/multi_turn/tree_of_attacks.py index cfad81d3ab..48d6025a47 100644 --- a/pyrit/executor/attack/multi_turn/tree_of_attacks.py +++ b/pyrit/executor/attack/multi_turn/tree_of_attacks.py @@ -37,6 +37,7 @@ ) from pyrit.executor.attack.core.attack_strategy import AttackStrategy from pyrit.executor.attack.multi_turn import MultiTurnAttackContext +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -267,7 +268,7 @@ def __init__( request_converters: List[PromptConverterConfiguration], response_converters: List[PromptConverterConfiguration], auxiliary_scorers: Optional[List[Scorer]], - attack_id: dict[str, str], + attack_id: AttackIdentifier, attack_strategy_name: str, memory_labels: Optional[dict[str, str]] = None, parent_id: Optional[str] = None, @@ -289,7 +290,7 @@ def __init__( request_converters (List[PromptConverterConfiguration]): Converters for request normalization response_converters (List[PromptConverterConfiguration]): Converters for response normalization auxiliary_scorers (Optional[List[Scorer]]): Additional scorers for the response - attack_id (dict[str, str]): Unique identifier for the attack. + attack_id (AttackIdentifier): Unique identifier for the attack. attack_strategy_name (str): Name of the attack strategy for execution context. memory_labels (Optional[dict[str, str]]): Labels for memory storage. parent_id (Optional[str]): ID of the parent node, if this is a child node diff --git a/pyrit/executor/attack/printer/console_printer.py b/pyrit/executor/attack/printer/console_printer.py index 7d7110d0ae..c71b40b31c 100644 --- a/pyrit/executor/attack/printer/console_printer.py +++ b/pyrit/executor/attack/printer/console_printer.py @@ -258,10 +258,8 @@ async def print_summary_async(self, result: AttackResult) -> None: # Extract attack type name from attack_identifier attack_type = "Unknown" - if isinstance(result.attack_identifier, dict) and "__type__" in result.attack_identifier: - attack_type = result.attack_identifier["__type__"] - elif isinstance(result.attack_identifier, str): - attack_type = result.attack_identifier + if result.attack_identifier: + attack_type = result.attack_identifier.class_name self._print_colored(f"{self._indent * 2}• Attack Type: {attack_type}", Fore.CYAN) self._print_colored(f"{self._indent * 2}• Conversation ID: {result.conversation_id}", Fore.CYAN) diff --git a/pyrit/executor/attack/printer/markdown_printer.py b/pyrit/executor/attack/printer/markdown_printer.py index 27838a46c2..e62a80cabc 100644 --- a/pyrit/executor/attack/printer/markdown_printer.py +++ b/pyrit/executor/attack/printer/markdown_printer.py @@ -493,7 +493,7 @@ async def _get_summary_markdown_async(self, result: AttackResult) -> List[str]: markdown_lines.append("|-------|-------|") markdown_lines.append(f"| **Objective** | {result.objective} |") - attack_type = result.attack_identifier.get("__type__", "Unknown") + attack_type = result.attack_identifier.class_name if result.attack_identifier else "Unknown" markdown_lines.append(f"| **Attack Type** | `{attack_type}` |") markdown_lines.append(f"| **Conversation ID** | `{result.conversation_id}` |") diff --git a/pyrit/executor/benchmark/fairness_bias.py b/pyrit/executor/benchmark/fairness_bias.py index b894757eba..3e6bc8b785 100644 --- a/pyrit/executor/benchmark/fairness_bias.py +++ b/pyrit/executor/benchmark/fairness_bias.py @@ -17,6 +17,7 @@ PromptSendingAttack, ) from pyrit.executor.core import Strategy, StrategyContext +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory from pyrit.models import ( AttackOutcome, @@ -195,7 +196,10 @@ async def _perform_async(self, *, context: FairnessBiasBenchmarkContext) -> Atta conversation_id=str(uuid.UUID(int=0)), objective=context.generated_objective, outcome=AttackOutcome.FAILURE, - attack_identifier=self.get_identifier(), + attack_identifier=AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + ), ) return last_attack_result diff --git a/pyrit/executor/core/strategy.py b/pyrit/executor/core/strategy.py index 7fc48a4173..1ef0f94cff 100644 --- a/pyrit/executor/core/strategy.py +++ b/pyrit/executor/core/strategy.py @@ -176,19 +176,6 @@ def __init__( default_values.get_non_required_value(env_var_name="GLOBAL_MEMORY_LABELS") or "{}" ) - def get_identifier(self) -> Dict[str, str]: - """ - Get a serializable identifier for the strategy instance. - - Returns: - dict: A dictionary containing the type, module, and unique ID of the strategy. - """ - return { - "__type__": self.__class__.__name__, - "__module__": self.__class__.__module__, - "id": str(self._id), - } - def _register_event_handler(self, event_handler: StrategyEventHandler[StrategyContextT, StrategyResultT]) -> None: """ Register an event handler for strategy events. diff --git a/pyrit/executor/promptgen/anecdoctor.py b/pyrit/executor/promptgen/anecdoctor.py index b627e477d8..82ecb25e5f 100644 --- a/pyrit/executor/promptgen/anecdoctor.py +++ b/pyrit/executor/promptgen/anecdoctor.py @@ -19,6 +19,7 @@ PromptGeneratorStrategyContext, PromptGeneratorStrategyResult, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.models import ( Message, ) @@ -67,7 +68,10 @@ class AnecdoctorResult(PromptGeneratorStrategyResult): generated_content: Message -class AnecdoctorGenerator(PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult]): +class AnecdoctorGenerator( + PromptGeneratorStrategy[AnecdoctorContext, AnecdoctorResult], + Identifiable[AttackIdentifier], +): """ Implementation of the Anecdoctor prompt generation strategy. @@ -131,6 +135,21 @@ def __init__( else: self._system_prompt_template = self._load_prompt_from_yaml(yaml_filename=self._ANECDOCTOR_USE_FEWSHOT_YAML) + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this prompt generator. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._objective_target.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + ) + def _validate_context(self, *, context: AnecdoctorContext) -> None: """ Validate the context before executing the prompt generation. diff --git a/pyrit/executor/promptgen/fuzzer/fuzzer.py b/pyrit/executor/promptgen/fuzzer/fuzzer.py index 7949a398a9..93360a16dd 100644 --- a/pyrit/executor/promptgen/fuzzer/fuzzer.py +++ b/pyrit/executor/promptgen/fuzzer/fuzzer.py @@ -24,6 +24,7 @@ PromptGeneratorStrategyResult, ) from pyrit.executor.promptgen.fuzzer.fuzzer_converter_base import FuzzerConverter +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( Message, @@ -492,7 +493,10 @@ def print_templates_only(self, result: FuzzerResult) -> None: print("No successful templates found.") -class FuzzerGenerator(PromptGeneratorStrategy[FuzzerContext, FuzzerResult]): +class FuzzerGenerator( + PromptGeneratorStrategy[FuzzerContext, FuzzerResult], + Identifiable[AttackIdentifier], +): """ Implementation of the Fuzzer prompt generation strategy using Monte Carlo Tree Search (MCTS). @@ -675,6 +679,21 @@ def __init__( # Initialize utilities self._prompt_normalizer = prompt_normalizer or PromptNormalizer() + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this prompt generator. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._objective_target.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + ) + def _validate_inputs( self, *, diff --git a/pyrit/executor/workflow/xpia.py b/pyrit/executor/workflow/xpia.py index 1579f7e868..2f5baf83f4 100644 --- a/pyrit/executor/workflow/xpia.py +++ b/pyrit/executor/workflow/xpia.py @@ -14,6 +14,7 @@ WorkflowResult, WorkflowStrategy, ) +from pyrit.identifiers import AttackIdentifier, Identifiable from pyrit.memory import CentralMemory from pyrit.models import ( Message, @@ -127,7 +128,7 @@ def status(self) -> XPIAStatus: return XPIAStatus.SUCCESS if self.success else XPIAStatus.FAILURE -class XPIAWorkflow(WorkflowStrategy[XPIAContext, XPIAResult]): +class XPIAWorkflow(WorkflowStrategy[XPIAContext, XPIAResult], Identifiable[AttackIdentifier]): """ Implementation of Cross-Domain Prompt Injection Attack (XPIA) workflow. @@ -174,6 +175,26 @@ def __init__( self._prompt_normalizer = prompt_normalizer or PromptNormalizer() self._memory = CentralMemory.get_memory_instance() + def _build_identifier(self) -> AttackIdentifier: + """ + Build the typed identifier for this XPIA workflow. + + Returns: + AttackIdentifier: The constructed identifier. + """ + objective_target_identifier = self._attack_setup_target.get_identifier() + + scorer_identifier = None + if self._scorer: + scorer_identifier = self._scorer.get_identifier() + + return AttackIdentifier( + class_name=self.__class__.__name__, + class_module=self.__class__.__module__, + objective_target_identifier=objective_target_identifier, + objective_scorer_identifier=scorer_identifier, + ) + def _validate_context(self, *, context: XPIAContext) -> None: """ Validate the XPIA context before execution. diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index 30c501894c..0c8fe13cac 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -3,12 +3,13 @@ """Identifiers module for PyRIT components.""" +from pyrit.identifiers.attack_identifier import AttackIdentifier from pyrit.identifiers.class_name_utils import ( class_name_to_snake_case, snake_case_to_class_name, ) from pyrit.identifiers.converter_identifier import ConverterIdentifier -from pyrit.identifiers.identifiable import Identifiable, IdentifierT, LegacyIdentifiable +from pyrit.identifiers.identifiable import Identifiable, IdentifierT from pyrit.identifiers.identifier import ( Identifier, IdentifierType, @@ -17,13 +18,13 @@ from pyrit.identifiers.target_identifier import TargetIdentifier __all__ = [ + "AttackIdentifier", "class_name_to_snake_case", "ConverterIdentifier", "Identifiable", "Identifier", "IdentifierT", "IdentifierType", - "LegacyIdentifiable", "ScorerIdentifier", "snake_case_to_class_name", "TargetIdentifier", diff --git a/pyrit/identifiers/attack_identifier.py b/pyrit/identifiers/attack_identifier.py new file mode 100644 index 0000000000..1ab4c4e8f5 --- /dev/null +++ b/pyrit/identifiers/attack_identifier.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +from pyrit.identifiers.converter_identifier import ConverterIdentifier +from pyrit.identifiers.identifier import Identifier +from pyrit.identifiers.scorer_identifier import ScorerIdentifier +from pyrit.identifiers.target_identifier import TargetIdentifier + + +@dataclass(frozen=True) +class AttackIdentifier(Identifier): + """ + Typed identifier for an attack strategy instance. + + Captures the configuration that makes one attack strategy meaningfully + different from another: the objective target, optional scorer, and converter + pipeline. These do not change between calls to ``execute_async``. + """ + + objective_target_identifier: Optional[TargetIdentifier] = None + objective_scorer_identifier: Optional[ScorerIdentifier] = None + request_converter_identifiers: Optional[List[ConverterIdentifier]] = None + response_converter_identifiers: Optional[List[ConverterIdentifier]] = None + + # Additional attack-specific params for subclass flexibility + attack_specific_params: Optional[Dict[str, Any]] = None + + @classmethod + def from_dict(cls: Type["AttackIdentifier"], data: dict[str, Any]) -> "AttackIdentifier": + """ + Deserialize an AttackIdentifier from a dictionary. + + Handles nested sub-identifiers (target, scorer, converters) by + recursively calling their own ``from_dict`` implementations. + + Args: + data: Dictionary containing the serialized identifier fields. + + Returns: + AttackIdentifier: The deserialized identifier. + """ + data = dict(data) + + if "objective_target_identifier" in data and isinstance(data["objective_target_identifier"], dict): + data["objective_target_identifier"] = TargetIdentifier.from_dict(data["objective_target_identifier"]) + + if "objective_scorer_identifier" in data and isinstance(data["objective_scorer_identifier"], dict): + data["objective_scorer_identifier"] = ScorerIdentifier.from_dict(data["objective_scorer_identifier"]) + + if "request_converter_identifiers" in data and data["request_converter_identifiers"] is not None: + data["request_converter_identifiers"] = [ + ConverterIdentifier.from_dict(c) if isinstance(c, dict) else c + for c in data["request_converter_identifiers"] + ] + + if "response_converter_identifiers" in data and data["response_converter_identifiers"] is not None: + data["response_converter_identifiers"] = [ + ConverterIdentifier.from_dict(c) if isinstance(c, dict) else c + for c in data["response_converter_identifiers"] + ] + + return super().from_dict(data) diff --git a/pyrit/identifiers/converter_identifier.py b/pyrit/identifiers/converter_identifier.py index 777672a932..3c86a0daca 100644 --- a/pyrit/identifiers/converter_identifier.py +++ b/pyrit/identifiers/converter_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, Dict, List, Optional, Tuple, Type from pyrit.identifiers.identifier import Identifier @@ -73,5 +73,4 @@ def from_dict(cls: Type["ConverterIdentifier"], data: dict[str, Any]) -> "Conver data["supported_output_types"] = () # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(ConverterIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/identifiers/identifiable.py b/pyrit/identifiers/identifiable.py index 94108e6eaf..7ef54c7f19 100644 --- a/pyrit/identifiers/identifiable.py +++ b/pyrit/identifiers/identifiable.py @@ -12,24 +12,6 @@ IdentifierT = TypeVar("IdentifierT", bound=Identifier) -class LegacyIdentifiable(ABC): - """ - Deprecated legacy interface for objects that can provide an identifier dictionary. - - This interface will eventually be replaced by Identifier dataclass. - Classes implementing this interface should return a dict describing their identity. - """ - - @abstractmethod - def get_identifier(self) -> dict[str, str]: - """Return a dictionary describing this object's identity.""" - pass - - def __str__(self) -> str: - """Return string representation of the identifier.""" - return f"{self.get_identifier()}" - - class Identifiable(ABC, Generic[IdentifierT]): """ Abstract base class for objects that can provide a typed identifier. diff --git a/pyrit/identifiers/identifier.py b/pyrit/identifiers/identifier.py index 64a754f0ae..ff7ee8832b 100644 --- a/pyrit/identifiers/identifier.py +++ b/pyrit/identifiers/identifier.py @@ -135,12 +135,12 @@ class Identifier: All component-specific identifier types should extend this with additional fields. """ - class_name: str # The actual class name, equivalent to __type__ (e.g., "SelfAskRefusalScorer") - class_module: str # The module path, equivalent to __module__ (e.g., "pyrit.score.self_ask_refusal_scorer") + class_name: str + class_module: str # Fields excluded from storage (STORAGE auto-expands to include HASH) - class_description: str = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) - identifier_type: IdentifierType = field(metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + class_description: str = field(default="", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) + identifier_type: IdentifierType = field(default="instance", metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) # Auto-computed fields snake_class_name: str = field(init=False, metadata={_EXCLUDE: {_ExcludeFrom.STORAGE}}) diff --git a/pyrit/identifiers/scorer_identifier.py b/pyrit/identifiers/scorer_identifier.py index d467504fe4..8dac5fc676 100644 --- a/pyrit/identifiers/scorer_identifier.py +++ b/pyrit/identifiers/scorer_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, Dict, List, Optional, Type from pyrit.identifiers.identifier import _MAX_STORAGE_LENGTH, Identifier from pyrit.models.score import ScoreType @@ -64,5 +64,4 @@ def from_dict(cls: Type["ScorerIdentifier"], data: dict[str, Any]) -> "ScorerIde ] # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(ScorerIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/identifiers/target_identifier.py b/pyrit/identifiers/target_identifier.py index b8924fb0c0..9d31170182 100644 --- a/pyrit/identifiers/target_identifier.py +++ b/pyrit/identifiers/target_identifier.py @@ -4,7 +4,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional, Type, cast +from typing import Any, Dict, Optional, Type from pyrit.identifiers.identifier import Identifier @@ -54,5 +54,4 @@ def from_dict(cls: Type["TargetIdentifier"], data: dict[str, Any]) -> "TargetIde TargetIdentifier: A new TargetIdentifier instance. """ # Delegate to parent class for standard processing - result = Identifier.from_dict.__func__(cls, data) # type: ignore[attr-defined] - return cast(TargetIdentifier, result) + return super().from_dict(data) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 5e3817c481..f53c4b6cd4 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -257,7 +257,7 @@ def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: Returns: Any: SQLAlchemy text condition with bound parameter. """ - return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.id') = :json_id").bindparams( + return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams( json_id=str(attack_id) ) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index bc8b9f5f36..5ae66e4d6c 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -32,7 +32,7 @@ import pyrit from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -220,7 +220,7 @@ def __init__(self, *, entry: MessagePiece): self.prompt_target_identifier = ( entry.prompt_target_identifier.to_dict() if entry.prompt_target_identifier else {} ) - self.attack_identifier = entry.attack_identifier + self.attack_identifier = entry.attack_identifier.to_dict() if entry.attack_identifier else {} self.original_value = entry.original_value self.original_value_data_type = entry.original_value_data_type # type: ignore @@ -256,6 +256,11 @@ def get_message_piece(self) -> MessagePiece: if self.prompt_target_identifier: target_id = TargetIdentifier.from_dict({**self.prompt_target_identifier, "pyrit_version": stored_version}) + # Reconstruct AttackIdentifier with the stored pyrit_version + attack_id: Optional[AttackIdentifier] = None + if self.attack_identifier: + attack_id = AttackIdentifier.from_dict({**self.attack_identifier, "pyrit_version": stored_version}) + message_piece = MessagePiece( role=self.role, original_value=self.original_value, @@ -270,7 +275,7 @@ def get_message_piece(self) -> MessagePiece: targeted_harm_categories=self.targeted_harm_categories, converter_identifiers=converter_ids, prompt_target_identifier=target_id, - attack_identifier=self.attack_identifier, + attack_identifier=attack_id, original_value_data_type=self.original_value_data_type, converted_value_data_type=self.converted_value_data_type, response_error=self.response_error, @@ -732,7 +737,7 @@ def __init__(self, *, entry: AttackResult): self.id = uuid.uuid4() self.conversation_id = entry.conversation_id self.objective = entry.objective - self.attack_identifier = entry.attack_identifier + self.attack_identifier = entry.attack_identifier.to_dict() if entry.attack_identifier else {} self.objective_sha256 = to_sha256(entry.objective) # Use helper method for UUID conversions @@ -833,7 +838,7 @@ def get_attack_result(self) -> AttackResult: return AttackResult( conversation_id=self.conversation_id, objective=self.objective, - attack_identifier=self.attack_identifier, + attack_identifier=AttackIdentifier.from_dict(self.attack_identifier) if self.attack_identifier else None, last_response=self.last_response.get_message_piece() if self.last_response else None, last_score=self.last_score.get_score() if self.last_score else None, executed_turns=self.executed_turns, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 30a251cf72..bca6a21817 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -163,7 +163,7 @@ def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any: Returns: Any: A SQLAlchemy text condition with bound parameters. """ - return text("JSON_EXTRACT(attack_identifier, '$.id') = :attack_id").bindparams(attack_id=str(attack_id)) + return text("JSON_EXTRACT(attack_identifier, '$.hash') = :attack_id").bindparams(attack_id=str(attack_id)) def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]]) -> Any: """ diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index dc9e3a1a9e..7f92612f59 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Any, Dict, Optional, TypeVar +from pyrit.identifiers import AttackIdentifier from pyrit.models.conversation_reference import ConversationReference, ConversationType from pyrit.models.message_piece import MessagePiece from pyrit.models.score import Score @@ -41,8 +42,8 @@ class AttackResult(StrategyResult): # Natural-language description of the attacker's objective objective: str - # Identifier of the attack (e.g., name, module) - attack_identifier: dict[str, str] + # Identifier of the attack strategy that produced this result + attack_identifier: Optional[AttackIdentifier] = None # Evidence # Model response generated in the final turn of the attack diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 460529475f..f07b045318 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, get_args from uuid import uuid4 -from pyrit.identifiers import ConverterIdentifier, ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError from pyrit.models.score import Score @@ -39,7 +39,7 @@ def __init__( prompt_metadata: Optional[Dict[str, Union[str, int]]] = None, converter_identifiers: Optional[List[Union[ConverterIdentifier, Dict[str, str]]]] = None, prompt_target_identifier: Optional[Union[TargetIdentifier, Dict[str, Any]]] = None, - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[Union[AttackIdentifier, Dict[str, str]]] = None, scorer_identifier: Optional[Union[ScorerIdentifier, Dict[str, str]]] = None, original_value_data_type: PromptDataType = "text", converted_value_data_type: Optional[PromptDataType] = None, @@ -118,7 +118,10 @@ def __init__( TargetIdentifier.normalize(prompt_target_identifier) if prompt_target_identifier else None ) - self.attack_identifier = attack_identifier or {} + # Handle attack_identifier: normalize to AttackIdentifier (handles dict with deprecation warning) + self.attack_identifier: Optional[AttackIdentifier] = ( + AttackIdentifier.normalize(attack_identifier) if attack_identifier else None + ) # Handle scorer_identifier: normalize to ScorerIdentifier (handles dict with deprecation warning) self.scorer_identifier: Optional[ScorerIdentifier] = ( @@ -283,7 +286,7 @@ def to_dict(self) -> dict[str, object]: "prompt_target_identifier": ( self.prompt_target_identifier.to_dict() if self.prompt_target_identifier else None ), - "attack_identifier": self.attack_identifier, + "attack_identifier": self.attack_identifier.to_dict() if self.attack_identifier else None, "scorer_identifier": self.scorer_identifier.to_dict() if self.scorer_identifier else None, "original_value_data_type": self.original_value_data_type, "original_value": self.original_value, diff --git a/pyrit/prompt_converter/prompt_converter.py b/pyrit/prompt_converter/prompt_converter.py index 756d3972e6..9af7ef9218 100644 --- a/pyrit/prompt_converter/prompt_converter.py +++ b/pyrit/prompt_converter/prompt_converter.py @@ -221,7 +221,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=self.__class__.__doc__ or "", - identifier_type="instance", supported_input_types=self.SUPPORTED_INPUT_TYPES, supported_output_types=self.SUPPORTED_OUTPUT_TYPES, sub_identifier=sub_identifier, diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index 8fd8aaaba0..00f2f0f578 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -14,6 +14,7 @@ execution_context, get_execution_context, ) +from pyrit.identifiers import AttackIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( Message, @@ -53,7 +54,7 @@ async def send_prompt_async( request_converter_configurations: list[PromptConverterConfiguration] = [], response_converter_configurations: list[PromptConverterConfiguration] = [], labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> Message: """ Send a single request to a target. @@ -67,7 +68,7 @@ async def send_prompt_async( response_converter_configurations (list[PromptConverterConfiguration], optional): Configurations for converting the response. Defaults to an empty list. labels (Optional[dict[str, str]], optional): Labels associated with the request. Defaults to None. - attack_identifier (Optional[dict[str, str]], optional): Identifier for the attack. Defaults to + attack_identifier (Optional[AttackIdentifier], optional): Identifier for the attack. Defaults to None. Raises: @@ -155,7 +156,7 @@ async def send_prompt_batch_to_target_async( requests: list[NormalizerRequest], target: PromptTarget, labels: Optional[dict[str, str]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, batch_size: int = 10, ) -> list[Message]: """ @@ -166,7 +167,7 @@ async def send_prompt_batch_to_target_async( target (PromptTarget): The target to which the prompts are sent. labels (Optional[dict[str, str]], optional): A dictionary of labels to be included with the request. Defaults to None. - attack_identifier (Optional[dict[str, str]], optional): A dictionary identifying the attack. + attack_identifier (Optional[AttackIdentifier], optional): The attack identifier. Defaults to None. batch_size (int, optional): The number of prompts to include in each batch. Defaults to 10. @@ -274,7 +275,7 @@ async def add_prepended_conversation_to_memory( conversation_id: str, should_convert: bool = True, converter_configurations: Optional[list[PromptConverterConfiguration]] = None, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, prepended_conversation: Optional[list[Message]] = None, ) -> Optional[list[Message]]: """ @@ -285,7 +286,7 @@ async def add_prepended_conversation_to_memory( should_convert (bool): Whether to convert the prepended conversation converter_configurations (Optional[list[PromptConverterConfiguration]]): Configurations for converting the request - attack_identifier (Optional[dict[str, str]]): Identifier for the attack + attack_identifier (Optional[AttackIdentifier]): Identifier for the attack prepended_conversation (Optional[list[Message]]): The conversation to prepend Returns: diff --git a/pyrit/prompt_target/common/prompt_chat_target.py b/pyrit/prompt_target/common/prompt_chat_target.py index 5eac0209f5..b837918295 100644 --- a/pyrit/prompt_target/common/prompt_chat_target.py +++ b/pyrit/prompt_target/common/prompt_chat_target.py @@ -4,6 +4,7 @@ import abc from typing import Optional +from pyrit.identifiers import AttackIdentifier from pyrit.models import MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig from pyrit.prompt_target.common.prompt_target import PromptTarget @@ -51,7 +52,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: """ diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index 653d008e65..29ba2cb47c 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -129,7 +129,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", - identifier_type="instance", endpoint=self._endpoint, model_name=model_name, temperature=temperature, diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index a334e87e72..bf6443afa6 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -12,7 +12,7 @@ import importlib.util import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional @@ -41,9 +41,9 @@ class InitializerMetadata(Identifier): Use get_class() to get the actual class. """ - display_name: str - required_env_vars: tuple[str, ...] - execution_order: int + display_name: str = field(kw_only=True) + required_env_vars: tuple[str, ...] = field(kw_only=True) + execution_order: int = field(kw_only=True) class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py index f95ad93986..5489c774f2 100644 --- a/pyrit/registry/class_registries/scenario_registry.py +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -40,11 +40,11 @@ class ScenarioMetadata(Identifier): Use get_class() to get the actual class. """ - default_strategy: str - all_strategies: tuple[str, ...] - aggregate_strategies: tuple[str, ...] - default_datasets: tuple[str, ...] - max_dataset_size: Optional[int] + default_strategy: str = field(kw_only=True) + all_strategies: tuple[str, ...] = field(kw_only=True) + aggregate_strategies: tuple[str, ...] = field(kw_only=True) + default_datasets: tuple[str, ...] = field(kw_only=True) + max_dataset_size: Optional[int] = field(kw_only=True) class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): diff --git a/pyrit/score/float_scale/float_scale_scorer.py b/pyrit/score/float_scale/float_scale_scorer.py index a1150b88e3..30650dc637 100644 --- a/pyrit/score/float_scale/float_scale_scorer.py +++ b/pyrit/score/float_scale/float_scale_scorer.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING, Optional from uuid import UUID from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier from pyrit.models import PromptDataType, Score, UnvalidatedScore from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.score.scorer import Scorer @@ -75,7 +76,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> UnvalidatedScore: score: UnvalidatedScore | None = None try: diff --git a/pyrit/score/float_scale/self_ask_likert_scorer.py b/pyrit/score/float_scale/self_ask_likert_scorer.py index c01fe65073..523e5d703c 100644 --- a/pyrit/score/float_scale/self_ask_likert_scorer.py +++ b/pyrit/score/float_scale/self_ask_likert_scorer.py @@ -280,6 +280,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op message_data_type=message_piece.converted_value_data_type, scored_prompt_id=message_piece.id, category=self._score_category, + attack_identifier=message_piece.attack_identifier, objective=objective, ) diff --git a/pyrit/score/float_scale/self_ask_scale_scorer.py b/pyrit/score/float_scale/self_ask_scale_scorer.py index 5e502681d0..6c2ed1e116 100644 --- a/pyrit/score/float_scale/self_ask_scale_scorer.py +++ b/pyrit/score/float_scale/self_ask_scale_scorer.py @@ -120,6 +120,7 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op scored_prompt_id=message_piece.id, category=self._category, objective=objective, + attack_identifier=message_piece.attack_identifier, ) score = unvalidated_score.to_score( diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 83cd795ecf..6765d907e1 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -26,7 +26,7 @@ pyrit_json_retry, remove_markdown_json, ) -from pyrit.identifiers import Identifiable, ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, Identifiable, ScorerIdentifier from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( ChatMessageRole, @@ -145,7 +145,6 @@ def _create_identifier( class_name=self.__class__.__name__, class_module=self.__class__.__module__, class_description=" ".join(self.__class__.__doc__.split()) if self.__class__.__doc__ else "", - identifier_type="instance", scorer_type=self.scorer_type, system_prompt_template=system_prompt_template, user_prompt_template=user_prompt_template, @@ -521,7 +520,7 @@ async def _score_value_with_llm( description_output_key: str = "description", metadata_output_key: str = "metadata", category_output_key: str = "category", - attack_identifier: Optional[Dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, ) -> UnvalidatedScore: """ Send a request to a target, and take care of retries. @@ -555,7 +554,7 @@ async def _score_value_with_llm( Defaults to "metadata". category_output_key (str): The key in the JSON response that contains the category. Defaults to "category". - attack_identifier (Optional[Dict[str, str]]): A dictionary containing attack-specific identifiers. + attack_identifier (Optional[AttackIdentifier]): The attack identifier. Defaults to None. Returns: @@ -569,9 +568,6 @@ async def _score_value_with_llm( """ conversation_id = str(uuid.uuid4()) - if attack_identifier: - attack_identifier["scored_prompt_id"] = str(scored_prompt_id) - prompt_target.set_system_prompt( system_prompt=system_prompt, conversation_id=conversation_id, diff --git a/tests/integration/mocks.py b/tests/integration/mocks.py index 47efa4cae9..c924fca504 100644 --- a/tests/integration/mocks.py +++ b/tests/integration/mocks.py @@ -5,6 +5,7 @@ from sqlalchemy import inspect +from pyrit.identifiers import AttackIdentifier from pyrit.memory import MemoryInterface, SQLiteMemory from pyrit.models import Message, MessagePiece from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute @@ -47,7 +48,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt diff --git a/tests/unit/analytics/test_result_analysis.py b/tests/unit/analytics/test_result_analysis.py index 44b2a56e8a..5a074aefc6 100644 --- a/tests/unit/analytics/test_result_analysis.py +++ b/tests/unit/analytics/test_result_analysis.py @@ -1,24 +1,27 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from typing import Optional + import pytest from pyrit.analytics.result_analysis import AttackStats, analyze_results +from pyrit.identifiers import AttackIdentifier from pyrit.models import AttackOutcome, AttackResult # helpers def make_attack( outcome: AttackOutcome, - attack_type: str | None = "default", + attack_type: Optional[str] = "default", conversation_id: str = "conv-1", ) -> AttackResult: """ Minimal valid AttackResult for analytics tests. """ - attack_identifier: dict[str, str] = {} + attack_identifier: Optional[AttackIdentifier] = None if attack_type is not None: - attack_identifier["type"] = attack_type + attack_identifier = AttackIdentifier(class_name=attack_type, class_module="tests.unit.analytics") return AttackResult( conversation_id=conversation_id, diff --git a/tests/unit/converter/test_persuasion_converter.py b/tests/unit/converter/test_persuasion_converter.py index 79c31b082d..9962945859 100644 --- a/tests/unit/converter/test_persuasion_converter.py +++ b/tests/unit/converter/test_persuasion_converter.py @@ -7,6 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import PersuasionConverter @@ -73,8 +74,8 @@ async def test_persuasion_converter_send_prompt_async_bad_json_exception_retries converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/converter/test_translation_converter.py b/tests/unit/converter/test_translation_converter.py index 80ba8b0da4..c8c26254a5 100644 --- a/tests/unit/converter/test_translation_converter.py +++ b/tests/unit/converter/test_translation_converter.py @@ -6,6 +6,7 @@ import pytest from unit.mocks import MockPromptTarget +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import TranslationConverter @@ -79,7 +80,7 @@ async def test_translation_converter_succeeds_after_retries(sqlite_instance): converted_value="hola", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "test-identifier"}, + prompt_target_identifier=TargetIdentifier(class_name="test-identifier", class_module="test"), sequence=1, ) ] diff --git a/tests/unit/converter/test_variation_converter.py b/tests/unit/converter/test_variation_converter.py index 023ca02b0d..e11f2e9642 100644 --- a/tests/unit/converter/test_variation_converter.py +++ b/tests/unit/converter/test_variation_converter.py @@ -7,6 +7,7 @@ from unit.mocks import MockPromptTarget from pyrit.exceptions.exception_classes import InvalidJsonException +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_converter import VariationConverter @@ -45,8 +46,8 @@ async def test_variation_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/docs/test_api_documentation.py b/tests/unit/docs/test_api_documentation.py index f67425aa19..48db273589 100644 --- a/tests/unit/docs/test_api_documentation.py +++ b/tests/unit/docs/test_api_documentation.py @@ -119,7 +119,7 @@ def get_module_exports(module_path: str) -> Set[str]: "exclude": set(), }, "pyrit.identifiers": { - "exclude": {"LegacyIdentifiable"}, + "exclude": set(), }, } diff --git a/tests/unit/exceptions/test_exception_context.py b/tests/unit/exceptions/test_exception_context.py index 7f4f5f59f7..e53cd42bce 100644 --- a/tests/unit/exceptions/test_exception_context.py +++ b/tests/unit/exceptions/test_exception_context.py @@ -12,6 +12,7 @@ get_execution_context, set_execution_context, ) +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier class TestExecutionContext: @@ -31,11 +32,20 @@ def test_default_values(self): def test_initialization_with_values(self): """Test ExecutionContext initialization with all values.""" + attack_id = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + target_id = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://api.openai.com", + ) context = ExecutionContext( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name="PromptSendingAttack", - attack_identifier={"__type__": "PromptSendingAttack", "id": "abc123"}, - component_identifier={"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"}, + attack_identifier=attack_id, + component_identifier=target_id, objective_target_conversation_id="conv-123", endpoint="https://api.openai.com", component_name="OpenAIChatTarget", @@ -43,8 +53,8 @@ def test_initialization_with_values(self): ) assert context.component_role == ComponentRole.OBJECTIVE_TARGET assert context.attack_strategy_name == "PromptSendingAttack" - assert context.attack_identifier == {"__type__": "PromptSendingAttack", "id": "abc123"} - assert context.component_identifier == {"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"} + assert context.attack_identifier is attack_id + assert context.component_identifier is target_id assert context.objective_target_conversation_id == "conv-123" assert context.endpoint == "https://api.openai.com" assert context.component_name == "OpenAIChatTarget" @@ -96,11 +106,19 @@ def test_get_exception_details_minimal(self): def test_get_exception_details_full(self): """Test exception details with full context.""" + attack_id = AttackIdentifier( + class_name="RedTeamingAttack", + class_module="pyrit.executor.attack.multi_turn.red_teaming", + ) + target_id = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + ) context = ExecutionContext( component_role=ComponentRole.OBJECTIVE_TARGET, attack_strategy_name="RedTeamingAttack", - attack_identifier={"__type__": "RedTeamingAttack", "id": "xyz"}, - component_identifier={"__type__": "OpenAIChatTarget"}, + attack_identifier=attack_id, + component_identifier=target_id, objective_target_conversation_id="conv-456", objective="Tell me how to hack a system", ) @@ -247,7 +265,11 @@ def test_execution_context_creates_manager(self): def test_execution_context_extracts_endpoint(self): """Test that endpoint is extracted from component_identifier.""" - component_id = {"__type__": "OpenAIChatTarget", "endpoint": "https://api.openai.com"} + component_id = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://api.openai.com", + ) manager = execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, component_identifier=component_id, @@ -255,8 +277,11 @@ def test_execution_context_extracts_endpoint(self): assert manager.context.endpoint == "https://api.openai.com" def test_execution_context_extracts_component_name(self): - """Test that component_name is extracted from component_identifier.__type__.""" - component_id = {"__type__": "TrueFalseScorer", "endpoint": "https://api.openai.com"} + """Test that component_name is extracted from component_identifier.class_name.""" + component_id = ScorerIdentifier( + class_name="TrueFalseScorer", + class_module="pyrit.score.true_false.true_false_scorer", + ) manager = execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, component_identifier=component_id, @@ -265,20 +290,32 @@ def test_execution_context_extracts_component_name(self): def test_execution_context_no_endpoint(self): """Test that endpoint is None when not in component_identifier.""" - component_id = {"__type__": "TextTarget"} + component_id = TargetIdentifier( + class_name="TextTarget", + class_module="pyrit.prompt_target.text_target", + ) manager = execution_context( component_role=ComponentRole.OBJECTIVE_TARGET, component_identifier=component_id, ) - assert manager.context.endpoint is None + assert manager.context.endpoint == "" def test_execution_context_full_usage(self): """Test full usage of execution_context as context manager.""" + attack_id = AttackIdentifier( + class_name="CrescendoAttack", + class_module="pyrit.executor.attack.multi_turn.crescendo", + ) + target_id = TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + endpoint="https://example.com", + ) with execution_context( component_role=ComponentRole.ADVERSARIAL_CHAT, attack_strategy_name="CrescendoAttack", - attack_identifier={"id": "test"}, - component_identifier={"endpoint": "https://example.com"}, + attack_identifier=attack_id, + component_identifier=target_id, objective_target_conversation_id="conv-789", ): ctx = get_execution_context() diff --git a/tests/unit/executor/attack/component/test_conversation_manager.py b/tests/unit/executor/attack/component/test_conversation_manager.py index b874169d44..1876c0cd46 100644 --- a/tests/unit/executor/attack/component/test_conversation_manager.py +++ b/tests/unit/executor/attack/component/test_conversation_manager.py @@ -18,7 +18,7 @@ """ import uuid -from typing import Dict, List, Optional +from typing import List, Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -34,7 +34,7 @@ ) from pyrit.executor.attack.core import AttackContext from pyrit.executor.attack.core.attack_parameters import AttackParameters -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptConverterConfiguration, PromptNormalizer from pyrit.prompt_target import PromptChatTarget, PromptTarget @@ -68,13 +68,12 @@ class _TestAttackContext(AttackContext): @pytest.fixture -def attack_identifier() -> Dict[str, str]: +def attack_identifier() -> AttackIdentifier: """Create a sample attack identifier.""" - return { - "__type__": "TestAttack", - "__module__": "pyrit.executor.attack.test_attack", - "id": str(uuid.uuid4()), - } + return AttackIdentifier( + class_name="TestAttack", + class_module="pyrit.executor.attack.test_attack", + ) @pytest.fixture @@ -246,8 +245,8 @@ def test_swaps_user_to_assistant(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -262,8 +261,8 @@ def test_swaps_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -281,8 +280,8 @@ def test_swaps_simulated_assistant_to_user(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert len(result) == 1 @@ -300,8 +299,8 @@ def test_skips_system_messages(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # Only user message should be present, system skipped @@ -317,8 +316,8 @@ def test_assigns_new_uuids(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) # New ID should be different from original @@ -339,8 +338,8 @@ def test_preserves_message_content(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result[0].get_piece().original_value == "Original content" @@ -351,8 +350,8 @@ def test_empty_prepended_conversation(self) -> None: result = get_adversarial_chat_messages( [], adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), ) assert result == [] @@ -366,8 +365,8 @@ def test_applies_labels(self) -> None: result = get_adversarial_chat_messages( messages, adversarial_chat_conversation_id="adversarial_conv", - attack_identifier={"__type__": "TestAttack"}, - adversarial_chat_target_identifier={"id": "adversarial_target"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), + adversarial_chat_target_identifier=_mock_target_id("adversarial_target"), labels=labels, ) @@ -476,7 +475,7 @@ def test_with_custom_values(self, sample_score: Score) -> None: class TestConversationManagerInitialization: """Tests for ConversationManager initialization.""" - def test_init_with_required_parameters(self, attack_identifier: Dict[str, str]) -> None: + def test_init_with_required_parameters(self, attack_identifier: AttackIdentifier) -> None: """Test initialization with only required parameters.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -485,7 +484,7 @@ def test_init_with_required_parameters(self, attack_identifier: Dict[str, str]) assert manager._memory is not None def test_init_with_custom_prompt_normalizer( - self, attack_identifier: Dict[str, str], mock_prompt_normalizer: MagicMock + self, attack_identifier: AttackIdentifier, mock_prompt_normalizer: MagicMock ) -> None: """Test initialization with a custom prompt normalizer.""" manager = ConversationManager(attack_identifier=attack_identifier, prompt_normalizer=mock_prompt_normalizer) @@ -502,7 +501,7 @@ def test_init_with_custom_prompt_normalizer( class TestConversationRetrieval: """Tests for conversation retrieval methods.""" - def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: Dict[str, str]) -> None: + def test_get_conversation_returns_empty_list_when_no_messages(self, attack_identifier: AttackIdentifier) -> None: """Test get_conversation returns empty list for non-existent conversation.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -512,7 +511,7 @@ def test_get_conversation_returns_empty_list_when_no_messages(self, attack_ident assert result == [] def test_get_conversation_returns_messages_in_order( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_conversation returns messages in order.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -530,7 +529,7 @@ def test_get_conversation_returns_messages_in_order( assert result[0].message_pieces[0].api_role == "user" assert result[1].message_pieces[0].api_role == "assistant" - def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: Dict[str, str]) -> None: + def test_get_last_message_returns_none_for_empty_conversation(self, attack_identifier: AttackIdentifier) -> None: """Test get_last_message returns None for empty conversation.""" manager = ConversationManager(attack_identifier=attack_identifier) conversation_id = str(uuid.uuid4()) @@ -540,7 +539,7 @@ def test_get_last_message_returns_none_for_empty_conversation(self, attack_ident assert result is None def test_get_last_message_returns_last_piece( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message returns the most recent message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -558,7 +557,7 @@ def test_get_last_message_returns_last_piece( assert result.api_role == "assistant" def test_get_last_message_with_role_filter( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message with role filter returns correct message.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -577,7 +576,7 @@ def test_get_last_message_with_role_filter( assert result.api_role == "user" def test_get_last_message_with_role_filter_returns_none_when_no_match( - self, attack_identifier: Dict[str, str], sample_conversation: List[Message] + self, attack_identifier: AttackIdentifier, sample_conversation: List[Message] ) -> None: """Test get_last_message returns None when no message matches role filter.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -605,7 +604,7 @@ class TestSystemPromptHandling: """Tests for system prompt functionality.""" def test_set_system_prompt_with_chat_target( - self, attack_identifier: Dict[str, str], mock_chat_target: MagicMock + self, attack_identifier: AttackIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt calls target's set_system_prompt method.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -628,7 +627,7 @@ def test_set_system_prompt_with_chat_target( ) def test_set_system_prompt_without_labels( - self, attack_identifier: Dict[str, str], mock_chat_target: MagicMock + self, attack_identifier: AttackIdentifier, mock_chat_target: MagicMock ) -> None: """Test set_system_prompt works without labels.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -658,7 +657,7 @@ class TestInitializeContext: @pytest.mark.asyncio async def test_raises_error_for_empty_conversation_id( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, mock_attack_context: AttackContext, ) -> None: @@ -675,7 +674,7 @@ async def test_raises_error_for_empty_conversation_id( @pytest.mark.asyncio async def test_returns_default_state_for_no_prepended_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, mock_attack_context: AttackContext, ) -> None: @@ -696,7 +695,7 @@ async def test_returns_default_state_for_no_prepended_conversation( @pytest.mark.asyncio async def test_merges_memory_labels( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that memory_labels are merged with context labels.""" @@ -719,7 +718,7 @@ async def test_merges_memory_labels( @pytest.mark.asyncio async def test_adds_prepended_conversation_to_memory_for_chat_target( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -742,7 +741,7 @@ async def test_adds_prepended_conversation_to_memory_for_chat_target( @pytest.mark.asyncio async def test_converts_assistant_to_simulated_assistant( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_assistant_piece: MessagePiece, ) -> None: @@ -767,7 +766,7 @@ async def test_converts_assistant_to_simulated_assistant( @pytest.mark.asyncio async def test_normalizes_for_non_chat_target_by_default( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -793,7 +792,7 @@ async def test_normalizes_for_non_chat_target_by_default( @pytest.mark.asyncio async def test_normalizes_for_non_chat_target_when_configured( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -822,7 +821,7 @@ async def test_normalizes_for_non_chat_target_when_configured( @pytest.mark.asyncio async def test_returns_turn_count_for_multi_turn_attacks( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -845,7 +844,7 @@ async def test_returns_turn_count_for_multi_turn_attacks( @pytest.mark.asyncio async def test_multipart_message_extracts_scores_from_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_score: Score, ) -> None: @@ -919,7 +918,7 @@ async def test_multipart_message_extracts_scores_from_all_pieces( @pytest.mark.asyncio async def test_prepended_conversation_ignores_true_scores( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that prepended conversations only extract false scores, ignoring true scores. @@ -1023,7 +1022,7 @@ class TestPrependedConversationConfigSettings: @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_is_default( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1049,7 +1048,7 @@ async def test_non_chat_target_behavior_normalize_is_default( @pytest.mark.asyncio async def test_non_chat_target_behavior_raise_explicit( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1074,7 +1073,7 @@ async def test_non_chat_target_behavior_raise_explicit( @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_first_turn_creates_next_message( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1102,7 +1101,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_creates_next_messag @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existing_message( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1132,7 +1131,7 @@ async def test_non_chat_target_behavior_normalize_first_turn_prepends_to_existin @pytest.mark.asyncio async def test_non_chat_target_behavior_normalize_returns_empty_state( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1162,7 +1161,7 @@ async def test_non_chat_target_behavior_normalize_returns_empty_state( @pytest.mark.asyncio async def test_apply_converters_to_roles_default_applies_to_all( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1189,7 +1188,7 @@ async def test_apply_converters_to_roles_default_applies_to_all( @pytest.mark.asyncio async def test_apply_converters_to_roles_user_only( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1218,7 +1217,7 @@ async def test_apply_converters_to_roles_user_only( @pytest.mark.asyncio async def test_apply_converters_to_roles_assistant_only( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1247,7 +1246,7 @@ async def test_apply_converters_to_roles_assistant_only( @pytest.mark.asyncio async def test_apply_converters_to_roles_empty_list_skips_all( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1280,7 +1279,7 @@ async def test_apply_converters_to_roles_empty_list_skips_all( @pytest.mark.asyncio async def test_message_normalizer_default_uses_conversation_context_normalizer( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1308,7 +1307,7 @@ async def test_message_normalizer_default_uses_conversation_context_normalizer( @pytest.mark.asyncio async def test_message_normalizer_custom_normalizer_is_used( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1389,7 +1388,7 @@ def test_for_non_chat_target_with_custom_roles(self) -> None: @pytest.mark.asyncio async def test_chat_target_ignores_non_chat_target_behavior( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_conversation: List[Message], ) -> None: @@ -1421,7 +1420,7 @@ async def test_chat_target_ignores_non_chat_target_behavior( @pytest.mark.asyncio async def test_config_with_max_turns_validation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, ) -> None: """Test that config works correctly with max_turns validation.""" @@ -1471,7 +1470,7 @@ class TestAddPrependedConversationToMemory: @pytest.mark.asyncio async def test_adds_messages_to_memory( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that messages are added to memory.""" @@ -1490,7 +1489,7 @@ async def test_adds_messages_to_memory( @pytest.mark.asyncio async def test_assigns_conversation_id_to_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that conversation_id is assigned to all message pieces.""" @@ -1510,7 +1509,7 @@ async def test_assigns_conversation_id_to_all_pieces( @pytest.mark.asyncio async def test_assigns_attack_identifier_to_all_pieces( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_conversation: List[Message], ) -> None: """Test that attack_identifier is assigned to all message pieces.""" @@ -1530,7 +1529,7 @@ async def test_assigns_attack_identifier_to_all_pieces( @pytest.mark.asyncio async def test_raises_error_when_exceeds_max_turns( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, sample_user_piece: MessagePiece, sample_assistant_piece: MessagePiece, ) -> None: @@ -1556,7 +1555,7 @@ async def test_raises_error_when_exceeds_max_turns( @pytest.mark.asyncio async def test_multipart_response_counts_as_one_turn( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that a multi-part assistant response counts as only one turn.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1595,7 +1594,7 @@ async def test_multipart_response_counts_as_one_turn( @pytest.mark.asyncio async def test_returns_zero_for_empty_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that empty conversation returns 0 turns.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1611,7 +1610,7 @@ async def test_returns_zero_for_empty_conversation( @pytest.mark.asyncio async def test_applies_converters_when_provided( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_prompt_normalizer: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1633,7 +1632,7 @@ async def test_applies_converters_when_provided( @pytest.mark.asyncio async def test_handles_none_messages_gracefully( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, ) -> None: """Test that None messages are handled gracefully.""" manager = ConversationManager(attack_identifier=attack_identifier) @@ -1659,7 +1658,7 @@ class TestEdgeCasesAndErrorHandling: @pytest.mark.asyncio async def test_preserves_piece_metadata( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1688,7 +1687,7 @@ async def test_preserves_piece_metadata( @pytest.mark.asyncio async def test_preserves_original_and_converted_values( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_user_piece: MessagePiece, ) -> None: @@ -1716,7 +1715,7 @@ async def test_preserves_original_and_converted_values( @pytest.mark.asyncio async def test_handles_system_messages_in_prepended_conversation( self, - attack_identifier: Dict[str, str], + attack_identifier: AttackIdentifier, mock_chat_target: MagicMock, sample_system_piece: MessagePiece, sample_user_piece: MessagePiece, diff --git a/tests/unit/executor/attack/core/test_attack_strategy.py b/tests/unit/executor/attack/core/test_attack_strategy.py index 918b9c7531..9c98edb38c 100644 --- a/tests/unit/executor/attack/core/test_attack_strategy.py +++ b/tests/unit/executor/attack/core/test_attack_strategy.py @@ -69,7 +69,6 @@ def sample_attack_result(): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -112,7 +111,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -144,7 +142,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -176,7 +173,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -208,7 +204,6 @@ async def _perform_async(self, *, context): return AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, @@ -497,7 +492,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", executed_turns=1, @@ -542,7 +536,6 @@ async def _perform_async(self, *, context): result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "test_attack"}, outcome=AttackOutcome.SUCCESS, outcome_reason="Test successful", execution_time_ms=0, diff --git a/tests/unit/executor/attack/core/test_markdown_printer.py b/tests/unit/executor/attack/core/test_markdown_printer.py index 8a3c430269..a66fcb4260 100644 --- a/tests/unit/executor/attack/core/test_markdown_printer.py +++ b/tests/unit/executor/attack/core/test_markdown_printer.py @@ -8,7 +8,7 @@ import pytest from pyrit.executor.attack.printer.markdown_printer import MarkdownAttackResultPrinter -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import CentralMemory from pyrit.models import AttackOutcome, AttackResult, Message, MessagePiece, Score @@ -69,7 +69,7 @@ def sample_float_score(): def sample_attack_result(): return AttackResult( objective="Test objective", - attack_identifier={"__type__": "TestAttack"}, + attack_identifier=AttackIdentifier(class_name="TestAttack", class_module="test_module"), conversation_id="test-conv-123", executed_turns=3, execution_time_ms=1500, diff --git a/tests/unit/executor/attack/multi_turn/test_chunked_request.py b/tests/unit/executor/attack/multi_turn/test_chunked_request.py index b3bf1e8949..aad88f1099 100644 --- a/tests/unit/executor/attack/multi_turn/test_chunked_request.py +++ b/tests/unit/executor/attack/multi_turn/test_chunked_request.py @@ -8,7 +8,7 @@ during Crucible CTF red teaming exercises using PyRIT. """ -from unittest.mock import Mock +from unittest.mock import MagicMock import pytest @@ -17,6 +17,25 @@ ChunkedRequestAttack, ChunkedRequestAttackContext, ) +from pyrit.identifiers import TargetIdentifier +from pyrit.prompt_target import PromptTarget + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _make_mock_target(): + """Create a mock target with proper get_identifier.""" + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("MockTarget") + return target class TestChunkedRequestAttackContext: @@ -46,7 +65,7 @@ class TestChunkedRequestAttack: def test_init_default_values(self): """Test initialization with default values.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) assert attack._chunk_size == 50 @@ -55,7 +74,7 @@ def test_init_default_values(self): def test_init_custom_values(self): """Test initialization with custom values.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=25, @@ -69,7 +88,7 @@ def test_init_custom_values(self): def test_init_custom_request_template(self): """Test initialization with custom request template.""" - mock_target = Mock() + mock_target = _make_mock_target() template = "Show me {chunk_type} from position {start} to {end} for '{objective}'" attack = ChunkedRequestAttack( objective_target=mock_target, @@ -80,21 +99,21 @@ def test_init_custom_request_template(self): def test_init_invalid_chunk_size(self): """Test that invalid chunk_size raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="chunk_size must be >= 1"): ChunkedRequestAttack(objective_target=mock_target, chunk_size=0) def test_init_invalid_total_length(self): """Test that invalid total_length raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="total_length must be >= chunk_size"): ChunkedRequestAttack(objective_target=mock_target, chunk_size=100, total_length=50) def test_generate_chunk_prompts(self): """Test chunk prompt generation.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, @@ -112,7 +131,7 @@ def test_generate_chunk_prompts(self): def test_generate_chunk_prompts_custom_chunk_type(self): """Test chunk prompt generation with custom chunk type.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, @@ -129,7 +148,7 @@ def test_generate_chunk_prompts_custom_chunk_type(self): def test_validate_context_empty_objective(self): """Test validation fails with empty objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective="")) @@ -139,7 +158,7 @@ def test_validate_context_empty_objective(self): def test_validate_context_whitespace_objective(self): """Test validation fails with whitespace-only objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective=" ")) @@ -149,7 +168,7 @@ def test_validate_context_whitespace_objective(self): def test_validate_context_valid_objective(self): """Test validation succeeds with valid objective.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack(objective_target=mock_target) context = ChunkedRequestAttackContext(params=AttackParameters(objective="Extract the secret password")) @@ -159,7 +178,7 @@ def test_validate_context_valid_objective(self): def test_init_invalid_request_template_missing_start(self): """Test that request_template without 'start' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -169,7 +188,7 @@ def test_init_invalid_request_template_missing_start(self): def test_init_invalid_request_template_missing_end(self): """Test that request_template without 'end' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -179,7 +198,7 @@ def test_init_invalid_request_template_missing_end(self): def test_init_invalid_request_template_missing_chunk_type(self): """Test that request_template without 'chunk_type' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -189,7 +208,7 @@ def test_init_invalid_request_template_missing_chunk_type(self): def test_init_invalid_request_template_missing_objective(self): """Test that request_template without 'objective' placeholder raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -199,7 +218,7 @@ def test_init_invalid_request_template_missing_objective(self): def test_init_invalid_request_template_missing_multiple(self): """Test that request_template without multiple placeholders raises ValueError.""" - mock_target = Mock() + mock_target = _make_mock_target() with pytest.raises(ValueError, match="request_template must contain all required placeholders"): ChunkedRequestAttack( @@ -209,7 +228,7 @@ def test_init_invalid_request_template_missing_multiple(self): def test_init_valid_request_template_with_extra_placeholders(self): """Test that request_template with extra placeholders is accepted.""" - mock_target = Mock() + mock_target = _make_mock_target() # Should not raise - extra placeholders are fine as long as required ones are present attack = ChunkedRequestAttack( @@ -221,7 +240,7 @@ def test_init_valid_request_template_with_extra_placeholders(self): def test_generate_chunk_prompts_with_objective(self): """Test that chunk prompts include the objective from context.""" - mock_target = Mock() + mock_target = _make_mock_target() attack = ChunkedRequestAttack( objective_target=mock_target, chunk_size=50, diff --git a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py index bdf70a77d2..dc3e19952c 100644 --- a/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py +++ b/tests/unit/executor/attack/multi_turn/test_multi_prompt_sending.py @@ -63,6 +63,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer @@ -70,6 +71,7 @@ def mock_true_false_scorer(): def mock_non_true_false_scorer(): """Create a mock scorer that is not a true/false type""" scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = _mock_scorer_id() return scorer @@ -162,7 +164,8 @@ def test_init_with_valid_true_false_scorer(self, mock_target, mock_true_false_sc def test_init_with_all_custom_configurations(self, mock_target, mock_true_false_scorer, mock_prompt_normalizer): converter_cfg = AttackConverterConfig( - request_converters=[Base64Converter()], response_converters=[StringJoinConverter()] + request_converters=[PromptConverterConfiguration(converters=[Base64Converter()])], + response_converters=[PromptConverterConfiguration(converters=[StringJoinConverter()])], ) scoring_cfg = AttackScoringConfig(objective_scorer=mock_true_false_scorer) @@ -604,7 +607,9 @@ class TestConverterIntegration: async def test_perform_attack_with_converters( self, mock_target, mock_prompt_normalizer, basic_context, sample_response ): - converter_config = AttackConverterConfig(request_converters=[Base64Converter()]) + converter_config = AttackConverterConfig( + request_converters=[PromptConverterConfiguration(converters=[Base64Converter()])] + ) mock_prompt_normalizer.send_prompt_async.return_value = sample_response attack = MultiPromptSendingAttack( @@ -623,7 +628,9 @@ async def test_perform_attack_with_converters( async def test_perform_attack_with_response_converters( self, mock_target, mock_prompt_normalizer, basic_context, sample_response ): - converter_config = AttackConverterConfig(response_converters=[StringJoinConverter()]) + converter_config = AttackConverterConfig( + response_converters=[PromptConverterConfiguration(converters=[StringJoinConverter()])] + ) mock_prompt_normalizer.send_prompt_async.return_value = sample_response attack = MultiPromptSendingAttack( @@ -683,11 +690,13 @@ async def test_perform_attack_with_single_prompt(self, mock_target, mock_prompt_ assert result.last_response is not None assert mock_prompt_normalizer.send_prompt_async.call_count == 1 - def test_attack_has_unique_identifier(self, mock_target): + def test_attack_has_same_identifier_for_same_config(self, mock_target): attack1 = MultiPromptSendingAttack(objective_target=mock_target) attack2 = MultiPromptSendingAttack(objective_target=mock_target) - assert attack1.get_identifier() != attack2.get_identifier() + # Same config produces the same deterministic identifier + assert attack1.get_identifier().hash == attack2.get_identifier().hash + assert attack1.get_identifier().class_name == "MultiPromptSendingAttack" @pytest.mark.asyncio async def test_teardown_async_is_noop(self, mock_target, basic_context): diff --git a/tests/unit/executor/attack/multi_turn/test_red_teaming.py b/tests/unit/executor/attack/multi_turn/test_red_teaming.py index 340dc10c60..e6ed898b4a 100644 --- a/tests/unit/executor/attack/multi_turn/test_red_teaming.py +++ b/tests/unit/executor/attack/multi_turn/test_red_teaming.py @@ -1230,7 +1230,9 @@ async def test_perform_attack_with_message_bypasses_adversarial_chat_on_first_tu ): """Test that providing a message parameter bypasses adversarial chat generation on first turn.""" adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) - scoring_config = AttackScoringConfig(objective_scorer=MagicMock(spec=TrueFalseScorer)) + inline_scorer = MagicMock(spec=TrueFalseScorer) + inline_scorer.get_identifier.return_value = _mock_scorer_id() + scoring_config = AttackScoringConfig(objective_scorer=inline_scorer) attack = RedTeamingAttack( objective_target=mock_objective_target, @@ -1272,7 +1274,9 @@ async def test_perform_attack_with_multi_piece_message_uses_first_piece( ): """Test that multi-piece messages use only the first piece's converted_value.""" adversarial_config = AttackAdversarialConfig(target=mock_adversarial_chat) - scoring_config = AttackScoringConfig(objective_scorer=MagicMock(spec=TrueFalseScorer)) + inline_scorer = MagicMock(spec=TrueFalseScorer) + inline_scorer.get_identifier.return_value = _mock_scorer_id() + scoring_config = AttackScoringConfig(objective_scorer=inline_scorer) attack = RedTeamingAttack( objective_target=mock_objective_target, diff --git a/tests/unit/executor/attack/single_turn/test_context_compliance.py b/tests/unit/executor/attack/single_turn/test_context_compliance.py index 3908cb89fe..64975ba3a7 100644 --- a/tests/unit/executor/attack/single_turn/test_context_compliance.py +++ b/tests/unit/executor/attack/single_turn/test_context_compliance.py @@ -15,7 +15,7 @@ ContextComplianceAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -37,6 +37,16 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" @@ -68,6 +78,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_flip_attack.py b/tests/unit/executor/attack/single_turn/test_flip_attack.py index faffd77099..53caecd71a 100644 --- a/tests/unit/executor/attack/single_turn/test_flip_attack.py +++ b/tests/unit/executor/attack/single_turn/test_flip_attack.py @@ -13,7 +13,7 @@ FlipAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -34,6 +34,16 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptChatTarget for testing""" @@ -54,6 +64,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py index 9f9877c325..c15038c5da 100644 --- a/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py +++ b/tests/unit/executor/attack/single_turn/test_many_shot_jailbreak.py @@ -13,7 +13,7 @@ ManyShotJailbreakAttack, SingleTurnAttackContext, ) -from pyrit.identifiers import TargetIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -35,6 +35,16 @@ def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: ) +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target(): """Create a mock PromptTarget for testing""" @@ -67,6 +77,7 @@ def mock_scorer(): """Create a mock true/false scorer""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_prompt_sending.py b/tests/unit/executor/attack/single_turn/test_prompt_sending.py index b55da7cf3c..1e1b708249 100644 --- a/tests/unit/executor/attack/single_turn/test_prompt_sending.py +++ b/tests/unit/executor/attack/single_turn/test_prompt_sending.py @@ -44,6 +44,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer @@ -51,6 +52,7 @@ def mock_true_false_scorer(): def mock_non_true_false_scorer(): """Create a mock scorer that is not a true/false type""" scorer = MagicMock(spec=Scorer) + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer @@ -1146,13 +1148,13 @@ def test_attack_has_unique_identifier(self, mock_target): id2 = attack2.get_identifier() # Verify identifier structure - assert "__type__" in id1 - assert "__module__" in id1 - assert "id" in id1 + assert id1.class_name == "PromptSendingAttack" + assert id1.class_module is not None + assert id1.hash is not None - # Verify uniqueness - assert id1["id"] != id2["id"] - assert id1["__type__"] == id2["__type__"] == "PromptSendingAttack" + # Same config produces same identifier + assert id1.hash == id2.hash + assert id1.class_name == id2.class_name == "PromptSendingAttack" @pytest.mark.asyncio async def test_retry_stores_unsuccessful_conversation_and_updates_id( diff --git a/tests/unit/executor/attack/single_turn/test_role_play.py b/tests/unit/executor/attack/single_turn/test_role_play.py index 98218e3c9b..114e99bdf7 100644 --- a/tests/unit/executor/attack/single_turn/test_role_play.py +++ b/tests/unit/executor/attack/single_turn/test_role_play.py @@ -51,6 +51,7 @@ def mock_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer diff --git a/tests/unit/executor/attack/single_turn/test_skeleton_key.py b/tests/unit/executor/attack/single_turn/test_skeleton_key.py index 41278e9bdf..f2bd77ec12 100644 --- a/tests/unit/executor/attack/single_turn/test_skeleton_key.py +++ b/tests/unit/executor/attack/single_turn/test_skeleton_key.py @@ -41,6 +41,7 @@ def mock_true_false_scorer(): """Create a mock true/false scorer for testing""" scorer = MagicMock(spec=TrueFalseScorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = get_mock_scorer_identifier() return scorer diff --git a/tests/unit/executor/benchmark/test_fairness_bias.py b/tests/unit/executor/benchmark/test_fairness_bias.py index 8ad868d813..b23e9cf0a1 100644 --- a/tests/unit/executor/benchmark/test_fairness_bias.py +++ b/tests/unit/executor/benchmark/test_fairness_bias.py @@ -10,6 +10,7 @@ FairnessBiasBenchmark, FairnessBiasBenchmarkContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -29,10 +30,23 @@ def is_spacy_installed(): # Fixtures at the top of the file + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_prompt_target() -> MagicMock: """Mock prompt target for testing.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_prompt_target") return target @@ -64,7 +78,6 @@ def sample_attack_result() -> AttackResult: result = AttackResult( conversation_id="test-conversation-id", objective="Test objective", - attack_identifier={"name": "fairness_bias_benchmark"}, executed_turns=1, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, diff --git a/tests/unit/executor/benchmark/test_question_answering.py b/tests/unit/executor/benchmark/test_question_answering.py index 5e5af568c9..0233de781b 100644 --- a/tests/unit/executor/benchmark/test_question_answering.py +++ b/tests/unit/executor/benchmark/test_question_answering.py @@ -10,6 +10,7 @@ QuestionAnsweringBenchmark, QuestionAnsweringBenchmarkContext, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import ( AttackOutcome, AttackResult, @@ -20,12 +21,24 @@ ) from pyrit.prompt_target import PromptTarget - # Fixtures at the top of the file + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_prompt_target() -> MagicMock: """Mock prompt target for testing.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_prompt_target") return target diff --git a/tests/unit/executor/core/test_strategy.py b/tests/unit/executor/core/test_strategy.py index ba300afbc2..3b831ecadd 100644 --- a/tests/unit/executor/core/test_strategy.py +++ b/tests/unit/executor/core/test_strategy.py @@ -11,6 +11,7 @@ execution_context, ) from pyrit.executor.core.strategy import Strategy, StrategyContext +from pyrit.identifiers import ScorerIdentifier @dataclass @@ -195,7 +196,10 @@ async def test_error_includes_component_identifier(self): with execution_context( component_role=ComponentRole.OBJECTIVE_SCORER, attack_strategy_name="TestAttack", - component_identifier={"__type__": "SelfAskTrueFalseScorer"}, + component_identifier=ScorerIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score.true_false.self_ask_true_false_scorer", + ), ): await strategy.execute_with_context_async(context=context) diff --git a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py index 25e49a1c10..3be6513502 100644 --- a/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py +++ b/tests/unit/executor/promptgen/fuzzer/test_fuzzer_converter.py @@ -14,6 +14,7 @@ FuzzerShortenConverter, FuzzerSimilarConverter, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece @@ -90,8 +91,8 @@ async def test_converter_send_prompt_async_bad_json_exception_retries( converted_value=converted_value, original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ) ] diff --git a/tests/unit/executor/promptgen/test_anecdoctor.py b/tests/unit/executor/promptgen/test_anecdoctor.py index 78dbfd1725..e841c0c1ed 100644 --- a/tests/unit/executor/promptgen/test_anecdoctor.py +++ b/tests/unit/executor/promptgen/test_anecdoctor.py @@ -13,16 +13,28 @@ AnecdoctorGenerator, AnecdoctorResult, ) +from pyrit.identifiers import TargetIdentifier from pyrit.models import Message from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptChatTarget +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_objective_target() -> PromptChatTarget: """Create a mock objective target for testing.""" mock_target = MagicMock(spec=PromptChatTarget) mock_target.set_system_prompt = MagicMock() + mock_target.get_identifier.return_value = _mock_target_id("mock_objective_target") return mock_target @@ -31,6 +43,7 @@ def mock_processing_model() -> PromptChatTarget: """Create a mock processing model for testing.""" mock_model = MagicMock(spec=PromptChatTarget) mock_model.set_system_prompt = MagicMock() + mock_model.get_identifier.return_value = _mock_target_id("MockProcessingModel") return mock_model diff --git a/tests/unit/executor/workflow/test_xpia.py b/tests/unit/executor/workflow/test_xpia.py index d25cd00f5c..baede625bf 100644 --- a/tests/unit/executor/workflow/test_xpia.py +++ b/tests/unit/executor/workflow/test_xpia.py @@ -12,17 +12,40 @@ XPIAStatus, XPIAWorkflow, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_normalizer import PromptNormalizer from pyrit.prompt_target import PromptTarget from pyrit.score import Scorer - # Shared fixtures for all test classes + + +def _mock_scorer_id(name: str = "MockScorer") -> ScorerIdentifier: + """Helper to create ScorerIdentifier for tests.""" + return ScorerIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + +def _mock_target_id(name: str = "MockTarget") -> TargetIdentifier: + """Helper to create TargetIdentifier for tests.""" + return TargetIdentifier( + class_name=name, + class_module="test_module", + class_description="", + identifier_type="instance", + ) + + @pytest.fixture def mock_attack_setup_target() -> MagicMock: """Create a mock attack setup target.""" target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = _mock_target_id("mock_attack_setup_target") return target @@ -31,6 +54,7 @@ def mock_scorer() -> MagicMock: """Create a mock scorer.""" scorer = MagicMock(spec=Scorer) scorer.score_text_async = AsyncMock() + scorer.get_identifier.return_value = _mock_scorer_id() return scorer diff --git a/tests/unit/identifiers/test_attack_identifier.py b/tests/unit/identifiers/test_attack_identifier.py new file mode 100644 index 0000000000..b20768b2db --- /dev/null +++ b/tests/unit/identifiers/test_attack_identifier.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for AttackIdentifier-specific functionality. + +Note: Base Identifier functionality (hash computation, to_dict/from_dict basics, +frozen/hashable properties) is tested via ScorerIdentifier in test_scorer_identifier.py. +These tests focus on AttackIdentifier-specific fields and from_dict deserialization +of nested sub-identifiers. +""" + +import pytest + +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier + + +def _make_target_identifier() -> TargetIdentifier: + return TargetIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target.openai.openai_chat_target", + class_description="OpenAI chat target", + identifier_type="instance", + endpoint="https://api.openai.com/v1", + model_name="gpt-4o", + ) + + +def _make_scorer_identifier() -> ScorerIdentifier: + return ScorerIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score.true_false.self_ask_true_false_scorer", + class_description="True/false scorer", + identifier_type="instance", + ) + + +def _make_converter_identifier() -> ConverterIdentifier: + return ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + class_description="Base64 converter", + identifier_type="instance", + supported_input_types=["text"], + supported_output_types=["text"], + ) + + +class TestAttackIdentifierCreation: + """Test basic AttackIdentifier creation.""" + + def test_creation_minimal(self): + """Test creating an AttackIdentifier with only base fields.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + assert identifier.class_name == "PromptSendingAttack" + assert identifier.objective_target_identifier is None + assert identifier.objective_scorer_identifier is None + assert identifier.request_converter_identifiers is None + assert identifier.response_converter_identifiers is None + assert identifier.attack_specific_params is None + assert identifier.hash is not None + + def test_creation_all_fields(self): + """Test creating an AttackIdentifier with all sub-identifiers.""" + target_id = _make_target_identifier() + scorer_id = _make_scorer_identifier() + converter_id = _make_converter_identifier() + + identifier = AttackIdentifier( + class_name="CrescendoAttack", + class_module="pyrit.executor.attack.multi_turn.crescendo", + objective_target_identifier=target_id, + objective_scorer_identifier=scorer_id, + request_converter_identifiers=[converter_id], + response_converter_identifiers=[converter_id], + attack_specific_params={"max_turns": 10}, + ) + + assert identifier.objective_target_identifier is target_id + assert identifier.objective_scorer_identifier is scorer_id + assert identifier.request_converter_identifiers == [converter_id] + assert identifier.response_converter_identifiers == [converter_id] + assert identifier.attack_specific_params == {"max_turns": 10} + + def test_frozen(self): + """Test that AttackIdentifier is immutable.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + with pytest.raises(AttributeError): + identifier.class_name = "Other" # type: ignore[misc] + + def test_hashable(self): + """Test that AttackIdentifier can be used in sets/dicts.""" + identifier = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + # Should not raise + {identifier} + {identifier: 1} + + +class TestAttackIdentifierFromDict: + """Test AttackIdentifier.from_dict with nested sub-identifier deserialization.""" + + def test_from_dict_minimal(self): + """Test from_dict with no nested sub-identifiers.""" + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result, AttackIdentifier) + assert result.class_name == "PromptSendingAttack" + assert result.objective_target_identifier is None + assert result.objective_scorer_identifier is None + assert result.request_converter_identifiers is None + assert result.response_converter_identifiers is None + + def test_from_dict_deserializes_nested_target(self): + """Test that from_dict recursively deserializes the target sub-identifier.""" + target_id = _make_target_identifier() + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "objective_target_identifier": target_id.to_dict(), + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result.objective_target_identifier, TargetIdentifier) + assert result.objective_target_identifier.class_name == "OpenAIChatTarget" + assert result.objective_target_identifier.endpoint == "https://api.openai.com/v1" + + def test_from_dict_deserializes_nested_scorer(self): + """Test that from_dict recursively deserializes the scorer sub-identifier.""" + scorer_id = _make_scorer_identifier() + data = { + "class_name": "CrescendoAttack", + "class_module": "pyrit.executor.attack.multi_turn.crescendo", + "objective_scorer_identifier": scorer_id.to_dict(), + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result.objective_scorer_identifier, ScorerIdentifier) + assert result.objective_scorer_identifier.class_name == "SelfAskTrueFalseScorer" + + def test_from_dict_deserializes_nested_converters(self): + """Test that from_dict recursively deserializes converter sub-identifiers.""" + converter_id = _make_converter_identifier() + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "request_converter_identifiers": [converter_id.to_dict()], + } + + result = AttackIdentifier.from_dict(data) + + assert result.request_converter_identifiers is not None + assert len(result.request_converter_identifiers) == 1 + assert isinstance(result.request_converter_identifiers[0], ConverterIdentifier) + assert result.request_converter_identifiers[0].class_name == "Base64Converter" + + def test_from_dict_all_nested(self): + """Test from_dict with all nested sub-identifiers as dicts.""" + target_id = _make_target_identifier() + scorer_id = _make_scorer_identifier() + converter_id = _make_converter_identifier() + + data = { + "class_name": "CrescendoAttack", + "class_module": "pyrit.executor.attack.multi_turn.crescendo", + "objective_target_identifier": target_id.to_dict(), + "objective_scorer_identifier": scorer_id.to_dict(), + "request_converter_identifiers": [converter_id.to_dict()], + "attack_specific_params": {"max_turns": 10}, + } + + result = AttackIdentifier.from_dict(data) + + assert isinstance(result, AttackIdentifier) + assert isinstance(result.objective_target_identifier, TargetIdentifier) + assert isinstance(result.objective_scorer_identifier, ScorerIdentifier) + assert isinstance(result.request_converter_identifiers[0], ConverterIdentifier) + assert result.attack_specific_params == {"max_turns": 10} + + def test_from_dict_already_typed_sub_identifiers_not_re_parsed(self): + """Test that from_dict handles already-typed sub-identifiers without error.""" + target_id = _make_target_identifier() + converter_id = _make_converter_identifier() + + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "objective_target_identifier": target_id, # Already typed, not a dict + "request_converter_identifiers": [converter_id], # Already typed + } + + result = AttackIdentifier.from_dict(data) + + assert result.objective_target_identifier is target_id + assert result.request_converter_identifiers[0] is converter_id + + def test_from_dict_deserializes_response_converters(self): + """Test that from_dict recursively deserializes response converter sub-identifiers.""" + converter_id = _make_converter_identifier() + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "response_converter_identifiers": [converter_id.to_dict()], + } + + result = AttackIdentifier.from_dict(data) + + assert result.response_converter_identifiers is not None + assert len(result.response_converter_identifiers) == 1 + assert isinstance(result.response_converter_identifiers[0], ConverterIdentifier) + assert result.response_converter_identifiers[0].class_name == "Base64Converter" + + def test_from_dict_none_converters_stays_none(self): + """Test that None converter lists are preserved as None.""" + data = { + "class_name": "PromptSendingAttack", + "class_module": "pyrit.executor.attack.single_turn.prompt_sending", + "request_converter_identifiers": None, + "response_converter_identifiers": None, + } + + result = AttackIdentifier.from_dict(data) + assert result.request_converter_identifiers is None + assert result.response_converter_identifiers is None + + +class TestAttackIdentifierRoundTrip: + """Test to_dict → from_dict round-trip fidelity.""" + + def test_round_trip_minimal(self): + """Test round-trip with minimal fields.""" + original = AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending", + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert restored.class_name == original.class_name + assert restored.class_module == original.class_module + assert restored.hash == original.hash + + def test_round_trip_with_nested_identifiers(self): + """Test round-trip preserves nested sub-identifiers.""" + original = AttackIdentifier( + class_name="CrescendoAttack", + class_module="pyrit.executor.attack.multi_turn.crescendo", + objective_target_identifier=_make_target_identifier(), + objective_scorer_identifier=_make_scorer_identifier(), + request_converter_identifiers=[_make_converter_identifier()], + response_converter_identifiers=[_make_converter_identifier()], + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert isinstance(restored.objective_target_identifier, TargetIdentifier) + assert isinstance(restored.objective_scorer_identifier, ScorerIdentifier) + assert isinstance(restored.request_converter_identifiers[0], ConverterIdentifier) + assert isinstance(restored.response_converter_identifiers[0], ConverterIdentifier) + assert restored.hash == original.hash + + def test_round_trip_with_attack_specific_params(self): + """Test round-trip preserves attack_specific_params.""" + original = AttackIdentifier( + class_name="TreeOfAttacks", + class_module="pyrit.executor.attack.multi_turn.tree_of_attacks", + attack_specific_params={"width": 3, "depth": 5, "pruning": True}, + ) + + restored = AttackIdentifier.from_dict(original.to_dict()) + + assert restored.attack_specific_params == {"width": 3, "depth": 5, "pruning": True} + assert restored.hash == original.hash diff --git a/tests/unit/identifiers/test_identifiers.py b/tests/unit/identifiers/test_identifiers.py index a9e373a7f7..64e1da8904 100644 --- a/tests/unit/identifiers/test_identifiers.py +++ b/tests/unit/identifiers/test_identifiers.py @@ -6,36 +6,10 @@ import pytest import pyrit -from pyrit.identifiers import Identifier, LegacyIdentifiable +from pyrit.identifiers import Identifier from pyrit.identifiers.identifier import _EXCLUDE, _ExcludeFrom, _expand_exclusions -class TestLegacyIdentifiable: - """Tests for the LegacyIdentifiable abstract base class.""" - - def test_legacy_identifiable_get_identifier_is_abstract(self): - """Test that get_identifier is an abstract method that must be implemented.""" - - class ConcreteLegacyIdentifiable(LegacyIdentifiable): - def get_identifier(self) -> dict[str, str]: - return {"type": "test", "name": "example"} - - obj = ConcreteLegacyIdentifiable() - result = obj.get_identifier() - assert result == {"type": "test", "name": "example"} - - def test_legacy_identifiable_str_returns_identifier_dict(self): - """Test that __str__ returns the get_identifier() result as a string.""" - - class ConcreteLegacyIdentifiable(LegacyIdentifiable): - def get_identifier(self) -> dict[str, str]: - return {"type": "test"} - - obj = ConcreteLegacyIdentifiable() - # __str__ returns the identifier dict as a string - assert str(obj) == "{'type': 'test'}" - - class TestIdentifier: """Tests for the Identifier dataclass.""" diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index be2fa4c64d..56813d4417 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -6,7 +6,7 @@ from typing import Sequence from pyrit.common.utils import to_sha256 -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import MemoryInterface from pyrit.memory.memory_models import AttackResultEntry from pyrit.models import ( @@ -36,7 +36,6 @@ def create_attack_result(conversation_id: str, objective_num: int, outcome: Atta return AttackResult( conversation_id=conversation_id, objective=f"Objective {objective_num}", - attack_identifier={"name": "test_attack"}, outcome=outcome, ) @@ -47,7 +46,6 @@ def test_add_attack_results_to_memory(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1", "module": "test_module"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -58,7 +56,6 @@ def test_add_attack_results_to_memory(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2", "module": "test_module"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -85,7 +82,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -94,7 +90,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -103,7 +98,6 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -134,7 +128,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -143,7 +136,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result2 = AttackResult( conversation_id="conv_1", # Same conversation ID objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -152,7 +144,6 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) attack_result3 = AttackResult( conversation_id="conv_2", # Different conversation ID objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -176,7 +167,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective for success", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -185,7 +175,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Another objective for failure", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -194,7 +183,6 @@ def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Different objective entirely", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -219,7 +207,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -228,7 +215,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.SUCCESS, # Same outcome @@ -237,7 +223,6 @@ def test_get_attack_results_by_outcome(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_3", objective="Test objective 3", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.FAILURE, # Different outcome @@ -267,7 +252,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result1 = AttackResult( conversation_id="conv_1", objective=objective1, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -277,7 +261,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result2 = AttackResult( conversation_id="conv_2", objective=objective2, - attack_identifier={"name": "test_attack"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -287,7 +270,6 @@ def test_get_attack_results_by_objective_sha256(sqlite_instance: MemoryInterface attack_result3 = AttackResult( conversation_id="conv_3", objective=objective3, - attack_identifier={"name": "test_attack"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.UNDETERMINED, @@ -312,7 +294,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective for success", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -321,7 +302,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_1", # Same conversation ID objective="Another objective for failure", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, # Different outcome @@ -330,7 +310,6 @@ def test_get_attack_results_multiple_filters(sqlite_instance: MemoryInterface): attack_result3 = AttackResult( conversation_id="conv_2", # Different conversation ID objective="Test objective for success", - attack_identifier={"name": "test_attack_3"}, executed_turns=7, execution_time_ms=1500, outcome=AttackOutcome.SUCCESS, @@ -357,7 +336,6 @@ def test_get_attack_results_no_filters(sqlite_instance: MemoryInterface): attack_result1 = AttackResult( conversation_id="conv_1", objective="Test objective 1", - attack_identifier={"name": "test_attack_1"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -366,7 +344,6 @@ def test_get_attack_results_no_filters(sqlite_instance: MemoryInterface): attack_result2 = AttackResult( conversation_id="conv_2", objective="Test objective 2", - attack_identifier={"name": "test_attack_2"}, executed_turns=3, execution_time_ms=500, outcome=AttackOutcome.FAILURE, @@ -388,7 +365,6 @@ def test_get_attack_results_empty_list(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -407,7 +383,6 @@ def test_get_attack_results_nonexistent_ids(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -457,7 +432,6 @@ def test_attack_result_with_last_response_and_score(sqlite_instance: MemoryInter attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with relationships", - attack_identifier={"name": "test_attack"}, last_response=message_piece, last_score=score, executed_turns=5, @@ -487,7 +461,7 @@ def test_attack_result_all_outcomes(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id=f"conv_{i}", objective=f"Test objective {i}", - attack_identifier={"name": f"test_attack_{i}"}, + attack_identifier=AttackIdentifier(class_name=f"TestAttack{i}", class_module="test.module"), executed_turns=i + 1, execution_time_ms=(i + 1) * 100, outcome=outcome, @@ -523,7 +497,6 @@ def test_attack_result_metadata_handling(sqlite_instance: MemoryInterface): attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with metadata", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -547,7 +520,6 @@ def test_attack_result_objective_sha256_auto_generation(sqlite_instance: MemoryI attack_result = AttackResult( conversation_id="conv_1", objective=objective, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -577,7 +549,6 @@ def test_attack_result_with_attack_generation_conversation_ids(sqlite_instance: attack_result = AttackResult( conversation_id="conv_1", objective="Test objective with conversation IDs", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, @@ -605,7 +576,6 @@ def test_attack_result_without_attack_generation_conversation_ids(sqlite_instanc attack_result = AttackResult( conversation_id="conv_1", objective="Test objective without conversation IDs", - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=AttackOutcome.SUCCESS, diff --git a/tests/unit/memory/memory_interface/test_interface_export.py b/tests/unit/memory/memory_interface/test_interface_export.py index 16b5847522..8fe3b2fa50 100644 --- a/tests/unit/memory/memory_interface/test_interface_export.py +++ b/tests/unit/memory/memory_interface/test_interface_export.py @@ -15,7 +15,7 @@ def test_export_conversation_by_attack_id_file_created( sqlite_instance: MemoryInterface, sample_conversations: Sequence[MessagePiece] ): - attack1_id = sample_conversations[0].attack_identifier["id"] + attack1_id = sample_conversations[0].attack_identifier.hash # Default path in export_conversations() file_name = f"{attack1_id}.json" diff --git a/tests/unit/memory/memory_interface/test_interface_prompts.py b/tests/unit/memory/memory_interface/test_interface_prompts.py index a6b9b4fc05..1c064da99c 100644 --- a/tests/unit/memory/memory_interface/test_interface_prompts.py +++ b/tests/unit/memory/memory_interface/test_interface_prompts.py @@ -9,6 +9,7 @@ from uuid import uuid4 import pytest +from unit.mocks import get_mock_target from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ScorerIdentifier @@ -110,8 +111,8 @@ def test_get_message_pieces_uuid_and_string_ids(sqlite_instance: MemoryInterface def test_duplicate_memory(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) conversation_id_1 = "11111" conversation_id_2 = "22222" conversation_id_3 = "33333" @@ -167,8 +168,8 @@ def test_duplicate_memory(sqlite_instance: MemoryInterface): all_pieces = sqlite_instance.get_message_pieces() assert len(all_pieces) == 9 # Attack IDs are preserved (not changed) when duplicating - assert len([p for p in all_pieces if p.attack_identifier["id"] == attack1.get_identifier()["id"]]) == 8 - assert len([p for p in all_pieces if p.attack_identifier["id"] == attack2.get_identifier()["id"]]) == 1 + assert len([p for p in all_pieces if p.attack_identifier.hash == attack1.get_identifier().hash]) == 8 + assert len([p for p in all_pieces if p.attack_identifier.hash == attack2.get_identifier().hash]) == 1 assert len([p for p in all_pieces if p.conversation_id == conversation_id_1]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_2]) == 2 assert len([p for p in all_pieces if p.conversation_id == conversation_id_3]) == 1 @@ -181,7 +182,7 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac conversation_id = str(uuid4()) prompt_id_1 = uuid4() prompt_id_2 = uuid4() - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) memory_labels = {"sample": "label"} pieces = [ MessagePiece( @@ -245,15 +246,15 @@ def test_duplicate_conversation_pieces_not_score(sqlite_instance: MemoryInterfac assert piece.id not in (prompt_id_1, prompt_id_2) assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 # Attack ID is preserved, so both original and duplicated pieces have the same attack ID - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier()["id"])) == 2 + assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 def test_duplicate_conversation_excluding_last_turn(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id_1 = "11111" conversation_id_2 = "22222" pieces = [ @@ -317,7 +318,7 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M conversation_id = str(uuid4()) prompt_id_1 = uuid4() prompt_id_2 = uuid4() - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) memory_labels = {"sample": "label"} pieces = [ MessagePiece( @@ -399,13 +400,13 @@ def test_duplicate_conversation_excluding_last_turn_not_score(sqlite_instance: M assert new_pieces[1].id != prompt_id_2 assert len(sqlite_instance.get_prompt_scores(labels=memory_labels)) == 2 # Attack ID is preserved - assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier()["id"])) == 2 + assert len(sqlite_instance.get_prompt_scores(attack_id=attack1.get_identifier().hash)) == 2 # The duplicate prompts ids should not have scores so only two scores are returned assert len(sqlite_instance.get_prompt_scores(prompt_ids=[str(prompt_id_1), str(prompt_id_2)] + new_pieces_ids)) == 2 def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id_1 = "11111" pieces = [ MessagePiece( @@ -455,7 +456,7 @@ def test_duplicate_conversation_excluding_last_turn_same_attack(sqlite_instance: def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "11111" pieces = [ MessagePiece( @@ -481,14 +482,14 @@ def test_duplicate_memory_preserves_attack_id(sqlite_instance: MemoryInterface): assert new_conversation_id != conversation_id # Both pieces should have the same attack ID - attack_ids = {p.attack_identifier["id"] for p in all_pieces} + attack_ids = {p.attack_identifier.hash for p in all_pieces} assert len(attack_ids) == 1 - assert attack1.get_identifier()["id"] in attack_ids + assert attack1.get_identifier().hash in attack_ids def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface): """Test that duplicated conversation has new piece IDs.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "test-conv-123" original_piece = MessagePiece( role="user", @@ -520,7 +521,7 @@ def test_duplicate_conversation_creates_new_ids(sqlite_instance: MemoryInterface def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: MemoryInterface): """Test that duplicated conversation preserves original_prompt_id for tracing.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "test-conv-456" original_piece = MessagePiece( role="user", @@ -544,7 +545,7 @@ def test_duplicate_conversation_preserves_original_prompt_id(sqlite_instance: Me def test_duplicate_conversation_with_multiple_pieces(sqlite_instance: MemoryInterface): """Test that duplicating a multi-piece conversation works correctly.""" - attack1 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = "multi-piece-conv" pieces = [ @@ -789,8 +790,8 @@ def test_get_message_pieces_id(sqlite_instance: MemoryInterface): def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): - attack1 = PromptSendingAttack(objective_target=MagicMock()) - attack2 = PromptSendingAttack(objective_target=MagicMock()) + attack1 = PromptSendingAttack(objective_target=get_mock_target()) + attack2 = PromptSendingAttack(objective_target=get_mock_target("Target2")) entries = [ PromptMemoryEntry( @@ -818,7 +819,7 @@ def test_get_message_pieces_attack(sqlite_instance: MemoryInterface): sqlite_instance._insert_entries(entries=entries) - attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier()["id"]) + attack1_entries = sqlite_instance.get_message_pieces(attack_id=attack1.get_identifier().hash) assert len(attack1_entries) == 2 assert_original_value_in_list("Hello 1", attack1_entries) @@ -950,7 +951,7 @@ def test_get_message_pieces_by_hash(sqlite_instance: MemoryInterface): def test_get_message_pieces_with_non_matching_memory_labels(sqlite_instance: MemoryInterface): - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) labels = {"op_name": "op1", "user_name": "name1", "harm_category": "dummy1"} entries = [ PromptMemoryEntry( diff --git a/tests/unit/memory/memory_interface/test_interface_scenario_results.py b/tests/unit/memory/memory_interface/test_interface_scenario_results.py index 4fd86a4130..810300b98b 100644 --- a/tests/unit/memory/memory_interface/test_interface_scenario_results.py +++ b/tests/unit/memory/memory_interface/test_interface_scenario_results.py @@ -7,7 +7,7 @@ import pytest from unit.mocks import get_mock_scorer_identifier -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.memory import MemoryInterface from pyrit.models import ( AttackOutcome, @@ -30,7 +30,6 @@ def create_attack_result(conversation_id: str, objective: str, outcome: AttackOu return AttackResult( conversation_id=conversation_id, objective=objective, - attack_identifier={"name": "test_attack"}, executed_turns=5, execution_time_ms=1000, outcome=outcome, @@ -64,7 +63,7 @@ def create_scenario_result( return ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results=attack_results, objective_scorer_identifier=scorer_identifier, ) @@ -279,7 +278,9 @@ def test_preserves_metadata(sqlite_instance: MemoryInterface): scenario_result = ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target", "endpoint": "https://example.com"}, + objective_target_identifier=TargetIdentifier( + class_name="test_target", class_module="test", endpoint="https://example.com" + ), attack_results={}, objective_scorer_identifier=scorer_identifier, ) @@ -361,7 +362,7 @@ def test_filter_by_labels(sqlite_instance: MemoryInterface, sample_attack_result scenario_identifier = ScenarioIdentifier(name="Labeled Scenario", scenario_version=1) scenario_result = ScenarioResult( scenario_identifier=scenario_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [sample_attack_results[0]]}, labels={"environment": "testing", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -385,7 +386,7 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Scenario 1", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, labels={"environment": "testing", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -394,7 +395,7 @@ def test_filter_by_multiple_labels(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Scenario 2", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, labels={"environment": "production", "team": "red-team"}, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -423,7 +424,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Recent Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, completion_time=now, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -432,7 +433,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Yesterday Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, completion_time=yesterday, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -441,7 +442,7 @@ def test_filter_by_completion_time(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="Old Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack3": [attack_result3]}, completion_time=last_week, objective_scorer_identifier=get_mock_scorer_identifier(), @@ -474,7 +475,7 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Old Version Scenario", scenario_version=1, pyrit_version="0.4.0") scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -482,7 +483,7 @@ def test_filter_by_pyrit_version(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="New Version Scenario", scenario_version=1, pyrit_version="0.5.0") scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "test_target"}, + objective_target_identifier=TargetIdentifier(class_name="test_target", class_module="test"), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -507,7 +508,9 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Azure Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://myresource.openai.azure.com"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://myresource.openai.azure.com" + ), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -515,7 +518,9 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="OpenAI Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://api.openai.com/v1"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://api.openai.com/v1" + ), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -523,7 +528,7 @@ def test_filter_by_target_endpoint(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="No Endpoint Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "Local"}, + objective_target_identifier=TargetIdentifier(class_name="Local", class_module="test"), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -554,7 +559,7 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="GPT-4 Scenario", scenario_version=1) scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-4-0613"}, + objective_target_identifier=TargetIdentifier(class_name="OpenAI", class_module="test", model_name="gpt-4-0613"), attack_results={"Attack1": [attack_result1]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -562,7 +567,7 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="GPT-4o Scenario", scenario_version=1) scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-4o"}, + objective_target_identifier=TargetIdentifier(class_name="OpenAI", class_module="test", model_name="gpt-4o"), attack_results={"Attack2": [attack_result2]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -570,7 +575,9 @@ def test_filter_by_target_model_name(sqlite_instance: MemoryInterface): scenario3_identifier = ScenarioIdentifier(name="GPT-3.5 Scenario", scenario_version=1) scenario3 = ScenarioResult( scenario_identifier=scenario3_identifier, - objective_target_identifier={"target": "OpenAI", "model_name": "gpt-3.5-turbo"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", model_name="gpt-3.5-turbo" + ), attack_results={"Attack3": [attack_result3]}, objective_scorer_identifier=get_mock_scorer_identifier(), ) @@ -603,7 +610,9 @@ def test_combined_filters(sqlite_instance: MemoryInterface): scenario1_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.5.0") scenario1 = ScenarioResult( scenario_identifier=scenario1_identifier, - objective_target_identifier={"target": "OpenAI", "endpoint": "https://api.openai.com", "model_name": "gpt-4"}, + objective_target_identifier=TargetIdentifier( + class_name="OpenAI", class_module="test", endpoint="https://api.openai.com", model_name="gpt-4" + ), attack_results={"Attack1": [attack_result1]}, labels={"environment": "testing"}, completion_time=now, @@ -613,7 +622,9 @@ def test_combined_filters(sqlite_instance: MemoryInterface): scenario2_identifier = ScenarioIdentifier(name="Test Scenario", scenario_version=1, pyrit_version="0.4.0") scenario2 = ScenarioResult( scenario_identifier=scenario2_identifier, - objective_target_identifier={"target": "Azure", "endpoint": "https://azure.com", "model_name": "gpt-3.5"}, + objective_target_identifier=TargetIdentifier( + class_name="Azure", class_module="test", endpoint="https://azure.com", model_name="gpt-3.5" + ), attack_results={"Attack2": [attack_result2]}, labels={"environment": "production"}, completion_time=yesterday, diff --git a/tests/unit/memory/memory_interface/test_interface_scores.py b/tests/unit/memory/memory_interface/test_interface_scores.py index 7941f3b79d..d6ee9201f1 100644 --- a/tests/unit/memory/memory_interface/test_interface_scores.py +++ b/tests/unit/memory/memory_interface/test_interface_scores.py @@ -4,10 +4,10 @@ import uuid from typing import Literal, Sequence -from unittest.mock import MagicMock from uuid import uuid4 import pytest +from unit.mocks import get_mock_target from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack from pyrit.identifiers import ScorerIdentifier @@ -54,7 +54,7 @@ def test_get_scores_by_attack_id_and_label( sqlite_instance.add_scores_to_memory(scores=[score]) # Fetch the score we just added - db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier["id"]) + db_score = sqlite_instance.get_prompt_scores(attack_id=sample_conversations[0].attack_identifier.hash) assert len(db_score) == 1 assert db_score[0].score_value == score.score_value @@ -75,7 +75,7 @@ def test_get_scores_by_attack_id_and_label( assert db_score[0].score_value == score.score_value db_score = sqlite_instance.get_prompt_scores( - attack_id=sample_conversations[0].attack_identifier["id"], + attack_id=sample_conversations[0].attack_identifier.hash, labels={"x": "y"}, ) assert len(db_score) == 0 @@ -133,7 +133,7 @@ def test_add_score_get_score( def test_add_score_duplicate_prompt(sqlite_instance: MemoryInterface): # Ensure that scores of duplicate prompts are linked back to the original original_id = uuid4() - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) conversation_id = str(uuid4()) pieces = [ MessagePiece( diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ee1fdd6bee..5a58d40c1c 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,10 +9,10 @@ from typing import Generator, MutableSequence, Optional, Sequence from unittest.mock import MagicMock, patch -from pyrit.identifiers import ScorerIdentifier, TargetIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.memory import AzureSQLMemory, CentralMemory, PromptMemoryEntry from pyrit.models import Message, MessagePiece -from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute +from pyrit.prompt_target import PromptChatTarget, PromptTarget, limit_requests_per_minute def get_mock_scorer_identifier() -> ScorerIdentifier: @@ -48,6 +48,41 @@ def get_mock_target_identifier(name: str = "MockTarget", module: str = "tests.un ) +def get_mock_attack_identifier(name: str = "MockAttack", module: str = "tests.unit.mocks") -> AttackIdentifier: + """ + Returns a mock AttackIdentifier for use in tests where the specific + attack identity doesn't matter. + + Args: + name: The class name for the mock attack. Defaults to "MockAttack". + module: The module path for the mock attack. Defaults to "tests.unit.mocks". + + Returns: + An AttackIdentifier configured with the provided name and module. + """ + return AttackIdentifier( + class_name=name, + class_module=module, + ) + + +def get_mock_target(name: str = "MockTarget") -> MagicMock: + """ + Returns a MagicMock target whose ``get_identifier()`` returns a real + :class:`TargetIdentifier`. Use this wherever a ``MagicMock(spec=PromptTarget)`` + is needed as an ``objective_target``. + + Args: + name: The class name for the mock target. Defaults to "MockTarget". + + Returns: + A MagicMock configured to return a real TargetIdentifier. + """ + target = MagicMock(spec=PromptTarget) + target.get_identifier.return_value = get_mock_target_identifier(name) + return target + + class MockHttpPostAsync(AbstractAsyncContextManager): def __init__(self, url, headers=None, json=None, params=None, ssl=None): self.status = 200 @@ -100,7 +135,7 @@ def set_system_prompt( *, system_prompt: str, conversation_id: str, - attack_identifier: Optional[dict[str, str]] = None, + attack_identifier: Optional[AttackIdentifier] = None, labels: Optional[dict[str, str]] = None, ) -> None: self.system_prompt = system_prompt @@ -222,11 +257,7 @@ def get_test_message_piece() -> MessagePiece: def get_sample_conversations() -> MutableSequence[Message]: with patch.object(CentralMemory, "get_memory_instance", return_value=MagicMock()): conversation_1 = str(uuid.uuid4()) - attack_identifier = { - "__type__": "MockPromptTarget", - "__module__": "unit.mocks", - "id": str(uuid.uuid4()), - } + attack_id = get_mock_attack_identifier() return [ MessagePiece( @@ -235,7 +266,7 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="Hello, how are you?", conversation_id=conversation_1, sequence=0, - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", @@ -243,14 +274,14 @@ def get_sample_conversations() -> MutableSequence[Message]: converted_value="I'm fine, thank you!", conversation_id=conversation_1, sequence=1, - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), MessagePiece( role="assistant", original_value="original prompt text", converted_value="I'm fine, thank you!", conversation_id=str(uuid.uuid4()), - attack_identifier=attack_identifier, + attack_identifier=attack_id, ).to_message(), ] diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 78ecb60080..4881b7070f 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -8,13 +8,12 @@ import warnings from datetime import datetime, timedelta from typing import MutableSequence -from unittest.mock import MagicMock import pytest -from unit.mocks import MockPromptTarget, get_sample_conversations +from unit.mocks import MockPromptTarget, get_mock_target, get_sample_conversations from pyrit.executor.attack import PromptSendingAttack -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ConverterIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.models import ( Message, MessagePiece, @@ -83,7 +82,7 @@ def test_prompt_targets_serialize(patch_central_database): def test_executors_serialize(): - attack = PromptSendingAttack(objective_target=MagicMock()) + attack = PromptSendingAttack(objective_target=get_mock_target()) entry = MessagePiece( role="user", @@ -92,9 +91,9 @@ def test_executors_serialize(): attack_identifier=attack.get_identifier(), ) - assert entry.attack_identifier["id"] is not None - assert entry.attack_identifier["__type__"] == "PromptSendingAttack" - assert entry.attack_identifier["__module__"] == "pyrit.executor.attack.single_turn.prompt_sending" + assert entry.attack_identifier.hash is not None + assert entry.attack_identifier.class_name == "PromptSendingAttack" + assert entry.attack_identifier.class_module == "pyrit.executor.attack.single_turn.prompt_sending" @pytest.mark.asyncio @@ -664,14 +663,21 @@ def test_message_piece_to_dict(): targeted_harm_categories=["violence", "illegal"], prompt_metadata={"key": "metadata"}, converter_identifiers=[ - {"__type__": "Base64Converter", "__module__": "pyrit.prompt_converter.base64_converter"} + ConverterIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter.base64_converter", + supported_input_types=["text"], + supported_output_types=["text"], + ) ], - prompt_target_identifier={"__type__": "MockPromptTarget", "__module__": "unit.mocks"}, - attack_identifier={ - "id": str(uuid.uuid4()), - "__type__": "PromptSendingAttack", - "__module__": "pyrit.executor.attack.single_turn.prompt_sending_attack", - }, + prompt_target_identifier=TargetIdentifier( + class_name="MockPromptTarget", + class_module="unit.mocks", + ), + attack_identifier=AttackIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack.single_turn.prompt_sending_attack", + ), scorer_identifier=ScorerIdentifier( class_name="TestScorer", class_module="pyrit.score.test_scorer", @@ -746,7 +752,7 @@ def test_message_piece_to_dict(): assert result["prompt_metadata"] == entry.prompt_metadata assert result["converter_identifiers"] == [conv.to_dict() for conv in entry.converter_identifiers] assert result["prompt_target_identifier"] == entry.prompt_target_identifier.to_dict() - assert result["attack_identifier"] == entry.attack_identifier + assert result["attack_identifier"] == entry.attack_identifier.to_dict() assert result["scorer_identifier"] == entry.scorer_identifier.to_dict() assert result["original_value_data_type"] == entry.original_value_data_type assert result["original_value"] == entry.original_value diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index e02104dad9..3c8381dcd5 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field import pytest @@ -13,7 +13,7 @@ class MetadataWithTags(Identifier): """Test metadata with a tags field for list filtering tests.""" - tags: tuple[str, ...] + tags: tuple[str, ...] = field(kw_only=True) class TestMatchesFilters: diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py index 9f5744e2c1..0a774d6b1e 100644 --- a/tests/unit/registry/test_base_instance_registry.py +++ b/tests/unit/registry/test_base_instance_registry.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass +from dataclasses import dataclass, field from pyrit.identifiers import Identifier from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry @@ -11,7 +11,7 @@ class SampleItemMetadata(Identifier): """Sample metadata with an extra field.""" - category: str + category: str = field(kw_only=True) class ConcreteTestRegistry(BaseInstanceRegistry[str, SampleItemMetadata]): diff --git a/tests/unit/scenarios/test_atomic_attack.py b/tests/unit/scenarios/test_atomic_attack.py index 7cb41d0a2f..b372777def 100644 --- a/tests/unit/scenarios/test_atomic_attack.py +++ b/tests/unit/scenarios/test_atomic_attack.py @@ -75,21 +75,18 @@ def sample_attack_results(): AttackResult( conversation_id="conv-1", objective="objective1", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ), AttackResult( conversation_id="conv-2", objective="objective2", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "2"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ), AttackResult( conversation_id="conv-3", objective="objective3", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "3"}, outcome=AttackOutcome.FAILURE, executed_turns=1, ), @@ -431,7 +428,6 @@ async def test_full_attack_run_execution_flow(self, mock_attack, sample_seed_gro AttackResult( conversation_id=f"conv-{i}", objective=f"objective{i + 1}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -476,7 +472,6 @@ async def test_atomic_attack_with_single_seed_group(self, mock_attack): AttackResult( conversation_id="conv-1", objective="single_objective", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -513,7 +508,6 @@ async def test_atomic_attack_with_many_seed_groups(self, mock_attack): AttackResult( conversation_id=f"conv-{i}", objective=f"objective_{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -682,7 +676,6 @@ async def test_run_async_passes_seed_groups_with_messages(self, mock_attack, see AttackResult( conversation_id=f"conv-{i}", objective=seed_groups_with_messages[i].objective.value, - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=len(seed_groups_with_messages[i].user_messages), ) diff --git a/tests/unit/scenarios/test_jailbreak.py b/tests/unit/scenarios/test_jailbreak.py index 047334131c..c5c6f6b42d 100644 --- a/tests/unit/scenarios/test_jailbreak.py +++ b/tests/unit/scenarios/test_jailbreak.py @@ -10,6 +10,7 @@ from pyrit.executor.attack.core.attack_config import AttackScoringConfig from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedGroup, SeedObjective from pyrit.prompt_target import PromptTarget from pyrit.scenario.scenarios.airt.jailbreak import Jailbreak, JailbreakStrategy @@ -43,7 +44,7 @@ def mock_memory_seed_groups() -> List[SeedGroup]: def mock_objective_target() -> PromptTarget: """Create a mock objective target for testing.""" mock = MagicMock(spec=PromptTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockObjectiveTarget", class_module="test") return mock @@ -51,7 +52,7 @@ def mock_objective_target() -> PromptTarget: def mock_objective_scorer() -> TrueFalseInverterScorer: """Create a mock scorer for testing.""" mock = MagicMock(spec=TrueFalseInverterScorer) - mock.get_identifier.return_value = {"__type__": "MockObjectiveScorer", "__module__": "test"} + mock.get_identifier.return_value = ScorerIdentifier(class_name="MockObjectiveScorer", class_module="test") return mock diff --git a/tests/unit/scenarios/test_psychosocial_harms.py b/tests/unit/scenarios/test_psychosocial_harms.py index 4a178da7c1..8ecf25206d 100644 --- a/tests/unit/scenarios/test_psychosocial_harms.py +++ b/tests/unit/scenarios/test_psychosocial_harms.py @@ -14,6 +14,7 @@ PromptSendingAttack, RolePlayAttack, ) +from pyrit.identifiers import ScorerIdentifier, TargetIdentifier from pyrit.models import SeedDataset, SeedGroup, SeedObjective from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget from pyrit.scenario.scenarios.airt import ( @@ -72,21 +73,21 @@ def mock_runtime_env(): @pytest.fixture def mock_objective_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockObjectiveTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockObjectiveTarget", class_module="test") return mock @pytest.fixture def mock_objective_scorer() -> FloatScaleThresholdScorer: mock = MagicMock(spec=FloatScaleThresholdScorer) - mock.get_identifier.return_value = {"__type__": "MockObjectiveScorer", "__module__": "test"} + mock.get_identifier.return_value = ScorerIdentifier(class_name="MockObjectiveScorer", class_module="test") return mock @pytest.fixture def mock_adversarial_target() -> PromptChatTarget: mock = MagicMock(spec=PromptChatTarget) - mock.get_identifier.return_value = {"__type__": "MockAdversarialTarget", "__module__": "test"} + mock.get_identifier.return_value = TargetIdentifier(class_name="MockAdversarialTarget", class_module="test") return mock @@ -173,7 +174,9 @@ def test_init_default_adversarial_chat(self, *, mock_objective_scorer: FloatScal def test_init_with_adversarial_chat(self, *, mock_objective_scorer: FloatScaleThresholdScorer) -> None: adversarial_chat = MagicMock(OpenAIChatTarget) - adversarial_chat.get_identifier.return_value = {"type": "CustomAdversary"} + adversarial_chat.get_identifier.return_value = TargetIdentifier( + class_name="CustomAdversary", class_module="test" + ) scenario = PsychosocialScenario( adversarial_chat=adversarial_chat, diff --git a/tests/unit/scenarios/test_scenario.py b/tests/unit/scenarios/test_scenario.py index f7e0e6fe4e..804796167e 100644 --- a/tests/unit/scenarios/test_scenario.py +++ b/tests/unit/scenarios/test_scenario.py @@ -87,11 +87,6 @@ def sample_attack_results(): AttackResult( conversation_id=f"conv-{i}", objective=f"objective{i}", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": str(i), - }, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -526,7 +521,7 @@ def test_scenario_result_initialization(self, sample_attack_results): identifier = ScenarioIdentifier(name="Test", scenario_version=1) result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={"__type__": "TestTarget", "__module__": "test"}, + objective_target_identifier=TargetIdentifier(class_name="TestTarget", class_module="test"), attack_results={"base64": sample_attack_results[:3], "rot13": sample_attack_results[3:]}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -542,10 +537,10 @@ def test_scenario_result_with_empty_results(self): identifier = ScenarioIdentifier(name="TestScenario", scenario_version=1) result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": []}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -560,10 +555,10 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): # All successful result = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": sample_attack_results}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -574,32 +569,22 @@ def test_scenario_result_objective_achieved_rate(self, sample_attack_results): AttackResult( conversation_id="conv-fail", objective="objective", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": "1", - }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), AttackResult( conversation_id="conv-fail2", objective="objective", - attack_identifier={ - "__type__": "TestAttack", - "__module__": "test", - "id": "2", - }, outcome=AttackOutcome.FAILURE, executed_turns=1, ), ] result2 = ScenarioResult( scenario_identifier=identifier, - objective_target_identifier={ - "__type__": "TestTarget", - "__module__": "test", - }, + objective_target_identifier=TargetIdentifier( + class_name="TestTarget", + class_module="test", + ), attack_results={"base64": mixed_results}, objective_scorer_identifier=_TEST_SCORER_ID, ) @@ -638,10 +623,10 @@ def create_mock_truefalse_scorer(): from pyrit.score import TrueFalseScorer mock_scorer = MagicMock(spec=TrueFalseScorer) - mock_scorer.get_identifier.return_value = { - "__type__": "MockTrueFalseScorer", - "__module__": "test", - } + mock_scorer.get_identifier.return_value = ScorerIdentifier( + class_name="MockTrueFalseScorer", + class_module="test", + ) mock_scorer.get_scorer_metrics.return_value = None # Make isinstance check work mock_scorer.__class__ = TrueFalseScorer diff --git a/tests/unit/scenarios/test_scenario_partial_results.py b/tests/unit/scenarios/test_scenario_partial_results.py index 1886c36395..c93d9a4309 100644 --- a/tests/unit/scenarios/test_scenario_partial_results.py +++ b/tests/unit/scenarios/test_scenario_partial_results.py @@ -138,7 +138,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -156,7 +155,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id="conv-3", objective="obj3", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "3"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -199,7 +197,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -258,7 +255,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -275,7 +271,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{i}", objective=f"obj{i}", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": str(i)}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -335,7 +330,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id="conv-a2-1", objective="a2_obj1", - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": "a2_1"}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) @@ -351,7 +345,6 @@ async def mock_run(*args, **kwargs): AttackResult( conversation_id=f"conv-{obj}", objective=obj, - attack_identifier={"__type__": "TestAttack", "__module__": "test", "id": obj}, outcome=AttackOutcome.SUCCESS, executed_turns=1, ) diff --git a/tests/unit/scenarios/test_scenario_retry.py b/tests/unit/scenarios/test_scenario_retry.py index 69bcecef01..36ba15d4c3 100644 --- a/tests/unit/scenarios/test_scenario_retry.py +++ b/tests/unit/scenarios/test_scenario_retry.py @@ -69,7 +69,6 @@ def create_attack_result( return AttackResult( conversation_id=conversation_id or f"{CONV_ID_PREFIX}{index}", objective=objective or f"{OBJECTIVE_PREFIX}{index}", - attack_identifier={"__type__": TEST_ATTACK_TYPE, "__module__": TEST_MODULE, "id": str(index)}, outcome=outcome, executed_turns=executed_turns, ) diff --git a/tests/unit/score/test_conversation_history_scorer.py b/tests/unit/score/test_conversation_history_scorer.py index 0cdbd6cae0..dd27ac642f 100644 --- a/tests/unit/score/test_conversation_history_scorer.py +++ b/tests/unit/score/test_conversation_history_scorer.py @@ -7,7 +7,7 @@ import pytest -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier, TargetIdentifier from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.score import ( @@ -244,8 +244,8 @@ async def test_conversation_history_scorer_preserves_metadata(patch_central_data original_value="Response", conversation_id=conversation_id, labels={"test": "label"}, - prompt_target_identifier={"target": "test"}, - attack_identifier={"attack": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="test", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), sequence=1, ) diff --git a/tests/unit/score/test_scorer.py b/tests/unit/score/test_scorer.py index 8fc2f55457..01c31c1127 100644 --- a/tests/unit/score/test_scorer.py +++ b/tests/unit/score/test_scorer.py @@ -10,7 +10,7 @@ from unit.mocks import get_mock_target_identifier from pyrit.exceptions import InvalidJsonException, remove_markdown_json -from pyrit.identifiers import ScorerIdentifier +from pyrit.identifiers import AttackIdentifier, ScorerIdentifier from pyrit.memory import CentralMemory from pyrit.models import Message, MessagePiece, Score from pyrit.prompt_target import PromptChatTarget @@ -206,7 +206,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j chat_target.set_system_prompt = MagicMock() expected_system_prompt = "system_prompt" - expected_attack_id = "attack_id" + expected_attack_identifier = AttackIdentifier(class_name="TestAttack", class_module="test.module") expected_scored_prompt_id = "123" await scorer._score_value_with_llm( @@ -217,7 +217,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j scored_prompt_id=expected_scored_prompt_id, category="category", objective="task", - attack_identifier={"id": expected_attack_id}, + attack_identifier=expected_attack_identifier, ) chat_target.set_system_prompt.assert_called_once() @@ -225,8 +225,7 @@ async def test_scorer_score_value_with_llm_use_provided_attack_identifier(good_j _, set_sys_prompt_args = chat_target.set_system_prompt.call_args assert set_sys_prompt_args["system_prompt"] == expected_system_prompt assert isinstance(set_sys_prompt_args["conversation_id"], str) - assert set_sys_prompt_args["attack_identifier"]["id"] == expected_attack_id - assert set_sys_prompt_args["attack_identifier"]["scored_prompt_id"] == expected_scored_prompt_id + assert set_sys_prompt_args["attack_identifier"] is expected_attack_identifier @pytest.mark.asyncio diff --git a/tests/unit/target/test_http_target.py b/tests/unit/target/test_http_target.py index 5d49702b07..088e12270f 100644 --- a/tests/unit/target/test_http_target.py +++ b/tests/unit/target/test_http_target.py @@ -67,7 +67,9 @@ def test_http_target_sets_endpoint_and_rate_limit(mock_callback_function, sqlite @patch("httpx.AsyncClient.request") async def test_send_prompt_async(mock_request, mock_http_target, mock_http_response): message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] mock_request.return_value = mock_http_response response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -113,7 +115,7 @@ async def test_send_prompt_async_client_kwargs(): # Use **httpx_client_kwargs to pass them as keyword arguments http_target = HTTPTarget(http_request=sample_request, **httpx_client_kwargs) message = MagicMock() - message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None)] + message.message_pieces = [MagicMock(converted_value="", prompt_target_identifier=None, attack_identifier=None)] mock_response = MagicMock() mock_response.content = b"Response content" mock_request.return_value = mock_response @@ -148,7 +150,9 @@ async def test_send_prompt_regex_parse_async(mock_request, mock_http_target): mock_http_target.callback_function = callback_function message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] mock_response = MagicMock() mock_response.content = b"Match: 1234" @@ -175,7 +179,9 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send first prompt message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] response = await mock_http_target.send_prompt_async(message=message) assert len(response) == 1 @@ -193,7 +199,9 @@ async def test_send_prompt_async_keeps_original_template(mock_request, mock_http # Send second prompt second_message = MagicMock() - second_message.message_pieces = [MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None)] + second_message.message_pieces = [ + MagicMock(converted_value="second_test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] await mock_http_target.send_prompt_async(message=second_message) # Assert that the original template is still the same @@ -241,7 +249,9 @@ async def test_http_target_with_injected_client(): mock_request.return_value = mock_response message = MagicMock() - message.message_pieces = [MagicMock(converted_value="test_prompt", prompt_target_identifier=None)] + message.message_pieces = [ + MagicMock(converted_value="test_prompt", prompt_target_identifier=None, attack_identifier=None) + ] response = await target.send_prompt_async(message=message) diff --git a/tests/unit/target/test_openai_chat_target.py b/tests/unit/target/test_openai_chat_target.py index 6ba29eb2ea..e257aecb17 100644 --- a/tests/unit/target/test_openai_chat_target.py +++ b/tests/unit/target/test_openai_chat_target.py @@ -24,6 +24,7 @@ PyritException, RateLimitException, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig @@ -305,8 +306,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -316,8 +317,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory(openai_response_j converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -399,8 +400,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIChatT converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -410,8 +411,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIChatT converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -447,8 +448,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -458,8 +459,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/target/test_openai_response_target.py b/tests/unit/target/test_openai_response_target.py index c60e1ca559..e6b083efe8 100644 --- a/tests/unit/target/test_openai_response_target.py +++ b/tests/unit/target/test_openai_response_target.py @@ -21,6 +21,7 @@ PyritException, RateLimitException, ) +from pyrit.identifiers import AttackIdentifier, TargetIdentifier from pyrit.memory.memory_interface import MemoryInterface from pyrit.models import Message, MessagePiece from pyrit.models.json_response_config import _JsonResponseConfig @@ -317,8 +318,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -328,8 +329,8 @@ async def test_send_prompt_async_empty_response_adds_to_memory( converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -412,8 +413,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -423,8 +424,8 @@ async def test_send_prompt_async(openai_response_json: dict, target: OpenAIRespo converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] @@ -459,8 +460,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value="hello", original_value_data_type="text", converted_value_data_type="text", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), MessagePiece( @@ -470,8 +471,8 @@ async def test_send_prompt_async_empty_response_retries(openai_response_json: di converted_value=tmp_file_name, original_value_data_type="image_path", converted_value_data_type="image_path", - prompt_target_identifier={"target": "target-identifier"}, - attack_identifier={"test": "test"}, + prompt_target_identifier=TargetIdentifier(class_name="target-identifier", class_module="test"), + attack_identifier=AttackIdentifier(class_name="test", class_module="test"), labels={"test": "test"}, ), ] diff --git a/tests/unit/target/test_prompt_target.py b/tests/unit/target/test_prompt_target.py index e258d2e0c2..e6eb37bb35 100644 --- a/tests/unit/target/test_prompt_target.py +++ b/tests/unit/target/test_prompt_target.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import uuid from typing import MutableSequence from unittest.mock import AsyncMock, MagicMock, patch @@ -9,6 +8,7 @@ from unit.mocks import get_sample_conversations, openai_chat_response_json_dict from pyrit.executor.attack.core.attack_strategy import AttackStrategy +from pyrit.identifiers import AttackIdentifier from pyrit.models import Message, MessagePiece from pyrit.prompt_target import OpenAIChatTarget @@ -39,11 +39,10 @@ def mock_attack_strategy(): strategy = MagicMock(spec=AttackStrategy) strategy.execute_async = AsyncMock() strategy.execute_with_context_async = AsyncMock() - strategy.get_identifier.return_value = { - "__type__": "TestAttack", - "__module__": "pyrit.executor.attack.test_attack", - "id": str(uuid.uuid4()), - } + strategy.get_identifier.return_value = AttackIdentifier( + class_name="TestAttack", + class_module="pyrit.executor.attack.test_attack", + ) return strategy