From e06fcc1e59a9c22eab744b9597d7019b819024c4 Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Fri, 24 Apr 2026 15:39:34 +0100 Subject: [PATCH 1/6] feat: add interactive docs --- spikee/cli.py | 64 ++- spikee/docs.py | 1006 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1069 insertions(+), 1 deletion(-) create mode 100644 spikee/docs.py diff --git a/spikee/cli.py b/spikee/cli.py index c3a9d14..185ca8f 100644 --- a/spikee/cli.py +++ b/spikee/cli.py @@ -5,6 +5,7 @@ import sys import shutil import argparse +import argcomplete from . import __version__ from dotenv import load_dotenv from pathlib import Path @@ -28,6 +29,7 @@ list_providers, ) from .viewers.results import ResultsViewer +from .docs import docs_command banner = r""" _____ _____ _____ _ ________ ______ @@ -593,6 +595,64 @@ def main(): help="Include descriptions of modules where available", ) + # === [DOCS] Sub-command ================================================ + parser_docs = subparsers.add_parser( + "docs", + help="Generate spikee commands or explain spikee commands" + ) + + # Create subparsers for docs command + docs_subparsers = parser_docs.add_subparsers( + dest="subcommand", + help="Subcommands for docs" + ) + + # docs generate subcommand + parser_docs_generate = docs_subparsers.add_parser( + "generate", + help="Generate spikee commands using natural language queries" + ) + parser_docs_generate.add_argument( + "query", + type=str, + nargs="+", + help="Natural language description of desired spikee command" + ) + parser_docs_generate.add_argument( + "--model", + type=str, + default=None, + help="LLM model to use for generation (default: openai/gpt-4o)" + ) + parser_docs_generate.add_argument( + "--verbose", + action="store_true", + help="Show debug information (classification, context size)" + ) + + # docs explain subcommand + parser_docs_explain = docs_subparsers.add_parser( + "explain", + help="Explain spikee commands or provide information about them" + ) + parser_docs_explain.add_argument( + "query", + type=str, + nargs="+", + help="Query about spikee commands to explain" + ) + parser_docs_explain.add_argument( + "--model", + type=str, + default=None, + help="LLM model to use for explanation (default: openai/gpt-4o)" + ) + parser_docs_explain.add_argument( + "--verbose", + action="store_true", + help="Show debug information (classification, context size)" + ) + args = convert_to_new_args(parser.parse_args()) # Print banner and info unless quiet mode is enabled @@ -654,6 +714,8 @@ def main(): list_providers(args) else: parser_list.print_help() + elif args.command == "docs": + docs_command(args) else: parser.print_help() sys.exit(1) @@ -773,4 +835,4 @@ def copy_builtin_modules(include_option, force=False): print(f"[init] No built-in {module_type} were copied") except Exception as e: - print(f"[init] Error processing {module_type}: {e}") + print(f"[init] Error processing {module_type}: {e}") \ No newline at end of file diff --git a/spikee/docs.py b/spikee/docs.py new file mode 100644 index 0000000..6b83863 --- /dev/null +++ b/spikee/docs.py @@ -0,0 +1,1006 @@ +""" +LLM-powered spikee command generator with optimized context loading. + +This module provides natural language command generation for spikee using an LLM provider +with a two-stage approach: +1. First, classify the query to determine command type +2. Then, load only relevant documentation and module information + +Usage: + spikee docs "test gpt-4o-mini with the cybersec dataset" + spikee docs "generate dataset with base64 plugin" --model bedrock/claude45-haiku +""" + +import sys +import re +from typing import Tuple, Dict, Any + +from spikee.utilities.llm import get_llm +from spikee.utilities.llm_message import SystemMessage, HumanMessage +from spikee.utilities.modules import ( + extract_json_or_fail, + load_module_from_path, + get_options_from_module, + get_description_from_module, + collect_modules, + collect_seeds, + collect_datasets, +) + +try: + from rich.console import Console + from rich.syntax import Syntax +except ImportError: + Console = None + +DEFAULT_MODEL = "openai/gpt-4o" + +# === Stage 1: Command Classification === + +CLASSIFIER_PROMPT = """You are a command classifier for the SPIKEE toolkit. + +Analyze the user's query and determine which spikee command they want to use. + +Available commands: +- generate: Creating test datasets from seed folders +- test: Testing targets with datasets (includes attacks, judges, sampling) +- results: Analyzing, rejudging, or extracting test results +- list: Listing available modules (seeds, datasets, targets, plugins, attacks, judges, providers) +- init: Initializing workspace +- viewer: Launching web viewers +- unknown: Cannot determine or doesn't match spikee commands + +Respond with ONLY a JSON object: +{ + "command": "", + "confidence": "" +} + +Examples: +"test gpt-4o with my dataset" -> {"command": "test", "confidence": "high"} +"generate dataset with plugins" -> {"command": "generate", "confidence": "high"} +"show me all plugins" -> {"command": "list", "confidence": "high"} +"analyze my results" -> {"command": "results", "confidence": "high"} +"setup workspace" -> {"command": "init", "confidence": "medium"} +""" + +# === Stage 2: Modular Documentation Sections === + +COMMON_HEADER = """You are an expert assistant for the SPIKEE toolkit - a prompt injection and jailbreaking testing framework. + +Your task is to generate valid spikee CLI commands based on natural language descriptions from users. + +# Common Patterns: + +1. **LLM Provider Format**: Always use "provider/model" format + - OpenAI: "openai/gpt-4o", "openai/gpt-4o-mini" + - Bedrock: "bedrock/claude45-sonnet", "bedrock/claude45-haiku" + - Azure: "azure/gpt-4" + - Groq: "groq/llama-3.1-70b" + - DeepSeek: "deepseek/deepseek-chat" + +2. **Plugin Piping**: Use | to pipe plugins: "plugin1|plugin2|plugin3" + +3. **Options Format**: "module:key1=val1,key2=val2;module2:key3=val3" + +4. **Dataset Wildcards**: Can use wildcards in paths: "datasets/cybersec-*.jsonl" +""" + +GENERATE_DOCS = """ +## GENERATE Command + +### Required Arguments: +- `--seed-folder ` - REQUIRED: Path to seed folder (e.g., datasets/seeds-cybersec-2026-01) + +### Optional Source Arguments: +- `--include-standalone-inputs` - Include standalone_user_inputs.jsonl +- `--include-system-message` - Include system_messages.toml +- `--tag ` - Tag for dataset filename + +### Optional Transformation Arguments: +- `--plugins ` - Space-separated list of plugins OR piped plugins with | (e.g., "1337 base64" or "splat|base64") +- `--plugin-options ""` - Plugin options: "plugin1:option1=value1,option2=value2;plugin2:option2=value2" +- `--plugin-only` - Only output plugin entries +- `--include-fixes ` - Comma-separated: adv_prefixes, adv_suffixes, prefixes=, suffixes=, prefix=, suffix= + +### Optional Formatting Arguments: +- `--format ` - Output format: user-input (default/apps), full-prompt (LLMs), or burp +- `--languages ` - Comma-separated list of languages to filter (e.g., en) +- `--match-languages` - Only combine jailbreaks/instructions with matching languages (default: True) +- `--positions ` - Position to insert jailbreaks: start, middle, end (ignored if present) +- `--injection-delimiters ` - Delimiters for injecting jailbreaks (default: \\nINJECTION_PAYLOAD\\n) +- `--spotlighting-data-markers ` - Comma-separated data markers (placeholder: "DOCUMENT") +- `--instruction-filter ` - Comma-separated instruction types to include +- `--jailbreak-filter ` - Comma-separated jailbreak types to include + +### Examples: +```bash +# Basic generation +spikee generate --seed-folder datasets/seeds-cybersec-2026-01 + +# With plugins +spikee generate --seed-folder datasets/seeds-toxic-chat --plugins "1337 base64" + +# Plugin piping +spikee generate --seed-folder datasets/seeds-cybersec-2026-01 --plugins "splat|base64" + +# With plugin options +spikee generate --seed-folder datasets/seeds-example --plugins best_of_n --plugin-options "best_of_n:variants=50" + +# With standalone inputs +spikee generate --seed-folder datasets/seeds-in-the-wild --include-standalone-inputs + +# With adversarial fixes +spikee generate --seed-folder datasets/seeds-cybersec-2026-01 --include-fixes "adv_prefixes,adv_suffixes" +``` +""" + +TEST_DOCS = """ +## TEST Command + +### Required Dataset Arguments (at least one required): +- `--dataset ` - Path to dataset JSONL file (can be used multiple times) +- `--dataset-folder ` - Path to folder with multiple JSONL files (can be used multiple times) + +### Required Module Arguments: +- `--target ` - REQUIRED: Target module name (e.g., llm_provider, aws_bedrock_guardrail) + +### Optional Module Arguments: +- `--target-options ""` - Target options, typically "provider/model" format + - Examples: "openai/gpt-4o-mini", "bedrock/claude45-sonnet", "azure/gpt-4" +- `--judge-options ""` - LLM judge model (format: "model=provider/model" or just "provider/model") + - Examples: "bedrock/claude45-haiku", "openai/gpt-4o" + - Only needed for datasets requiring semantic evaluation (not canary-based) + +### Optional Testing Arguments: +- `--threads ` - Number of parallel threads (default: 4) +- `--attempts ` - Number of attempts per entry (default: 1) +- `--max-retries ` - Number of retries for rate-limiting/429 errors (default: 3) +- `--throttle ` - Time to wait between entries per thread (default: 0) +- `--sample ` - Sample percentage of dataset (e.g., 0.15 for 15%, default: 1) +- `--sample-seed ` - Seed for random sampling (default: 42) +- `--tag ` - Tag for results filename + +### Optional Attack Arguments: +- `--attack ` - Attack module to use +- `--attack-iterations ` - Number of attack iterations/turns per entry +- `--attack-options ""` - Attack-specific options +- `--attack-only` - Only run attack module, skip standard attempts + +### Optional Resume Arguments: +- `--resume-file ` - Resume from specific results JSONL file (single dataset only) +- `--auto-resume` - Silently resume from latest matching results file +- `--no-auto-resume` - Create new results file, don't resume + +### Examples: +```bash +# Basic test +spikee test --dataset datasets/cybersec-2026-01.jsonl --target llm_provider --target-options "openai/gpt-4o-mini" + +# With LLM judge +spikee test --dataset datasets/harmful.jsonl --target llm_provider --target-options "openai/gpt-4o-mini" --judge-options "bedrock/claude45-haiku" + +# Multiple datasets +spikee test --dataset datasets/dataset1.jsonl --dataset datasets/dataset2.jsonl --target llm_provider --target-options "bedrock/claude45-sonnet" + +# With attack +spikee test --dataset datasets/example.jsonl --target llm_provider --target-options "openai/gpt-4o" --attack best_of_n --attack-iterations 25 + +# With sampling +spikee test --dataset datasets/large.jsonl --target llm_provider --target-options "openai/gpt-4o" --sample 0.1 --sample-seed 123 +``` +""" + +RESULTS_DOCS = """ +## RESULTS Command + +### Subcommands: +1. `results analyze` - Analyze test results with statistics and visualizations +2. `results rejudge` - Re-judge results with different judge +3. `results extract` - Extract specific results by category or search term +4. `results dataset-comparison` - Compare datasets across multiple targets +5. `results convert-to-excel` - Convert results JSONL to Excel format + +### analyze Arguments: +- `--results-file ` - Path to results JSONL file (can be used multiple times) +- `--results-folder ` - Path to folder with results files (can be used multiple times) +- `--false-positive-checks ` - JSONL file with benign prompts for FP analysis (single dataset only) +- `--output-format ` - Output format: console (default) or html +- `--overview` - Only output general statistics +- `--combine` - Combine multiple results files into single analysis + +### rejudge Arguments: +- `--results-file ` - Path to results JSONL file (can be used multiple times) +- `--results-folder ` - Path to folder with results files (can be used multiple times) +- `--judge-options ` - Options to pass to the judge +- `--resume` - Resume from most recent re-judge file + +### extract Arguments: +- `--results-file ` - Path to results JSONL file (can be used multiple times) +- `--results-folder ` - Path to folder with results files (can be used multiple times) +- `--category ` - Category: success (default), failure, error, guardrail, no-guardrail, custom +- `--custom-search ` - Custom search: 'string', 'field:string', or '!string' to invert +- `--tag ` - Tag for results filename + +### convert-to-excel Arguments: +- `--result-file ` - Path to results JSONL file (required) + +### Examples: +```bash +# Analyze results +spikee results analyze --results-file results/test-run.jsonl + +# Rejudge with different judge +spikee results rejudge --results-file results/test.jsonl --judge-options "openai/gpt-4o" + +# Extract successful prompts +spikee results extract --results-file results/test.jsonl --category success +``` +""" + +LIST_DOCS = """ +## LIST Command + +### Subcommands: +- `list seeds` - List available seed folders +- `list datasets` - List available dataset JSONL files +- `list targets` - List available targets +- `list judges` - List available judges +- `list plugins` - List available plugins +- `list attacks` - List available attack scripts +- `list providers` - List available LLM providers + +### Optional Arguments (for targets, judges, plugins, attacks, providers): +- `-d`, `--description` - Include module descriptions + +### Examples: +```bash +spikee list seeds +spikee list targets --description +spikee list plugins -d +``` +""" + +INIT_DOCS = """ +## INIT Command + +### Arguments: +- `--force` - Overwrite existing directories +- `--include-builtin ` - Copy built-in modules to local workspace +- `--include-viewer` - Include built-in web viewer in local workspace + +### Examples: +```bash +# Basic workspace initialization +spikee init + +# With built-in modules +spikee init --include-builtin all + +# Force overwrite +spikee init --force +``` +""" + +VIEWER_DOCS = """ +## VIEWER Command + +### Subcommands: +- `viewer results` - Launch results viewer + +### Common Arguments: +- `-h`, `--host
` - Host address (default: 127.0.0.1) +- `-p`, `--port ` - Port number (default: 8080) +- `-d`, `--debug` - Enable debug mode with hot-reloading (default: False) +- `--truncate ` - Truncate long fields (default: 500 chars, 0 to disable) + +### results Viewer Arguments: +- `--result-file ` - Path to results JSONL file (can be used multiple times) +- `--result-folder ` - Path to results folder (can be used multiple times) +- `--allow-ast` - Allow AST parsing (use with caution) + +### Examples: +```bash +# Launch results viewer +spikee viewer results --result-folder results/ + +# Custom port +spikee viewer -p 8081 results --result-file results/test.jsonl +``` +""" + +RESPONSE_FORMAT = """ +# Your Response Format: + +You MUST respond with ONLY valid JSON in this exact format: + +{ + "command": "spikee ", + "explanation": "Clear explanation of what this command does", + "options": { + "useful_module_options": ["List of 2-4 useful module-specific options (e.g., --plugin-options, --attack-options, --target-options) that could enhance this command. ONLY include if the command uses modules like plugins, attacks, targets, or judges"] + } +} + +# Important Guidelines: + +1. Generate VALID spikee commands only - use exact argument names and formats shown above +2. ONLY use modules from the available modules list provided +3. Use real seed folders and datasets from the available lists when provided +4. If paths are not specified, use appropriate placeholders from the available lists +5. Use appropriate defaults (e.g., openai/gpt-4o-mini for testing if not specified) +6. Include clear explanations that help users understand what the command does +7. If user mentions a specific LLM provider/model, use it in the command +8. When using LLM-based modules, always include model in options +9. In the "options" field, ONLY suggest useful module-specific options if the command uses modules (plugins, attacks, targets, judges). Do NOT include general command arguments like --threads or --sample +10. Return ONLY the JSON - no additional text before or after +""" + + +def parse_query_for_model(query: str) -> Tuple[str, str]: + """ + Extract model specification from query if present. + + Patterns detected: + - "using openai/gpt-4o" + - "with bedrock/claude45-sonnet" + - "model=openai/gpt-4" + + Args: + query: Natural language query string + + Returns: + Tuple of (cleaned_query, model_name_or_none) + """ + pattern = r'\b(using|with|model=)\s*([a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+)\b' + match = re.search(pattern, query, re.IGNORECASE) + + if match: + model = match.group(2) + cleaned = re.sub(pattern, '', query, flags=re.IGNORECASE).strip() + cleaned = re.sub(r'\s+', ' ', cleaned) + return cleaned, model + + return query, "" + + +def classify_command(query: str, model: str = DEFAULT_MODEL) -> Dict[str, str]: + """ + Classify the user's query to determine which spikee command they want. + + Args: + query: Natural language query + model: LLM model to use for classification + + Returns: + Dictionary with command type and confidence level + """ + try: + # Use faster, cheaper model with minimal tokens for classification + llm = get_llm(model, max_tokens=50, temperature=0) + if llm is None: + raise ValueError("Failed to initialize LLM") + + messages = [ + SystemMessage(content=CLASSIFIER_PROMPT), + HumanMessage(content=query) + ] + + response = llm.invoke(messages).content + + if not isinstance(response, str) or response.strip() == "": + raise ValueError("LLM returned empty response for classification") + + result = extract_json_or_fail(response) + + return result + except Exception as e: + # Fallback to general command if classification fails + return {"command": "unknown", "confidence": "low"} + + +def get_seeds_info() -> str: + """ + Get information about available seed folders in the workspace. + + Returns: + Formatted string with seed folder names + """ + try: + seeds = collect_seeds() + + if not seeds: + return "\n## Available Seed Folders:\n\nNo seed folders found in datasets/. Use 'spikee init' to initialize workspace." + + lines = ["\n## Available Seed Folders:\n"] + lines.append("Found in datasets/ directory:") + for seed in seeds: + lines.append(f"- datasets/{seed}") + + return "\n".join(lines) + except Exception as e: + return f"\n## Available Seed Folders:\n\nError loading seeds: {str(e)}" + + +def get_datasets_info() -> str: + """ + Get information about available datasets in the workspace. + + Returns: + Formatted string with dataset names + """ + try: + datasets = collect_datasets() + + if not datasets: + return "\n## Available Datasets:\n\nNo datasets found in datasets/. Generate datasets using 'spikee generate'." + + lines = ["\n## Available Datasets:\n"] + lines.append("Found in datasets/ directory:") + for dataset in datasets: + lines.append(f"- datasets/{dataset}") + + return "\n".join(lines) + except Exception as e: + return f"\n## Available Datasets:\n\nError loading datasets: {str(e)}" + + +def get_module_info(module_type: str) -> str: + """ + Get information about available modules from local and built-in sources. + + Args: + module_type: Type of module (plugins, targets, attacks, judges, providers) + + Returns: + Formatted string with module names and options + """ + try: + # Use collect_modules from utilities + all_names, local_names, builtin_names = collect_modules(module_type) + + if not all_names: + return f"\n## Available {module_type.title()}:\n\nNo {module_type} available." + + lines = [f"\n## Available {module_type.title()}:\n"] + + # Track if any module requires LLM + any_llm_required = False + + for name in all_names: + try: + # Load module to get info + module = load_module_from_path(name, module_type) + + # Get options + options_result = get_options_from_module(module, module_type) + util_llm = False + options = [] + + if options_result is not None: + if isinstance(options_result, tuple) and len(options_result) == 2: + options, util_llm = options_result + else: + options = options_result if isinstance(options_result, list) else [options_result] + + if util_llm: + any_llm_required = True + + # Get description + description_result = get_description_from_module(module, module_type) + tags = [] + description = "" + + if description_result is not None: + if isinstance(description_result, tuple) and len(description_result) == 2: + tags, description = description_result + else: + description = description_result if isinstance(description_result, str) else "" + + # Build module line + line = f"- **{name}**" + + # Add source indicator + if name in local_names: + line += " [Local]" + + # Add tags if available + if tags: + if hasattr(tags[0], 'value'): # ModuleTag enum + tags_str = ", ".join([tag.value for tag in tags]) + else: + tags_str = ", ".join(str(tag) for tag in tags) + line += f" [{tags_str}]" + + # Add options if available + if options and len(options) > 0 and not (isinstance(options[0], str) and options[0].startswith(" str: + """ + Build contextual documentation based on detected command type. + + Args: + command_type: The type of command (generate, test, results, etc.) + + Returns: + Relevant documentation string + """ + context = COMMON_HEADER + + if command_type == "generate": + context += GENERATE_DOCS + context += get_seeds_info() # Add available seed folders + context += get_module_info("plugins") + + elif command_type == "test": + context += TEST_DOCS + context += get_datasets_info() # Add available datasets + context += get_module_info("targets") + context += get_module_info("attacks") + context += get_module_info("judges") + + elif command_type == "results": + context += RESULTS_DOCS + context += get_module_info("judges") # For rejudge + + elif command_type == "list": + context += LIST_DOCS + + elif command_type == "init": + context += INIT_DOCS + + elif command_type == "viewer": + context += VIEWER_DOCS + + else: # unknown - provide general overview + context += "\n## All Commands Available:\n" + context += "- generate: Create test datasets\n" + context += "- test: Test targets with datasets\n" + context += "- results: Analyze, rejudge, or extract results\n" + context += "- list: List available modules\n" + context += "- init: Initialize workspace\n" + context += "- viewer: Launch web viewers\n" + context += "\nPlease rephrase your query to be more specific.\n" + + context += RESPONSE_FORMAT + + return context + + +def build_explanation_context(command_type: str) -> str: + """ + Build contextual documentation for explaining spikee commands. + + Args: + command_type: The type of command being explained (generate, test, results, etc.) + + Returns: + Relevant documentation string for explanations + """ + # Start with common header + context = COMMON_HEADER + + # Add command-specific explanation context + context += "\n## Command Explanation Context\n" + context += "This section provides detailed information to explain spikee commands based on user queries.\n" + + if command_type == "generate": + context += "\n### Generate Command Context\n" + context += "The generate command creates test datasets from seed folders with optional transformations.\n" + context += GENERATE_DOCS + context += get_seeds_info() + context += get_module_info("plugins") + + elif command_type == "test": + context += "\n### Test Command Context\n" + context += "The test command evaluates targets with datasets, optionally using attacks and judges.\n" + context += TEST_DOCS + context += get_datasets_info() + context += get_module_info("targets") + context += get_module_info("attacks") + context += get_module_info("judges") + + elif command_type == "results": + context += "\n### Results Command Context\n" + context += "The results command analyzes, rejudges, or extracts test results.\n" + context += RESULTS_DOCS + context += get_module_info("judges") + + elif command_type == "list": + context += "\n### List Command Context\n" + context += "The list command shows available modules for various categories.\n" + context += LIST_DOCS + + elif command_type == "init": + context += "\n### Init Command Context\n" + context += "The init command initializes a new spikee workspace.\n" + context += INIT_DOCS + + elif command_type == "viewer": + context += "\n### Viewer Command Context\n" + context += "The viewer command launches web viewers for results.\n" + context += VIEWER_DOCS + + else: # unknown - provide general overview + context += "\n### General Command Context\n" + context += "When the command type is unknown, here's a general overview of all spikee commands:\n" + context += "- generate: Create test datasets\n" + context += "- test: Test targets with datasets\n" + context += "- results: Analyze, rejudge, or extract results\n" + context += "- list: List available modules\n" + context += "- init: Initialize workspace\n" + context += "- viewer: Launch web viewers\n" + context += "\nPlease rephrase your query to be more specific.\n" + + # Add explanation-specific response format + context += "\n# Your Response Format:\n" + context += "You MUST respond with ONLY valid JSON in this exact format:\n" + context += "{\n" + context += " \"explanation\": \"Clear explanation of the spikee command(s) requested\"\n" + context += "}\n" + + return context + + +def display_explanation(explanation_dict: Dict[str, Any]) -> None: + """ + Display generated explanation with formatting. + + Args: + explanation_dict: Dictionary containing explanation + """ + if Console is None: + # Fallback to plain text if rich is not available + print("\nExplanation:") + print(explanation_dict["explanation"]) + print() + return + + console = Console() + + # Display header + console.print() + console.print("[bold cyan]Explanation:[/bold cyan]") + + # Display explanation + console.print() + console.print("[bold green]" + explanation_dict["explanation"] + "[/bold green]") + console.print() + + +def generate_command(query: str, model: str = DEFAULT_MODEL, verbose: bool = False) -> Dict[str, Any]: + """ + Generate a spikee command from natural language query using an LLM with optimized context. + + This uses a two-stage approach: + 1. Classify the query to determine command type (fast, minimal tokens) + 2. Load only relevant documentation and generate command (focused context) + + Args: + query: Natural language description of desired command + model: LLM model to use (format: provider/model) + verbose: If True, print classification info + + Returns: + Dictionary with keys: command, explanation, options (containing useful_module_options) + + Raises: + Exception: If LLM call fails or response parsing fails + """ + try: + # Stage 1: Classify command type + classification = classify_command(query, model) + command_type = classification.get("command", "unknown") + confidence = classification.get("confidence", "low") + + if verbose: + print(f"[Debug] Classified as: {command_type} (confidence: {confidence})") + + # Stage 2: Build context and generate command + context = build_context_for_command(command_type) + + if verbose: + print(f"[Debug] Context size: {len(context)} chars") + + # Initialize LLM with appropriate token limit (increased for options) + llm = get_llm(model, max_tokens=1200, temperature=0.3) + + # Create messages + messages = [ + SystemMessage(content=context), + HumanMessage(content=query) + ] + + # Get response + response = llm.invoke(messages) + + # Parse JSON response + result = extract_json_or_fail(response.content) + + # Validate required fields + if "command" not in result or "explanation" not in result: + raise ValueError("LLM response missing required fields (command, explanation)") + + # Options field is optional but should have default structure if missing + if "options" not in result: + result["options"] = { + "useful_module_options": [] + } + + return result + + except ImportError as e: + raise Exception(f"Invalid LLM provider specification: {e}\n" + f"Use format: provider/model (e.g., openai/gpt-4o)") + except Exception as e: + raise Exception(f"Error generating command: {e}\n" + "Please try rephrasing your query or check your API credentials in .env file") + + +def explain_command(query: str, model: str = DEFAULT_MODEL, verbose: bool = False) -> Dict[str, Any]: + """ + Generate an explanation of spikee commands for natural language queries. + + Args: + query: Natural language query about spikee commands + model: LLM model to use (format: provider/model) + verbose: If True, print classification info + + Returns: + Dictionary with key: explanation + + Raises: + Exception: If LLM call fails or response parsing fails + """ + try: + # Stage 1: Classify command type + classification = classify_command(query, model) + command_type = classification.get("command", "unknown") + confidence = classification.get("confidence", "low") + + if verbose: + print(f"[Debug] Classified as: {command_type} (confidence: {confidence})") + + # Stage 2: Build context for explanation + context = build_explanation_context(command_type) + + if verbose: + print(f"[Debug] Context size: {len(context)} chars") + + # Initialize LLM with appropriate token limit + llm = get_llm(model, max_tokens=800, temperature=0.3) + + # Create messages + messages = [ + SystemMessage(content=context), + HumanMessage(content=query) + ] + + # Get response + response = llm.invoke(messages) + + # Parse JSON response + result = extract_json_or_fail(response.content) + + # Validate required fields + if "explanation" not in result: + raise ValueError("LLM response missing required field (explanation)") + + return result + + except ImportError as e: + raise Exception(f"Invalid LLM provider specification: {e}\n" + f"Use format: provider/model (e.g., openai/gpt-4o)") + except Exception as e: + raise Exception(f"Error generating explanation: {e}\n" + "Please try rephrasing your query or check your API credentials in .env file") + + +def display_command(command_dict: Dict[str, Any]) -> None: + """ + Display generated command with formatting. + + Args: + command_dict: Dictionary containing command, explanation, and options + """ + if Console is None: + # Fallback to plain text if rich is not available + print("\nGenerated Command:") + print(command_dict["command"]) + print("\nExplanation:") + print(command_dict["explanation"]) + + # Display options if available + if "options" in command_dict: + options = command_dict["options"] + + if options.get("useful_module_options"): + print("\nUseful Module Options:") + for opt in options["useful_module_options"]: + print(f" • {opt}") + + print() + return + + console = Console() + + # Display header + console.print() + console.print("[bold cyan]Generated Command:[/bold cyan]") + + # Display command with syntax highlighting + syntax = Syntax(command_dict["command"], "bash", theme="monokai", line_numbers=False) + console.print(syntax) + + # Display explanation + console.print() + console.print("[bold green]Explanation:[/bold green]") + console.print(command_dict["explanation"]) + + # Display options if available + if "options" in command_dict: + options = command_dict["options"] + + if options.get("useful_module_options"): + console.print() + console.print("[bold yellow]Useful Module Options:[/bold yellow]") + for opt in options["useful_module_options"]: + console.print(f" [dim]•[/dim] {opt}") + + console.print() + + +def docs_command(args) -> None: + """ + Entry point for the 'spikee docs' command. + + Args: + args: Parsed command-line arguments + """ + + # Check if any subcommand was provided + if hasattr(args, 'subcommand') and args.subcommand: + if args.subcommand == 'generate': + docs_generate(args) + elif args.subcommand == 'explain': + docs_explain(args) + else: + print(f"Error: Unknown subcommand '{args.subcommand}'. Use 'spikee docs --help' for available subcommands.") + sys.exit(1) + else: + # Default behavior: run generate mode + docs_generate(args) + + +def docs_generate(args) -> None: + """ + Generate a spikee command from a natural language query. + + Args: + args: Parsed command-line arguments + """ + # Join query parts into single string + query = " ".join(args.query) + + if not query.strip(): + print("Error: Please provide a query describing the spikee command you want to generate.") + print("\nExample: spikee docs generate \"test gpt-4o-mini with my dataset\"") + sys.exit(1) + + # Determine which model to use + model = None + + # First priority: --model flag + if hasattr(args, 'model') and args.model: + model = args.model + + # Second priority: model specified in query + if model is None: + cleaned_query, parsed_model = parse_query_for_model(query) + if parsed_model: + query = cleaned_query + model = parsed_model + + # Third priority: default model + if model is None: + model = DEFAULT_MODEL + + # Check for verbose mode + verbose = hasattr(args, 'verbose') and args.verbose + + # Show what we're doing + if Console: + console = Console() + console.print(f"[dim]Generating spikee command using {model}...[/dim]") + else: + print(f"Generating spikee command using {model}...") + + # Generate command + try: + result = generate_command(query, model, verbose=verbose) + display_command(result) + except Exception as e: + print(f"\n{e}", file=sys.stderr) + sys.exit(1) + + +def docs_explain(args) -> None: + """ + Explain spikee commands or provide information about them. + + Args: + args: Parsed command-line arguments + """ + # Join query parts into single string + query = " ".join(args.query) + + if not query.strip(): + print("Error: Please provide a query about spikee commands to explain.") + print("\nExample: spikee docs explain \"how to test with a custom model\"") + sys.exit(1) + + # Determine which model to use + model = None + + # First priority: --model flag + if hasattr(args, 'model') and args.model: + model = args.model + + # Second priority: model specified in query + if model is None: + cleaned_query, parsed_model = parse_query_for_model(query) + if parsed_model: + query = cleaned_query + model = parsed_model + + # Third priority: default model + if model is None: + model = DEFAULT_MODEL + + # Check for verbose mode + verbose = hasattr(args, 'verbose') and args.verbose + + # Show what we're doing + if Console: + console = Console() + console.print(f"[dim]Explaining spikee commands using {model}...[/dim]") + else: + print(f"Explaining spikee commands using {model}...") + + # Generate explanation + try: + result = explain_command(query, model, verbose=verbose) + display_explanation(result) + except Exception as e: + print(f"\n{e}", file=sys.stderr) + sys.exit(1) From cd938fc57a16e5c1414f185be9e9cb4d7c69306d Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Fri, 24 Apr 2026 16:37:47 +0100 Subject: [PATCH 2/6] feat: multiprocessing generator --- spikee/cli.py | 12 + spikee/generator.py | 758 +++++++++--------- .../test_spikee_generate/test_threads.py | 378 +++++++++ 3 files changed, 791 insertions(+), 357 deletions(-) create mode 100644 tests/functional/test_spikee_generate/test_threads.py diff --git a/spikee/cli.py b/spikee/cli.py index 185ca8f..cf781f1 100644 --- a/spikee/cli.py +++ b/spikee/cli.py @@ -211,6 +211,18 @@ def main(): default=None, help="Include a tag at the end of the generated dataset filename", ) + def positive_int(value): + ivalue = int(value) + if ivalue < 1: + raise argparse.ArgumentTypeError(f"--threads must be a positive integer, got {value}") + return ivalue + + parser_generate.add_argument( + "--threads", + type=positive_int, + default=1, + help="Number of threads for parallel plugin processing (default: 1)", + ) parser_plugin = subparsers_generate.add_parser( "plugin", help="Apply a plugin transformation to a string" diff --git a/spikee/generator.py b/spikee/generator.py index 8e7fa10..24b5617 100644 --- a/spikee/generator.py +++ b/spikee/generator.py @@ -2,11 +2,13 @@ import inspect import json import time +import asyncio from collections import defaultdict from typing import Union, List from tabulate import tabulate from pathlib import Path from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed from spikee.utilities.files import read_jsonl_file, read_toml_file, write_jsonl_file from spikee.utilities.modules import load_module_from_path @@ -473,127 +475,226 @@ def parse_exclude_patterns(jailbreak, instruction): return list(exclude_patterns) if exclude_patterns else None - # endregion - -def process_standalone_attacks( - standalone_attacks, - dataset, - entry_id, - adv_prefixes=[None], - adv_suffixes=[None], - plugins=None, - plugin_options_map=None, - plugin_only=False, -): +def _process_permutation_worker(perm, plugin_options_map, system_message_config, output_format, entry_id_start): """ - Processes standalone attacks and appends them to the dataset. - If plugins are provided, applies them to each standalone attack. - Returns the updated dataset and the next entry_id. + Worker function to process a single permutation (base_doc, jailbreak, instruction combination). + Each thread gets its own asyncio event loop for async LLM operations. + + Returns a list of Entry objects for this permutation. """ - - if plugin_only: - plugins = ( - [] + plugins if plugins else [] - ) # Only include plugin variations, no base attack - - else: - plugins = ( - [(None, None)] + plugins if plugins else [(None, None)] - ) # [(plugin name, plugin module)] with a dummy entry for no plugin - - prefixes = adv_prefixes - suffixes = adv_suffixes - - # Obtain plugin options and calculate total variants - plugin_variants = {} - if plugins: + asyncio.set_event_loop(asyncio.new_event_loop()) + + entries = [] + + try: + # Unpack permutation + base_doc = perm['base_doc'] + jailbreak = perm['jailbreak'] + instruction = perm['instruction'] + plugins = perm['plugins'] + prefixes = perm['prefixes'] + suffixes = perm['suffixes'] + positions = perm['positions'] + injection_delimiters = perm['injection_delimiters'] + spotlighting_data_markers_list = perm['spotlighting_data_markers_list'] + match_languages = perm['match_languages'] + + # Extract base document info + base_id = base_doc["id"] + document = base_doc["document"] + placeholder = base_doc.get("placeholder", "") + question = base_doc.get("question", "") + ideal_answer = base_doc.get("ideal_answer", "") + ideal_summary = base_doc.get("ideal_summary", "") + + entry_text = {} + if question: + entry_text["question"] = question + if ideal_answer: + entry_text["ideal_answer"] = ideal_answer + if ideal_summary: + entry_text["ideal_summary"] = ideal_summary + + # Extract jailbreak info + jailbreak_id = jailbreak["id"] + jailbreak_text = jailbreak["text"] + jailbreak_type = jailbreak.get("jailbreak_type", "") + jailbreak_lang = jailbreak.get("lang", "en") + + # Extract instruction info + instruction_id = instruction["id"] + instruction_content_type = instruction.get("content_type", "text") + instruction_content = content_factory(instruction["instruction"], instruction_content_type) + instruction_type = instruction.get("instruction_type", "") + instruction_lang = instruction.get("lang", "en") + judge_name = instruction.get("judge_name", "canary") + judge_args = instruction.get("judge_args", instruction.get("canary", "")) + + # Check language matching + if match_languages and jailbreak_lang != instruction_lang: + return [] + + # Combine jailbreak and instruction + combined_base = content_factory( + jailbreak_text.replace("", str(get_content(instruction_content))), + get_content_type(instruction_content) + ) + lang = instruction_lang + + # Create plugin / transformation regex exclusion lists + local_exclude = parse_exclude_patterns(jailbreak, instruction) + + # Apply plugins and create combined texts + combined_texts = [] + fix_permutations = [(prefix, suffix) for prefix in prefixes for suffix in suffixes] + for plugin_name, plugin_module in plugins: - if plugin_name is None: - plugin_variants[plugin_name] = 1 - - elif "~" in plugin_name and plugin_module: # Plugin Pipe - sub_plugins = plugin_name.split("~") - total_variants = 1 - for sub_plugin, sub_module in zip(sub_plugins, plugin_module): - sub_plugin_option = ( - plugin_options_map.get(sub_plugin) - if plugin_options_map - else None + plugin_texts: List[Content] = ( + apply_plugin(plugin_name, plugin_module, combined_base, local_exclude, plugin_options_map) + if plugin_name + else [combined_base] + ) + + for plugin_index, plugin_text in enumerate(plugin_texts, start=1): + for prefix, suffix in fix_permutations: + prefix_lang = prefix.get("lang", None) if prefix else None + suffix_lang = suffix.get("lang", None) if suffix else None + + if match_languages and ((prefix_lang and prefix_lang != lang) or (suffix_lang and suffix_lang != lang)): + continue + + prefix_text = prefix.get("prefix", "") + " " if prefix else "" + suffix_text = " " + suffix.get("suffix", "") if suffix else "" + text = content_factory( + prefix_text + get_content(plugin_text) + suffix_text, + get_content_type(plugin_text) ) - variants = get_plugin_variants(sub_module, sub_plugin_option) - total_variants *= variants - plugin_variants[plugin_name] = total_variants + + combined_texts.append({ + "text": text, + "prefix_id": prefix.get("id", None) if prefix else None, + "suffix_id": suffix.get("id", None) if suffix else None, + "plugin_name": plugin_name, + "plugin_suffix": f"_{plugin_name}-{plugin_index}" if plugin_name else "", + }) + + # Generate entries for each combined text + insert_positions = ["fixed"] if placeholder else positions + entry_id = entry_id_start + + for combined_text in combined_texts: + for position in insert_positions: + for injection_pattern in injection_delimiters: + injected_doc = insert_jailbreak( + document, + combined_text["text"], + position, + injection_pattern, + placeholder, + ) + + for entry_type in output_format: + if entry_type == "burp": + burp_payload_encoded = json.dumps(get_content(injected_doc))[1:-1] + entries.append(burp_payload_encoded) + else: + for spotlighting_data_marker in spotlighting_data_markers_list: + system_message = get_system_message(system_message_config, spotlighting_data_marker) + + final_injected_doc = injected_doc + if entry_type == EntryType.DOCUMENT: + if spotlighting_data_marker != "none" and isinstance(get_content(injected_doc), str): + final_injected_doc = content_factory( + spotlighting_data_marker.replace("DOCUMENT", get_content(injected_doc)), + get_content_type(injected_doc) + ) + + entry = Entry( + entry_type=entry_type, + entry_id=entry_id, + base_id=base_id, + jailbreak_id=jailbreak_id, + instruction_id=instruction_id, + prefix_id=combined_text.get("prefix_id", None), + suffix_id=combined_text.get("suffix_id", None), + content=final_injected_doc, + entry_text=entry_text, + system_message=system_message, + payload=combined_text.get("text", None), + lang=lang, + plugin_suffix=combined_text.get("plugin_suffix", ""), + plugin_name=combined_text.get("plugin_name", None), + judge_args=judge_args, + judge_name=judge_name, + position=position, + jailbreak_type=jailbreak_type, + instruction_type=instruction_type, + injection_pattern=injection_pattern, + spotlighting_data_markers=spotlighting_data_marker, + exclude_from_transformations_regex=local_exclude, + ).to_entry() + entries.append(entry) + entry_id += 1 + + except Exception as e: + print(f"\n[ERROR] Processing permutation failed: {e}") + import traceback + traceback.print_exc() + return [] + + return entries + + +def _process_standalone_worker(perm, plugin_options_map, entry_id_start): + """ + Worker function to process a single standalone attack permutation. + Each thread gets its own asyncio event loop for async LLM operations. - else: # Standalone Plugin - plugin_option = ( - plugin_options_map.get(plugin_name) if plugin_options_map else None - ) - plugin_variants[plugin_name] = get_plugin_variants( - plugin_module, plugin_option - ) + Returns a list of entry dicts for this standalone attack. + """ + asyncio.set_event_loop(asyncio.new_event_loop()) - # Calculate total entries for progress bar - total_entries = len(standalone_attacks) * ( - sum(plugin_variants.values() or [1]) + (1 if not plugin_only else 0) - ) - bar_standalone = tqdm(total=total_entries, desc="Standalone Attacks", initial=1) + entries = [] - for attack in standalone_attacks: - # If no judge_name, fallback - if "judge_name" not in attack: - attack["judge_name"] = "canary" - if "judge_args" not in attack: - attack["judge_args"] = attack.get("canary", "") + try: + attack = perm['attack'] + plugins = perm['plugins'] + prefixes = perm['prefixes'] + suffixes = perm['suffixes'] - # Get the base attack text and exclude patterns attack_type = attack.get("content_type", "text") attack_content = content_factory(attack.get("content", attack.get("text", "")), attack_type) - exclude_patterns = attack.get("exclude_from_transformations_regex", None) - # Get permutations for prefixes and suffixes - combined_texts = [] # Stored all permutations of prefixes/suffixes and plugin outputs for an attack entry. - fix_permutations = [ - (prefix, suffix) for prefix in prefixes for suffix in suffixes - ] + fix_permutations = [(prefix, suffix) for prefix in prefixes for suffix in suffixes] - # Apply plugins to the base attack text + combined_texts = [] for plugin_name, plugin_module in plugins: plugin_content: List[Content] = ( - apply_plugin( - plugin_name, - plugin_module, - attack_content, - exclude_patterns, - plugin_options_map, - ) + apply_plugin(plugin_name, plugin_module, attack_content, exclude_patterns, plugin_options_map) if plugin_name else [attack_content] ) - # Combine each plugin variation with each prefix/suffix permutation and add to combined_texts for plugin_index, plugin_text in enumerate(plugin_content, start=1): for prefix, suffix in fix_permutations: - prefix_text = prefix.get("prefix", "") + " " if prefix else "" suffix_text = " " + suffix.get("suffix", "") if suffix else "" - # TODO: Should this only apply to text content? - text = content_factory(prefix_text + get_content(plugin_text) + suffix_text, get_content_type(plugin_text)) - - combined_texts.append( - { - "text": text, - "prefix_id": prefix.get("id", None) if prefix else None, - "suffix_id": suffix.get("id", None) if suffix else None, - "plugin_name": plugin_name, - "plugin_suffix": f"_{plugin_name}-{plugin_index}" - if plugin_name - else "", - } + text = content_factory( + prefix_text + get_content(plugin_text) + suffix_text, + get_content_type(plugin_text) ) - + combined_texts.append({ + "text": text, + "prefix_id": prefix.get("id", None) if prefix else None, + "suffix_id": suffix.get("id", None) if suffix else None, + "plugin_name": plugin_name, + "plugin_suffix": f"_{plugin_name}-{plugin_index}" if plugin_name else "", + }) + + entry_id = entry_id_start for combined_text in combined_texts: entry = Entry( entry_type=EntryType.ATTACK, @@ -620,12 +721,106 @@ def process_standalone_attacks( exclude_from_transformations_regex=exclude_patterns, steering_keywords=attack.get("steering_keywords", None), ).to_attack() - - dataset.append(entry) + entries.append(entry) entry_id += 1 - bar_standalone.update(1) - bar_standalone.close() + except Exception as e: + print(f"\n[ERROR] Processing standalone attack failed: {e}") + import traceback + traceback.print_exc() + return [] + + return entries + + +def process_standalone_attacks( + standalone_attacks, + dataset, + entry_id, + adv_prefixes=[None], + adv_suffixes=[None], + plugins=None, + plugin_options_map=None, + plugin_only=False, + num_threads=1, +): + """ + Processes standalone attacks and appends them to the dataset. + If plugins are provided, applies them to each standalone attack. + Returns the updated dataset and the next entry_id. + """ + + if plugin_only: + plugins = [] + plugins if plugins else [] + else: + plugins = [(None, None)] + plugins if plugins else [(None, None)] + + prefixes = adv_prefixes + suffixes = adv_suffixes + + # Normalise judge fields and build permutation list + permutations = [] + for attack in standalone_attacks: + if "judge_name" not in attack: + attack["judge_name"] = "canary" + if "judge_args" not in attack: + attack["judge_args"] = attack.get("canary", "") + + permutations.append({ + 'attack': attack, + 'plugins': plugins, + 'prefixes': prefixes, + 'suffixes': suffixes, + }) + + print(f"[Info] Processing {len(permutations)} standalone attack(s) with {num_threads} thread(s)") + + new_entries = [] + + if num_threads > 1: + def thread_init(): + asyncio.set_event_loop(asyncio.new_event_loop()) + + with ThreadPoolExecutor(max_workers=num_threads, initializer=thread_init) as executor: + futures = {} + current_entry_id = entry_id + for perm in permutations: + future = executor.submit( + _process_standalone_worker, + perm, + plugin_options_map, + current_entry_id, + ) + futures[future] = perm + current_entry_id += 100 # Placeholder; reassigned after collection + + bar = tqdm(total=len(permutations), desc="Standalone Attacks") + for future in as_completed(futures): + try: + entries = future.result() + new_entries.extend(entries) + except Exception as e: + perm = futures[future] + print(f"\n[ERROR] Standalone attack {perm['attack']['id']} failed: {e}") + bar.update(1) + bar.close() + + # Reassign sequential entry IDs + for i, entry in enumerate(new_entries, start=entry_id): + if isinstance(entry, dict): + entry['id'] = i + entry_id += len(new_entries) + + else: + bar = tqdm(total=len(permutations), desc="Standalone Attacks") + for perm in permutations: + entries = _process_standalone_worker(perm, plugin_options_map, entry_id) + new_entries.extend(entries) + entry_id += len(entries) + bar.update(1) + bar.close() + + dataset.extend(new_entries) return dataset, entry_id @@ -644,283 +839,130 @@ def generate_variations( system_message_config=None, plugin_options_map=None, plugin_only=False, + num_threads=1, ): """ Generates dataset variations from the given base documents, jailbreaks, instructions, injection positions, delimiters, data markers, and plugins. Returns the dataset and the last used entry_id. + + When num_threads > 1, creates all permutations first and processes them in parallel. """ - dataset = [] - entry_id = 1 - + if plugin_only: - plugins = ( - [] + plugins if plugins else [] - ) # Only include plugin variations, no base attack - + plugins = [] + plugins if plugins else [] else: - plugins = ( - [(None, None)] + plugins if plugins else [(None, None)] - ) # [(plugin name, plugin module)] with a dummy entry for no plugin - + plugins = [(None, None)] + plugins if plugins else [(None, None)] + prefixes = adv_prefixes suffixes = adv_suffixes - + # Define output format specific entry types match output_format: case "full-prompt": - output_format = [EntryType.SUMMARY, EntryType.QA] - + output_format_types = [EntryType.SUMMARY, EntryType.QA] case "user-input": - output_format = [EntryType.DOCUMENT] - + output_format_types = [EntryType.DOCUMENT] case _: - output_format = ["burp"] - - # Obtain plugin options and calculate total variants for progress bar - plugin_variants = {} - if plugins: - for plugin_name, plugin_module in plugins: - if plugin_name is None: - plugin_variants[plugin_name] = 1 - - elif "~" in plugin_name and plugin_module: # Plugin Pipe - sub_plugins = plugin_name.split("~") - total_variants = 1 - for sub_plugin, sub_module in zip(sub_plugins, plugin_module): - sub_plugin_option = ( - plugin_options_map.get(sub_plugin) - if plugin_options_map - else None - ) - variants = get_plugin_variants(sub_module, sub_plugin_option) - total_variants *= variants - plugin_variants[plugin_name] = total_variants - - else: # Standalone Plugin - plugin_option = ( - plugin_options_map.get(plugin_name) if plugin_options_map else None - ) - plugin_variants[plugin_name] = get_plugin_variants( - plugin_module, plugin_option - ) - - # Calculate total entries for progress bar - total_entries = ( - len(base_docs) - * len(jailbreaks) - * len(instructions) - * len(positions) - * len(injection_delimiters) - * len(spotlighting_data_markers_list) - * len(suffixes) - * sum(plugin_variants.values() or [1]) - + 1 - if not plugin_only - else 0 - ) - bar_variations = tqdm(total=total_entries, desc="Variations", initial=0) - + output_format_types = ["burp"] + + # Build list of all permutations (base_doc, jailbreak, instruction) + print(f"[Info] Building composable permutations: {len(base_docs)} docs × {len(jailbreaks)} jailbreaks × {len(instructions)} instructions") + + permutations = [] + for base_doc in base_docs: - base_id = base_doc["id"] - document = base_doc["document"] - placeholder = base_doc.get("placeholder", "") - - # Define entry type specific text - question = base_doc.get("question", "") - ideal_answer = base_doc.get("ideal_answer", "") - ideal_summary = base_doc.get("ideal_summary", "") - entry_text = {} - if question != "": - entry_text["question"] = question - - if ideal_answer != "": - entry_text["ideal_answer"] = ideal_answer - - if ideal_summary != "": - entry_text["ideal_summary"] = ideal_summary - - # If the current document has a placeholder attribute, it means the user - # want the payload to be inserted into a fixed location, so we override - # the inject positions for this document - insert_positions = ["fixed"] if placeholder else positions - for jailbreak in jailbreaks: - jailbreak_id = jailbreak["id"] - jailbreak_text = jailbreak["text"] - jailbreak_type = jailbreak.get("jailbreak_type", "") - jailbreak_lang = jailbreak.get("lang", "en") - for instruction in instructions: - instruction_id = instruction["id"] - instruction_content_type = instruction.get("content_type", "text") - instruction_content = content_factory(instruction["instruction"], instruction_content_type) - instruction_type = instruction.get("instruction_type", "") - instruction_lang = instruction.get("lang", "en") - # instruction_steering_keywords = instruction.get("steering_keywords", None) - - if instruction_content_type != "text": - print( - f"Skipping instruction {instruction_id} for jailbreak {jailbreak_id} because instruction content type '{instruction_content_type}' is not supported yet." - ) - continue - - judge_name = instruction.get("judge_name", "canary") - judge_args = instruction.get( - "judge_args", instruction.get("canary", "") + # Pre-check language matching to skip invalid combinations + if match_languages: + jailbreak_lang = jailbreak.get("lang", "en") + instruction_lang = instruction.get("lang", "en") + if jailbreak_lang != instruction_lang: + continue + + permutations.append({ + 'base_doc': base_doc, + 'jailbreak': jailbreak, + 'instruction': instruction, + 'plugins': plugins, + 'prefixes': prefixes, + 'suffixes': suffixes, + 'positions': ["fixed"] if base_doc.get("placeholder", "") else positions, + 'injection_delimiters': injection_delimiters, + 'spotlighting_data_markers_list': spotlighting_data_markers_list, + 'match_languages': match_languages, + }) + + print(f"[Info] Processing {len(permutations)} composable permutations with {num_threads} thread(s)") + + # Process permutations + dataset = [] + entry_id = 1 + + if num_threads > 1: + # Parallel processing + def thread_init(): + """Each worker thread gets its own asyncio event loop.""" + asyncio.set_event_loop(asyncio.new_event_loop()) + + with ThreadPoolExecutor(max_workers=num_threads, initializer=thread_init) as executor: + # Submit all permutations + futures = {} + current_entry_id = entry_id + + for perm in permutations: + future = executor.submit( + _process_permutation_worker, + perm, + plugin_options_map, + system_message_config, + output_format_types, + current_entry_id ) - - # If match_languages is enabled, skip if jailbreak and instruction languages do not match - if match_languages and jailbreak_lang != instruction_lang: - total_entries -= ( - len(positions) - * len(injection_delimiters) - * len(spotlighting_data_markers_list) - * len(suffixes) - * sum(plugin_variants.values() or [1]) - ) - bar_variations.total = total_entries - bar_variations.refresh() - continue - - # Combines jailbreak and instruction texts - # Instruction is placed into jailbreak at placeholder - combined_base = content_factory(jailbreak_text.replace( - "", str(get_content(instruction_content)) - ), get_content_type(instruction_content)) - lang = instruction_lang - - # Create plugin / transformation regex exclusion lists - local_exclude = parse_exclude_patterns(jailbreak, instruction) - - # Apply all plugin, prefix and suffix combinations for the combined_base text (jailbreak + instruction) - # Applied with --plugins, --include-prefixes and --include-suffixes, by default None - combined_texts = [] - fix_permutations = [ - (prefix, suffix) for prefix in prefixes for suffix in suffixes - ] - - for plugin_name, plugin_module in plugins: - plugin_texts: List[Content] = ( - apply_plugin( - plugin_name, - plugin_module, - combined_base, - local_exclude, - plugin_options_map, - ) - if plugin_name - else [combined_base] - ) - - for plugin_index, plugin_text in enumerate(plugin_texts, start=1): - for prefix, suffix in fix_permutations: - prefix_lang = prefix.get("lang", None) if prefix else None - suffix_lang = suffix.get("lang", None) if suffix else None - - if match_languages and ( - (prefix_lang and prefix_lang != lang) - or (suffix_lang and suffix_lang != lang) - ): - continue - - prefix_text = prefix.get("prefix", "") + " " if prefix else "" - suffix_text = " " + suffix.get("suffix", "") if suffix else "" - text = content_factory(prefix_text + get_content(plugin_text) + suffix_text, get_content_type(plugin_text)) - - combined_texts.append( - { - "text": text, - "prefix_id": prefix.get("id", None) if prefix else None, - "suffix_id": suffix.get("id", None) if suffix else None, - "plugin_name": plugin_name, - "plugin_suffix": f"_{plugin_name}-{plugin_index}" - if plugin_name - else "", - } - ) - - for combined_text in combined_texts: - # Applied with --positions, by default ['end'] - for position in insert_positions: - # Applied with --injection-delimiters, by default "\nINJECTION_PAYLOAD\n" - for injection_pattern in injection_delimiters: - # Combine document with suffix_combined_text (jailbreak + instruction + suffix) - # suffix_combined_text is inserted into the injection_pattern at 'INJECTION_PAYLOAD' placeholder - # Document Placeholder: Injection is placed into document placeholder - # Otherwise: Injection is placed into document at position (start, middle, end) - injected_doc = insert_jailbreak( - document, - combined_text["text"], - position, - injection_pattern, - placeholder, - ) - - for entry_type in output_format: - if entry_type == "burp": - burp_payload_encoded = json.dumps(get_content(injected_doc))[ - 1:-1 - ] - dataset.append(burp_payload_encoded) - - else: - # Applied with --spotlighting-data-markers, by default "\nDOCUMENT\n" - for ( - spotlighting_data_marker - ) in spotlighting_data_markers_list: - # Applied with --include-system-message flag - gets system message associated with spotlighting data marker - system_message = get_system_message( - system_message_config, - spotlighting_data_marker, - ) - - # Combines injected document with spotlighting data marker, for full-prompt entries - if entry_type == EntryType.DOCUMENT: - injected_doc = ( - content_factory(spotlighting_data_marker.replace( - "DOCUMENT", get_content(injected_doc) - ), get_content_type(injected_doc)) - if spotlighting_data_marker != "none" and isinstance(get_content(injected_doc), str) - else injected_doc - ) - - entry = Entry( - entry_type=entry_type, - entry_id=entry_id, - base_id=base_id, - jailbreak_id=jailbreak_id, - instruction_id=instruction_id, - prefix_id=combined_text.get( - "prefix_id", None - ), - suffix_id=combined_text.get( - "suffix_id", None - ), - content=injected_doc, - entry_text=entry_text, - system_message=system_message, - payload=combined_text.get("text", None), - lang=lang, - plugin_suffix=combined_text.get( - "plugin_suffix", "" - ), - plugin_name=combined_text.get( - "plugin_name", None - ), - judge_args=judge_args, - judge_name=judge_name, - position=position, - jailbreak_type=jailbreak_type, - instruction_type=instruction_type, - injection_pattern=injection_pattern, - spotlighting_data_markers=spotlighting_data_marker, - exclude_from_transformations_regex=local_exclude, - ).to_entry() - dataset.append(entry) - entry_id += 1 - bar_variations.update(1) - + futures[future] = perm + # Estimate entry count for this permutation (rough approximation) + # This will be corrected as we process results + current_entry_id += 100 # Placeholder increment + + # Collect results with progress bar + bar = tqdm(total=len(permutations), desc="Processing permutations") + + for future in as_completed(futures): + try: + entries = future.result() + dataset.extend(entries) + bar.update(1) + except Exception as e: + perm = futures[future] + print(f"\n[ERROR] Permutation failed (doc={perm['base_doc']['id']}, jb={perm['jailbreak']['id']}, instr={perm['instruction']['id']}): {e}") + + bar.close() + + # Reassign entry IDs sequentially + for i, entry in enumerate(dataset, start=1): + if isinstance(entry, dict): + entry['id'] = i + entry_id = len(dataset) + 1 + + else: + # Sequential processing (original logic) + bar = tqdm(total=len(permutations), desc="Processing permutations") + + for perm in permutations: + entries = _process_permutation_worker( + perm, + plugin_options_map, + system_message_config, + output_format_types, + entry_id + ) + dataset.extend(entries) + entry_id += len(entries) + bar.update(1) + + bar.close() + return dataset, entry_id @@ -1133,6 +1175,7 @@ def generate_dataset(args): system_message_config=system_message_config, plugin_options_map=plugin_options_map, plugin_only=args.plugin_only, + num_threads=getattr(args, 'threads', 1), ) # Generate Standalone Attacks @@ -1149,6 +1192,7 @@ def generate_dataset(args): plugins=plugins if args.plugins else None, plugin_options_map=plugin_options_map, plugin_only=args.plugin_only, + num_threads=getattr(args, 'threads', 1), ) except ImportError as e: print(f"Missing dependency: {e}") diff --git a/tests/functional/test_spikee_generate/test_threads.py b/tests/functional/test_spikee_generate/test_threads.py new file mode 100644 index 0000000..f4b56a7 --- /dev/null +++ b/tests/functional/test_spikee_generate/test_threads.py @@ -0,0 +1,378 @@ +"""Functional tests for threaded dataset generation (--threads parameter). + +Tests verify that: +1. Sequential generation (--threads 1) produces correct output +2. Parallel generation (--threads > 1) produces equivalent output +3. Thread count affects performance appropriately +4. Entry IDs are assigned correctly in both modes +5. Errors are handled gracefully in parallel mode +""" + +import pytest +import time +from spikee.utilities.files import read_jsonl_file +from ..utils import spikee_generate_cli + + +class TestThreadsBasic: + """Basic tests for --threads parameter functionality.""" + + def test_threads_default_sequential(self, run_spikee, workspace_dir): + """Test default behavior (no --threads) uses sequential processing. + + Verifies: + - Default generation works without --threads parameter + - Produces expected number of entries + - All entries have valid structure + """ + output_file = spikee_generate_cli(run_spikee, workspace_dir) + + assert output_file.exists(), f"Expected dataset file at {output_file}" + + dataset = read_jsonl_file(output_file) + + assert len(dataset) > 0, "Generated dataset contains no entries" + assert len(dataset) >= 4, f"Expected at least 4 entries, got {len(dataset)}" + + # Verify all entries have required fields + for entry in dataset: + assert "id" in entry, "Entry missing 'id' field" + assert "long_id" in entry, "Entry missing 'long_id' field" + assert "content" in entry, "Entry missing 'content' field" + assert "judge_name" in entry, "Entry missing 'judge_name' field" + + def test_threads_explicit_sequential(self, run_spikee, workspace_dir): + """Test explicit --threads 1 uses sequential processing. + + Verifies: + - --threads 1 flag works correctly + - Produces same output as default behavior + - Entry IDs are sequential + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "1"] + ) + + assert output_file.exists(), f"Expected dataset file at {output_file}" + + dataset = read_jsonl_file(output_file) + + assert len(dataset) >= 4, f"Expected at least 4 entries, got {len(dataset)}" + + # Verify entry IDs are sequential starting from 1 + ids = [e["id"] for e in dataset] + assert ids == list(range(1, len(dataset) + 1)), \ + f"Expected sequential IDs 1..{len(dataset)}, got {ids}" + + def test_threads_parallel_basic(self, run_spikee, workspace_dir): + """Test basic parallel generation with --threads 2. + + Verifies: + - --threads 2 flag works correctly + - Produces correct number of entries + - All entries have valid structure + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "2"] + ) + + assert output_file.exists(), f"Expected dataset file at {output_file}" + + dataset = read_jsonl_file(output_file) + + assert len(dataset) >= 4, f"Expected at least 4 entries, got {len(dataset)}" + + # Verify all entries have required fields + for entry in dataset: + assert "id" in entry, "Entry missing 'id' field" + assert "long_id" in entry, "Entry missing 'long_id' field" + assert "content" in entry, "Entry missing 'content' field" + assert "payload" in entry, "Entry missing 'payload' field" + + +class TestThreadsEquivalence: + """Tests that sequential and parallel generation produce equivalent datasets.""" + + def test_sequential_vs_parallel_same_output(self, run_spikee, workspace_dir): + """Test that sequential and parallel generation produce equivalent datasets. + + Verifies: + - Both modes generate same number of entries + - Entries have same content (order may differ) + - All permutations are covered in both modes + """ + # Generate with sequential processing + output_sequential = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "1"] + ) + + # Generate with parallel processing + output_parallel = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3"] + ) + + dataset_seq = read_jsonl_file(output_sequential) + dataset_par = read_jsonl_file(output_parallel) + + # Should have same number of entries + assert len(dataset_seq) == len(dataset_par), \ + f"Sequential generated {len(dataset_seq)} entries, parallel generated {len(dataset_par)}" + + # Extract long_ids (unique identifiers for permutations) + long_ids_seq = sorted([e["long_id"] for e in dataset_seq]) + long_ids_par = sorted([e["long_id"] for e in dataset_par]) + + # Both should have identical sets of long_ids + assert long_ids_seq == long_ids_par, \ + "Sequential and parallel modes generated different permutations" + + # Compare content for each long_id + seq_by_id = {e["long_id"]: e for e in dataset_seq} + par_by_id = {e["long_id"]: e for e in dataset_par} + + for long_id in long_ids_seq: + seq_entry = seq_by_id[long_id] + par_entry = par_by_id[long_id] + + # Content should be identical + assert seq_entry["content"] == par_entry["content"], \ + f"Content mismatch for long_id={long_id}" + + # Payload should be identical + assert seq_entry["payload"] == par_entry["payload"], \ + f"Payload mismatch for long_id={long_id}" + + # Metadata should be identical + assert seq_entry["judge_name"] == par_entry["judge_name"], \ + f"Judge name mismatch for long_id={long_id}" + + def test_different_thread_counts_same_output(self, run_spikee, workspace_dir): + """Test that different thread counts produce equivalent datasets. + + Verifies: + - --threads 2 and --threads 4 produce same results + - Only performance differs, not output + """ + output_2_threads = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "2"] + ) + + output_4_threads = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "4"] + ) + + dataset_2 = read_jsonl_file(output_2_threads) + dataset_4 = read_jsonl_file(output_4_threads) + + # Should have same number of entries + assert len(dataset_2) == len(dataset_4), \ + f"2 threads: {len(dataset_2)} entries, 4 threads: {len(dataset_4)} entries" + + # Should have same long_ids + long_ids_2 = sorted([e["long_id"] for e in dataset_2]) + long_ids_4 = sorted([e["long_id"] for e in dataset_4]) + + assert long_ids_2 == long_ids_4, \ + "Different thread counts produced different permutations" + + +class TestThreadsWithPlugins: + """Tests for threaded generation with plugins.""" + + def test_threads_with_simple_plugin(self, run_spikee, workspace_dir): + """Test threaded generation with a simple transformation plugin. + + Verifies: + - Plugins work correctly in parallel mode + - Plugin transformations are applied consistently + - Both base and plugin entries are generated + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3", "--plugins", "test_upper"] + ) + + dataset = read_jsonl_file(output_file) + + # Should have base entries + plugin entries + assert len(dataset) == 12, f"Expected 12 entries (6 base + 6 plugin), got {len(dataset)}" + + # Verify plugin entries + plugin_entries = [e for e in dataset if e.get("plugin") == "test_upper"] + assert len(plugin_entries) == 6, f"Expected 6 plugin entries, got {len(plugin_entries)}" + + # Verify plugin transformation (uppercase) + for entry in plugin_entries: + assert entry["payload"] == entry["payload"].upper(), \ + "Plugin should uppercase the payload" + + def test_threads_with_plugin_sequential_equivalence(self, run_spikee, workspace_dir): + """Test that plugin transformations are identical in sequential and parallel modes. + + Verifies: + - Plugin output is deterministic + - Sequential and parallel produce same plugin transformations + """ + # Sequential with plugin + output_seq = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "1", "--plugins", "test_upper"] + ) + + # Parallel with plugin + output_par = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3", "--plugins", "test_upper"] + ) + + dataset_seq = read_jsonl_file(output_seq) + dataset_par = read_jsonl_file(output_par) + + # Same number of entries + assert len(dataset_seq) == len(dataset_par), \ + f"Sequential: {len(dataset_seq)}, Parallel: {len(dataset_par)}" + + # Compare plugin entries specifically + plugin_seq = sorted([e for e in dataset_seq if e.get("plugin") == "test_upper"], + key=lambda x: x["long_id"]) + plugin_par = sorted([e for e in dataset_par if e.get("plugin") == "test_upper"], + key=lambda x: x["long_id"]) + + assert len(plugin_seq) == len(plugin_par), \ + f"Sequential: {len(plugin_seq)} plugin entries, Parallel: {len(plugin_par)} plugin entries" + + # Verify transformations are identical + for seq_entry, par_entry in zip(plugin_seq, plugin_par): + assert seq_entry["payload"] == par_entry["payload"], \ + f"Plugin payload mismatch: {seq_entry['long_id']}" + assert seq_entry["content"] == par_entry["content"], \ + f"Plugin content mismatch: {seq_entry['long_id']}" + + +class TestThreadsPerformance: + """Tests for performance characteristics of threaded generation. + + Note: These are smoke tests, not precise benchmarks. + They verify that parallelization doesn't slow things down significantly. + """ + + def test_threads_performance_smoke(self, run_spikee, workspace_dir): + """Smoke test: Verify parallel mode doesn't take longer than sequential. + + This is a sanity check, not a performance benchmark. + Parallel should be at least as fast as sequential for non-trivial datasets. + """ + # Measure sequential time + start = time.time() + output_seq = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "1"] + ) + sequential_time = time.time() - start + + # Measure parallel time + start = time.time() + output_par = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3"] + ) + parallel_time = time.time() - start + + # Verify both completed successfully + assert output_seq.exists() + assert output_par.exists() + + # Parallel shouldn't be significantly slower than sequential + # Allow 3x overhead for thread management on small datasets + assert parallel_time < sequential_time * 3, \ + f"Parallel ({parallel_time:.2f}s) much slower than sequential ({sequential_time:.2f}s)" + + +class TestThreadsWithFilters: + """Tests for threaded generation with filtering options.""" + + def test_threads_with_language_filter(self, run_spikee, workspace_dir): + """Test threaded generation with language filtering. + + Verifies: + - Language filtering works in parallel mode + - Only specified language entries are generated + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3", "--languages", "en"] + ) + + dataset = read_jsonl_file(output_file) + + assert len(dataset) > 0, "Generated dataset contains no entries" + + # All entries should be English + languages = {e.get("lang") for e in dataset} + assert languages == {"en"}, f"Expected only 'en' language, got {languages}" + + def test_threads_with_instruction_filter(self, run_spikee, workspace_dir): + """Test threaded generation with instruction type filtering. + + Verifies: + - Instruction filtering works in parallel mode + - Only specified instruction types are generated + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3", "--instruction-filter", "restricted"] + ) + + dataset = read_jsonl_file(output_file) + + assert len(dataset) == 2, f"Expected 2 entries for 'restricted' filter, got {len(dataset)}" + + # All entries should reference the filtered instruction + for entry in dataset: + assert "instr-filter" in entry.get("long_id", ""), \ + f"Expected 'instr-filter' in long_id, got: {entry.get('long_id')}" + + def test_threads_with_match_languages_false(self, run_spikee, workspace_dir): + """Test threaded generation with cross-language pairing enabled. + + Verifies: + - Cross-language pairing works in parallel mode + - Generates all language combinations + """ + output_file = spikee_generate_cli( + run_spikee, + workspace_dir, + additional_args=["--threads", "3", "--match-languages", "false"] + ) + + dataset = read_jsonl_file(output_file) + + # Should have all cross-language combinations (2 docs × 2 jbs × 3 instrs = 12) + assert len(dataset) == 12, f"Expected 12 entries with cross-language, got {len(dataset)}" + + # Verify cross-language entries exist + long_ids = [e.get("long_id", "") for e in dataset] + cross_lang_entries = [ + lid for lid in long_ids + if ("jb-it" in lid and "instr-en" in lid) or ("jb-en" in lid and "instr-it" in lid) + ] + assert len(cross_lang_entries) > 0, "Expected cross-language entries" \ No newline at end of file From f8ab9105ec74c31ff82cbd3d5b9f374204861dc0 Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Fri, 24 Apr 2026 16:39:30 +0100 Subject: [PATCH 3/6] dev: remove pre-release code --- spikee/cli.py | 62 --- spikee/docs.py | 1006 ------------------------------------------------ 2 files changed, 1068 deletions(-) delete mode 100644 spikee/docs.py diff --git a/spikee/cli.py b/spikee/cli.py index cf781f1..ff5209f 100644 --- a/spikee/cli.py +++ b/spikee/cli.py @@ -5,7 +5,6 @@ import sys import shutil import argparse -import argcomplete from . import __version__ from dotenv import load_dotenv from pathlib import Path @@ -29,7 +28,6 @@ list_providers, ) from .viewers.results import ResultsViewer -from .docs import docs_command banner = r""" _____ _____ _____ _ ________ ______ @@ -607,64 +605,6 @@ def positive_int(value): help="Include descriptions of modules where available", ) - # === [DOCS] Sub-command ================================================ - parser_docs = subparsers.add_parser( - "docs", - help="Generate spikee commands or explain spikee commands" - ) - - # Create subparsers for docs command - docs_subparsers = parser_docs.add_subparsers( - dest="subcommand", - help="Subcommands for docs" - ) - - # docs generate subcommand - parser_docs_generate = docs_subparsers.add_parser( - "generate", - help="Generate spikee commands using natural language queries" - ) - parser_docs_generate.add_argument( - "query", - type=str, - nargs="+", - help="Natural language description of desired spikee command" - ) - parser_docs_generate.add_argument( - "--model", - type=str, - default=None, - help="LLM model to use for generation (default: openai/gpt-4o)" - ) - parser_docs_generate.add_argument( - "--verbose", - action="store_true", - help="Show debug information (classification, context size)" - ) - - # docs explain subcommand - parser_docs_explain = docs_subparsers.add_parser( - "explain", - help="Explain spikee commands or provide information about them" - ) - parser_docs_explain.add_argument( - "query", - type=str, - nargs="+", - help="Query about spikee commands to explain" - ) - parser_docs_explain.add_argument( - "--model", - type=str, - default=None, - help="LLM model to use for explanation (default: openai/gpt-4o)" - ) - parser_docs_explain.add_argument( - "--verbose", - action="store_true", - help="Show debug information (classification, context size)" - ) - args = convert_to_new_args(parser.parse_args()) # Print banner and info unless quiet mode is enabled @@ -726,8 +666,6 @@ def positive_int(value): list_providers(args) else: parser_list.print_help() - elif args.command == "docs": - docs_command(args) else: parser.print_help() sys.exit(1) diff --git a/spikee/docs.py b/spikee/docs.py deleted file mode 100644 index 6b83863..0000000 --- a/spikee/docs.py +++ /dev/null @@ -1,1006 +0,0 @@ -""" -LLM-powered spikee command generator with optimized context loading. - -This module provides natural language command generation for spikee using an LLM provider -with a two-stage approach: -1. First, classify the query to determine command type -2. Then, load only relevant documentation and module information - -Usage: - spikee docs "test gpt-4o-mini with the cybersec dataset" - spikee docs "generate dataset with base64 plugin" --model bedrock/claude45-haiku -""" - -import sys -import re -from typing import Tuple, Dict, Any - -from spikee.utilities.llm import get_llm -from spikee.utilities.llm_message import SystemMessage, HumanMessage -from spikee.utilities.modules import ( - extract_json_or_fail, - load_module_from_path, - get_options_from_module, - get_description_from_module, - collect_modules, - collect_seeds, - collect_datasets, -) - -try: - from rich.console import Console - from rich.syntax import Syntax -except ImportError: - Console = None - -DEFAULT_MODEL = "openai/gpt-4o" - -# === Stage 1: Command Classification === - -CLASSIFIER_PROMPT = """You are a command classifier for the SPIKEE toolkit. - -Analyze the user's query and determine which spikee command they want to use. - -Available commands: -- generate: Creating test datasets from seed folders -- test: Testing targets with datasets (includes attacks, judges, sampling) -- results: Analyzing, rejudging, or extracting test results -- list: Listing available modules (seeds, datasets, targets, plugins, attacks, judges, providers) -- init: Initializing workspace -- viewer: Launching web viewers -- unknown: Cannot determine or doesn't match spikee commands - -Respond with ONLY a JSON object: -{ - "command": "", - "confidence": "" -} - -Examples: -"test gpt-4o with my dataset" -> {"command": "test", "confidence": "high"} -"generate dataset with plugins" -> {"command": "generate", "confidence": "high"} -"show me all plugins" -> {"command": "list", "confidence": "high"} -"analyze my results" -> {"command": "results", "confidence": "high"} -"setup workspace" -> {"command": "init", "confidence": "medium"} -""" - -# === Stage 2: Modular Documentation Sections === - -COMMON_HEADER = """You are an expert assistant for the SPIKEE toolkit - a prompt injection and jailbreaking testing framework. - -Your task is to generate valid spikee CLI commands based on natural language descriptions from users. - -# Common Patterns: - -1. **LLM Provider Format**: Always use "provider/model" format - - OpenAI: "openai/gpt-4o", "openai/gpt-4o-mini" - - Bedrock: "bedrock/claude45-sonnet", "bedrock/claude45-haiku" - - Azure: "azure/gpt-4" - - Groq: "groq/llama-3.1-70b" - - DeepSeek: "deepseek/deepseek-chat" - -2. **Plugin Piping**: Use | to pipe plugins: "plugin1|plugin2|plugin3" - -3. **Options Format**: "module:key1=val1,key2=val2;module2:key3=val3" - -4. **Dataset Wildcards**: Can use wildcards in paths: "datasets/cybersec-*.jsonl" -""" - -GENERATE_DOCS = """ -## GENERATE Command - -### Required Arguments: -- `--seed-folder ` - REQUIRED: Path to seed folder (e.g., datasets/seeds-cybersec-2026-01) - -### Optional Source Arguments: -- `--include-standalone-inputs` - Include standalone_user_inputs.jsonl -- `--include-system-message` - Include system_messages.toml -- `--tag ` - Tag for dataset filename - -### Optional Transformation Arguments: -- `--plugins ` - Space-separated list of plugins OR piped plugins with | (e.g., "1337 base64" or "splat|base64") -- `--plugin-options ""` - Plugin options: "plugin1:option1=value1,option2=value2;plugin2:option2=value2" -- `--plugin-only` - Only output plugin entries -- `--include-fixes ` - Comma-separated: adv_prefixes, adv_suffixes, prefixes=, suffixes=, prefix=, suffix= - -### Optional Formatting Arguments: -- `--format ` - Output format: user-input (default/apps), full-prompt (LLMs), or burp -- `--languages ` - Comma-separated list of languages to filter (e.g., en) -- `--match-languages` - Only combine jailbreaks/instructions with matching languages (default: True) -- `--positions ` - Position to insert jailbreaks: start, middle, end (ignored if present) -- `--injection-delimiters ` - Delimiters for injecting jailbreaks (default: \\nINJECTION_PAYLOAD\\n) -- `--spotlighting-data-markers ` - Comma-separated data markers (placeholder: "DOCUMENT") -- `--instruction-filter ` - Comma-separated instruction types to include -- `--jailbreak-filter ` - Comma-separated jailbreak types to include - -### Examples: -```bash -# Basic generation -spikee generate --seed-folder datasets/seeds-cybersec-2026-01 - -# With plugins -spikee generate --seed-folder datasets/seeds-toxic-chat --plugins "1337 base64" - -# Plugin piping -spikee generate --seed-folder datasets/seeds-cybersec-2026-01 --plugins "splat|base64" - -# With plugin options -spikee generate --seed-folder datasets/seeds-example --plugins best_of_n --plugin-options "best_of_n:variants=50" - -# With standalone inputs -spikee generate --seed-folder datasets/seeds-in-the-wild --include-standalone-inputs - -# With adversarial fixes -spikee generate --seed-folder datasets/seeds-cybersec-2026-01 --include-fixes "adv_prefixes,adv_suffixes" -``` -""" - -TEST_DOCS = """ -## TEST Command - -### Required Dataset Arguments (at least one required): -- `--dataset ` - Path to dataset JSONL file (can be used multiple times) -- `--dataset-folder ` - Path to folder with multiple JSONL files (can be used multiple times) - -### Required Module Arguments: -- `--target ` - REQUIRED: Target module name (e.g., llm_provider, aws_bedrock_guardrail) - -### Optional Module Arguments: -- `--target-options ""` - Target options, typically "provider/model" format - - Examples: "openai/gpt-4o-mini", "bedrock/claude45-sonnet", "azure/gpt-4" -- `--judge-options ""` - LLM judge model (format: "model=provider/model" or just "provider/model") - - Examples: "bedrock/claude45-haiku", "openai/gpt-4o" - - Only needed for datasets requiring semantic evaluation (not canary-based) - -### Optional Testing Arguments: -- `--threads ` - Number of parallel threads (default: 4) -- `--attempts ` - Number of attempts per entry (default: 1) -- `--max-retries ` - Number of retries for rate-limiting/429 errors (default: 3) -- `--throttle ` - Time to wait between entries per thread (default: 0) -- `--sample ` - Sample percentage of dataset (e.g., 0.15 for 15%, default: 1) -- `--sample-seed ` - Seed for random sampling (default: 42) -- `--tag ` - Tag for results filename - -### Optional Attack Arguments: -- `--attack ` - Attack module to use -- `--attack-iterations ` - Number of attack iterations/turns per entry -- `--attack-options ""` - Attack-specific options -- `--attack-only` - Only run attack module, skip standard attempts - -### Optional Resume Arguments: -- `--resume-file ` - Resume from specific results JSONL file (single dataset only) -- `--auto-resume` - Silently resume from latest matching results file -- `--no-auto-resume` - Create new results file, don't resume - -### Examples: -```bash -# Basic test -spikee test --dataset datasets/cybersec-2026-01.jsonl --target llm_provider --target-options "openai/gpt-4o-mini" - -# With LLM judge -spikee test --dataset datasets/harmful.jsonl --target llm_provider --target-options "openai/gpt-4o-mini" --judge-options "bedrock/claude45-haiku" - -# Multiple datasets -spikee test --dataset datasets/dataset1.jsonl --dataset datasets/dataset2.jsonl --target llm_provider --target-options "bedrock/claude45-sonnet" - -# With attack -spikee test --dataset datasets/example.jsonl --target llm_provider --target-options "openai/gpt-4o" --attack best_of_n --attack-iterations 25 - -# With sampling -spikee test --dataset datasets/large.jsonl --target llm_provider --target-options "openai/gpt-4o" --sample 0.1 --sample-seed 123 -``` -""" - -RESULTS_DOCS = """ -## RESULTS Command - -### Subcommands: -1. `results analyze` - Analyze test results with statistics and visualizations -2. `results rejudge` - Re-judge results with different judge -3. `results extract` - Extract specific results by category or search term -4. `results dataset-comparison` - Compare datasets across multiple targets -5. `results convert-to-excel` - Convert results JSONL to Excel format - -### analyze Arguments: -- `--results-file ` - Path to results JSONL file (can be used multiple times) -- `--results-folder ` - Path to folder with results files (can be used multiple times) -- `--false-positive-checks ` - JSONL file with benign prompts for FP analysis (single dataset only) -- `--output-format ` - Output format: console (default) or html -- `--overview` - Only output general statistics -- `--combine` - Combine multiple results files into single analysis - -### rejudge Arguments: -- `--results-file ` - Path to results JSONL file (can be used multiple times) -- `--results-folder ` - Path to folder with results files (can be used multiple times) -- `--judge-options ` - Options to pass to the judge -- `--resume` - Resume from most recent re-judge file - -### extract Arguments: -- `--results-file ` - Path to results JSONL file (can be used multiple times) -- `--results-folder ` - Path to folder with results files (can be used multiple times) -- `--category ` - Category: success (default), failure, error, guardrail, no-guardrail, custom -- `--custom-search ` - Custom search: 'string', 'field:string', or '!string' to invert -- `--tag ` - Tag for results filename - -### convert-to-excel Arguments: -- `--result-file ` - Path to results JSONL file (required) - -### Examples: -```bash -# Analyze results -spikee results analyze --results-file results/test-run.jsonl - -# Rejudge with different judge -spikee results rejudge --results-file results/test.jsonl --judge-options "openai/gpt-4o" - -# Extract successful prompts -spikee results extract --results-file results/test.jsonl --category success -``` -""" - -LIST_DOCS = """ -## LIST Command - -### Subcommands: -- `list seeds` - List available seed folders -- `list datasets` - List available dataset JSONL files -- `list targets` - List available targets -- `list judges` - List available judges -- `list plugins` - List available plugins -- `list attacks` - List available attack scripts -- `list providers` - List available LLM providers - -### Optional Arguments (for targets, judges, plugins, attacks, providers): -- `-d`, `--description` - Include module descriptions - -### Examples: -```bash -spikee list seeds -spikee list targets --description -spikee list plugins -d -``` -""" - -INIT_DOCS = """ -## INIT Command - -### Arguments: -- `--force` - Overwrite existing directories -- `--include-builtin ` - Copy built-in modules to local workspace -- `--include-viewer` - Include built-in web viewer in local workspace - -### Examples: -```bash -# Basic workspace initialization -spikee init - -# With built-in modules -spikee init --include-builtin all - -# Force overwrite -spikee init --force -``` -""" - -VIEWER_DOCS = """ -## VIEWER Command - -### Subcommands: -- `viewer results` - Launch results viewer - -### Common Arguments: -- `-h`, `--host
` - Host address (default: 127.0.0.1) -- `-p`, `--port ` - Port number (default: 8080) -- `-d`, `--debug` - Enable debug mode with hot-reloading (default: False) -- `--truncate ` - Truncate long fields (default: 500 chars, 0 to disable) - -### results Viewer Arguments: -- `--result-file ` - Path to results JSONL file (can be used multiple times) -- `--result-folder ` - Path to results folder (can be used multiple times) -- `--allow-ast` - Allow AST parsing (use with caution) - -### Examples: -```bash -# Launch results viewer -spikee viewer results --result-folder results/ - -# Custom port -spikee viewer -p 8081 results --result-file results/test.jsonl -``` -""" - -RESPONSE_FORMAT = """ -# Your Response Format: - -You MUST respond with ONLY valid JSON in this exact format: - -{ - "command": "spikee ", - "explanation": "Clear explanation of what this command does", - "options": { - "useful_module_options": ["List of 2-4 useful module-specific options (e.g., --plugin-options, --attack-options, --target-options) that could enhance this command. ONLY include if the command uses modules like plugins, attacks, targets, or judges"] - } -} - -# Important Guidelines: - -1. Generate VALID spikee commands only - use exact argument names and formats shown above -2. ONLY use modules from the available modules list provided -3. Use real seed folders and datasets from the available lists when provided -4. If paths are not specified, use appropriate placeholders from the available lists -5. Use appropriate defaults (e.g., openai/gpt-4o-mini for testing if not specified) -6. Include clear explanations that help users understand what the command does -7. If user mentions a specific LLM provider/model, use it in the command -8. When using LLM-based modules, always include model in options -9. In the "options" field, ONLY suggest useful module-specific options if the command uses modules (plugins, attacks, targets, judges). Do NOT include general command arguments like --threads or --sample -10. Return ONLY the JSON - no additional text before or after -""" - - -def parse_query_for_model(query: str) -> Tuple[str, str]: - """ - Extract model specification from query if present. - - Patterns detected: - - "using openai/gpt-4o" - - "with bedrock/claude45-sonnet" - - "model=openai/gpt-4" - - Args: - query: Natural language query string - - Returns: - Tuple of (cleaned_query, model_name_or_none) - """ - pattern = r'\b(using|with|model=)\s*([a-zA-Z0-9_-]+/[a-zA-Z0-9_.-]+)\b' - match = re.search(pattern, query, re.IGNORECASE) - - if match: - model = match.group(2) - cleaned = re.sub(pattern, '', query, flags=re.IGNORECASE).strip() - cleaned = re.sub(r'\s+', ' ', cleaned) - return cleaned, model - - return query, "" - - -def classify_command(query: str, model: str = DEFAULT_MODEL) -> Dict[str, str]: - """ - Classify the user's query to determine which spikee command they want. - - Args: - query: Natural language query - model: LLM model to use for classification - - Returns: - Dictionary with command type and confidence level - """ - try: - # Use faster, cheaper model with minimal tokens for classification - llm = get_llm(model, max_tokens=50, temperature=0) - if llm is None: - raise ValueError("Failed to initialize LLM") - - messages = [ - SystemMessage(content=CLASSIFIER_PROMPT), - HumanMessage(content=query) - ] - - response = llm.invoke(messages).content - - if not isinstance(response, str) or response.strip() == "": - raise ValueError("LLM returned empty response for classification") - - result = extract_json_or_fail(response) - - return result - except Exception as e: - # Fallback to general command if classification fails - return {"command": "unknown", "confidence": "low"} - - -def get_seeds_info() -> str: - """ - Get information about available seed folders in the workspace. - - Returns: - Formatted string with seed folder names - """ - try: - seeds = collect_seeds() - - if not seeds: - return "\n## Available Seed Folders:\n\nNo seed folders found in datasets/. Use 'spikee init' to initialize workspace." - - lines = ["\n## Available Seed Folders:\n"] - lines.append("Found in datasets/ directory:") - for seed in seeds: - lines.append(f"- datasets/{seed}") - - return "\n".join(lines) - except Exception as e: - return f"\n## Available Seed Folders:\n\nError loading seeds: {str(e)}" - - -def get_datasets_info() -> str: - """ - Get information about available datasets in the workspace. - - Returns: - Formatted string with dataset names - """ - try: - datasets = collect_datasets() - - if not datasets: - return "\n## Available Datasets:\n\nNo datasets found in datasets/. Generate datasets using 'spikee generate'." - - lines = ["\n## Available Datasets:\n"] - lines.append("Found in datasets/ directory:") - for dataset in datasets: - lines.append(f"- datasets/{dataset}") - - return "\n".join(lines) - except Exception as e: - return f"\n## Available Datasets:\n\nError loading datasets: {str(e)}" - - -def get_module_info(module_type: str) -> str: - """ - Get information about available modules from local and built-in sources. - - Args: - module_type: Type of module (plugins, targets, attacks, judges, providers) - - Returns: - Formatted string with module names and options - """ - try: - # Use collect_modules from utilities - all_names, local_names, builtin_names = collect_modules(module_type) - - if not all_names: - return f"\n## Available {module_type.title()}:\n\nNo {module_type} available." - - lines = [f"\n## Available {module_type.title()}:\n"] - - # Track if any module requires LLM - any_llm_required = False - - for name in all_names: - try: - # Load module to get info - module = load_module_from_path(name, module_type) - - # Get options - options_result = get_options_from_module(module, module_type) - util_llm = False - options = [] - - if options_result is not None: - if isinstance(options_result, tuple) and len(options_result) == 2: - options, util_llm = options_result - else: - options = options_result if isinstance(options_result, list) else [options_result] - - if util_llm: - any_llm_required = True - - # Get description - description_result = get_description_from_module(module, module_type) - tags = [] - description = "" - - if description_result is not None: - if isinstance(description_result, tuple) and len(description_result) == 2: - tags, description = description_result - else: - description = description_result if isinstance(description_result, str) else "" - - # Build module line - line = f"- **{name}**" - - # Add source indicator - if name in local_names: - line += " [Local]" - - # Add tags if available - if tags: - if hasattr(tags[0], 'value'): # ModuleTag enum - tags_str = ", ".join([tag.value for tag in tags]) - else: - tags_str = ", ".join(str(tag) for tag in tags) - line += f" [{tags_str}]" - - # Add options if available - if options and len(options) > 0 and not (isinstance(options[0], str) and options[0].startswith(" str: - """ - Build contextual documentation based on detected command type. - - Args: - command_type: The type of command (generate, test, results, etc.) - - Returns: - Relevant documentation string - """ - context = COMMON_HEADER - - if command_type == "generate": - context += GENERATE_DOCS - context += get_seeds_info() # Add available seed folders - context += get_module_info("plugins") - - elif command_type == "test": - context += TEST_DOCS - context += get_datasets_info() # Add available datasets - context += get_module_info("targets") - context += get_module_info("attacks") - context += get_module_info("judges") - - elif command_type == "results": - context += RESULTS_DOCS - context += get_module_info("judges") # For rejudge - - elif command_type == "list": - context += LIST_DOCS - - elif command_type == "init": - context += INIT_DOCS - - elif command_type == "viewer": - context += VIEWER_DOCS - - else: # unknown - provide general overview - context += "\n## All Commands Available:\n" - context += "- generate: Create test datasets\n" - context += "- test: Test targets with datasets\n" - context += "- results: Analyze, rejudge, or extract results\n" - context += "- list: List available modules\n" - context += "- init: Initialize workspace\n" - context += "- viewer: Launch web viewers\n" - context += "\nPlease rephrase your query to be more specific.\n" - - context += RESPONSE_FORMAT - - return context - - -def build_explanation_context(command_type: str) -> str: - """ - Build contextual documentation for explaining spikee commands. - - Args: - command_type: The type of command being explained (generate, test, results, etc.) - - Returns: - Relevant documentation string for explanations - """ - # Start with common header - context = COMMON_HEADER - - # Add command-specific explanation context - context += "\n## Command Explanation Context\n" - context += "This section provides detailed information to explain spikee commands based on user queries.\n" - - if command_type == "generate": - context += "\n### Generate Command Context\n" - context += "The generate command creates test datasets from seed folders with optional transformations.\n" - context += GENERATE_DOCS - context += get_seeds_info() - context += get_module_info("plugins") - - elif command_type == "test": - context += "\n### Test Command Context\n" - context += "The test command evaluates targets with datasets, optionally using attacks and judges.\n" - context += TEST_DOCS - context += get_datasets_info() - context += get_module_info("targets") - context += get_module_info("attacks") - context += get_module_info("judges") - - elif command_type == "results": - context += "\n### Results Command Context\n" - context += "The results command analyzes, rejudges, or extracts test results.\n" - context += RESULTS_DOCS - context += get_module_info("judges") - - elif command_type == "list": - context += "\n### List Command Context\n" - context += "The list command shows available modules for various categories.\n" - context += LIST_DOCS - - elif command_type == "init": - context += "\n### Init Command Context\n" - context += "The init command initializes a new spikee workspace.\n" - context += INIT_DOCS - - elif command_type == "viewer": - context += "\n### Viewer Command Context\n" - context += "The viewer command launches web viewers for results.\n" - context += VIEWER_DOCS - - else: # unknown - provide general overview - context += "\n### General Command Context\n" - context += "When the command type is unknown, here's a general overview of all spikee commands:\n" - context += "- generate: Create test datasets\n" - context += "- test: Test targets with datasets\n" - context += "- results: Analyze, rejudge, or extract results\n" - context += "- list: List available modules\n" - context += "- init: Initialize workspace\n" - context += "- viewer: Launch web viewers\n" - context += "\nPlease rephrase your query to be more specific.\n" - - # Add explanation-specific response format - context += "\n# Your Response Format:\n" - context += "You MUST respond with ONLY valid JSON in this exact format:\n" - context += "{\n" - context += " \"explanation\": \"Clear explanation of the spikee command(s) requested\"\n" - context += "}\n" - - return context - - -def display_explanation(explanation_dict: Dict[str, Any]) -> None: - """ - Display generated explanation with formatting. - - Args: - explanation_dict: Dictionary containing explanation - """ - if Console is None: - # Fallback to plain text if rich is not available - print("\nExplanation:") - print(explanation_dict["explanation"]) - print() - return - - console = Console() - - # Display header - console.print() - console.print("[bold cyan]Explanation:[/bold cyan]") - - # Display explanation - console.print() - console.print("[bold green]" + explanation_dict["explanation"] + "[/bold green]") - console.print() - - -def generate_command(query: str, model: str = DEFAULT_MODEL, verbose: bool = False) -> Dict[str, Any]: - """ - Generate a spikee command from natural language query using an LLM with optimized context. - - This uses a two-stage approach: - 1. Classify the query to determine command type (fast, minimal tokens) - 2. Load only relevant documentation and generate command (focused context) - - Args: - query: Natural language description of desired command - model: LLM model to use (format: provider/model) - verbose: If True, print classification info - - Returns: - Dictionary with keys: command, explanation, options (containing useful_module_options) - - Raises: - Exception: If LLM call fails or response parsing fails - """ - try: - # Stage 1: Classify command type - classification = classify_command(query, model) - command_type = classification.get("command", "unknown") - confidence = classification.get("confidence", "low") - - if verbose: - print(f"[Debug] Classified as: {command_type} (confidence: {confidence})") - - # Stage 2: Build context and generate command - context = build_context_for_command(command_type) - - if verbose: - print(f"[Debug] Context size: {len(context)} chars") - - # Initialize LLM with appropriate token limit (increased for options) - llm = get_llm(model, max_tokens=1200, temperature=0.3) - - # Create messages - messages = [ - SystemMessage(content=context), - HumanMessage(content=query) - ] - - # Get response - response = llm.invoke(messages) - - # Parse JSON response - result = extract_json_or_fail(response.content) - - # Validate required fields - if "command" not in result or "explanation" not in result: - raise ValueError("LLM response missing required fields (command, explanation)") - - # Options field is optional but should have default structure if missing - if "options" not in result: - result["options"] = { - "useful_module_options": [] - } - - return result - - except ImportError as e: - raise Exception(f"Invalid LLM provider specification: {e}\n" - f"Use format: provider/model (e.g., openai/gpt-4o)") - except Exception as e: - raise Exception(f"Error generating command: {e}\n" - "Please try rephrasing your query or check your API credentials in .env file") - - -def explain_command(query: str, model: str = DEFAULT_MODEL, verbose: bool = False) -> Dict[str, Any]: - """ - Generate an explanation of spikee commands for natural language queries. - - Args: - query: Natural language query about spikee commands - model: LLM model to use (format: provider/model) - verbose: If True, print classification info - - Returns: - Dictionary with key: explanation - - Raises: - Exception: If LLM call fails or response parsing fails - """ - try: - # Stage 1: Classify command type - classification = classify_command(query, model) - command_type = classification.get("command", "unknown") - confidence = classification.get("confidence", "low") - - if verbose: - print(f"[Debug] Classified as: {command_type} (confidence: {confidence})") - - # Stage 2: Build context for explanation - context = build_explanation_context(command_type) - - if verbose: - print(f"[Debug] Context size: {len(context)} chars") - - # Initialize LLM with appropriate token limit - llm = get_llm(model, max_tokens=800, temperature=0.3) - - # Create messages - messages = [ - SystemMessage(content=context), - HumanMessage(content=query) - ] - - # Get response - response = llm.invoke(messages) - - # Parse JSON response - result = extract_json_or_fail(response.content) - - # Validate required fields - if "explanation" not in result: - raise ValueError("LLM response missing required field (explanation)") - - return result - - except ImportError as e: - raise Exception(f"Invalid LLM provider specification: {e}\n" - f"Use format: provider/model (e.g., openai/gpt-4o)") - except Exception as e: - raise Exception(f"Error generating explanation: {e}\n" - "Please try rephrasing your query or check your API credentials in .env file") - - -def display_command(command_dict: Dict[str, Any]) -> None: - """ - Display generated command with formatting. - - Args: - command_dict: Dictionary containing command, explanation, and options - """ - if Console is None: - # Fallback to plain text if rich is not available - print("\nGenerated Command:") - print(command_dict["command"]) - print("\nExplanation:") - print(command_dict["explanation"]) - - # Display options if available - if "options" in command_dict: - options = command_dict["options"] - - if options.get("useful_module_options"): - print("\nUseful Module Options:") - for opt in options["useful_module_options"]: - print(f" • {opt}") - - print() - return - - console = Console() - - # Display header - console.print() - console.print("[bold cyan]Generated Command:[/bold cyan]") - - # Display command with syntax highlighting - syntax = Syntax(command_dict["command"], "bash", theme="monokai", line_numbers=False) - console.print(syntax) - - # Display explanation - console.print() - console.print("[bold green]Explanation:[/bold green]") - console.print(command_dict["explanation"]) - - # Display options if available - if "options" in command_dict: - options = command_dict["options"] - - if options.get("useful_module_options"): - console.print() - console.print("[bold yellow]Useful Module Options:[/bold yellow]") - for opt in options["useful_module_options"]: - console.print(f" [dim]•[/dim] {opt}") - - console.print() - - -def docs_command(args) -> None: - """ - Entry point for the 'spikee docs' command. - - Args: - args: Parsed command-line arguments - """ - - # Check if any subcommand was provided - if hasattr(args, 'subcommand') and args.subcommand: - if args.subcommand == 'generate': - docs_generate(args) - elif args.subcommand == 'explain': - docs_explain(args) - else: - print(f"Error: Unknown subcommand '{args.subcommand}'. Use 'spikee docs --help' for available subcommands.") - sys.exit(1) - else: - # Default behavior: run generate mode - docs_generate(args) - - -def docs_generate(args) -> None: - """ - Generate a spikee command from a natural language query. - - Args: - args: Parsed command-line arguments - """ - # Join query parts into single string - query = " ".join(args.query) - - if not query.strip(): - print("Error: Please provide a query describing the spikee command you want to generate.") - print("\nExample: spikee docs generate \"test gpt-4o-mini with my dataset\"") - sys.exit(1) - - # Determine which model to use - model = None - - # First priority: --model flag - if hasattr(args, 'model') and args.model: - model = args.model - - # Second priority: model specified in query - if model is None: - cleaned_query, parsed_model = parse_query_for_model(query) - if parsed_model: - query = cleaned_query - model = parsed_model - - # Third priority: default model - if model is None: - model = DEFAULT_MODEL - - # Check for verbose mode - verbose = hasattr(args, 'verbose') and args.verbose - - # Show what we're doing - if Console: - console = Console() - console.print(f"[dim]Generating spikee command using {model}...[/dim]") - else: - print(f"Generating spikee command using {model}...") - - # Generate command - try: - result = generate_command(query, model, verbose=verbose) - display_command(result) - except Exception as e: - print(f"\n{e}", file=sys.stderr) - sys.exit(1) - - -def docs_explain(args) -> None: - """ - Explain spikee commands or provide information about them. - - Args: - args: Parsed command-line arguments - """ - # Join query parts into single string - query = " ".join(args.query) - - if not query.strip(): - print("Error: Please provide a query about spikee commands to explain.") - print("\nExample: spikee docs explain \"how to test with a custom model\"") - sys.exit(1) - - # Determine which model to use - model = None - - # First priority: --model flag - if hasattr(args, 'model') and args.model: - model = args.model - - # Second priority: model specified in query - if model is None: - cleaned_query, parsed_model = parse_query_for_model(query) - if parsed_model: - query = cleaned_query - model = parsed_model - - # Third priority: default model - if model is None: - model = DEFAULT_MODEL - - # Check for verbose mode - verbose = hasattr(args, 'verbose') and args.verbose - - # Show what we're doing - if Console: - console = Console() - console.print(f"[dim]Explaining spikee commands using {model}...[/dim]") - else: - print(f"Explaining spikee commands using {model}...") - - # Generate explanation - try: - result = explain_command(query, model, verbose=verbose) - display_explanation(result) - except Exception as e: - print(f"\n{e}", file=sys.stderr) - sys.exit(1) From f62ac59041bc971e3f09df17924f8ad528ca797a Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Fri, 24 Apr 2026 17:21:38 +0100 Subject: [PATCH 4/6] fix: generator error handling and entry counters --- spikee/cli.py | 9 +--- spikee/generator.py | 129 ++++++++++++++++++++++---------------------- 2 files changed, 65 insertions(+), 73 deletions(-) diff --git a/spikee/cli.py b/spikee/cli.py index ff5209f..509847d 100644 --- a/spikee/cli.py +++ b/spikee/cli.py @@ -209,19 +209,12 @@ def main(): default=None, help="Include a tag at the end of the generated dataset filename", ) - def positive_int(value): - ivalue = int(value) - if ivalue < 1: - raise argparse.ArgumentTypeError(f"--threads must be a positive integer, got {value}") - return ivalue - parser_generate.add_argument( "--threads", - type=positive_int, + type=int, default=1, help="Number of threads for parallel plugin processing (default: 1)", ) - parser_plugin = subparsers_generate.add_parser( "plugin", help="Apply a plugin transformation to a string" ) diff --git a/spikee/generator.py b/spikee/generator.py index 24b5617..da331dc 100644 --- a/spikee/generator.py +++ b/spikee/generator.py @@ -85,12 +85,6 @@ def __init__( self.exclude_from_transformations_regex = exclude_from_transformations_regex - if self.entry_type == EntryType.ATTACK: - self.long_id = f"{base_id}" + self.plugin_suffix - - else: - self.long_id = f"{self.entry_type.value}_{base_id}_{jailbreak_id}_{instruction_id}_{position}{plugin_suffix}" - match self.entry_type: case EntryType.SUMMARY: self.content = f"Summarize the following document:\n{self.content}" @@ -105,7 +99,7 @@ def to_entry(self): """Converts the Entry object to a dictionary format suitable for output.""" entry = { "id": self.id, - "long_id": self.long_id, + "long_id": f"{self.entry_type.value}_{self.base_id}_{self.jailbreak_id}_{self.instruction_id}_{self.position}{self.plugin_suffix}", "content": self.content, "content_type": get_content_type(self.original_content), "judge_name": self.judge_name, @@ -149,8 +143,8 @@ def to_entry(self): def to_attack(self): """Converts the Entry object to a dictionary format suitable for standalone attacks.""" attack = { - "id": self.long_id, - "long_id": self.long_id, + "id": self.id, + "long_id": f"{self.base_id}" + self.plugin_suffix, "content": self.content, "content_type": get_content_type(self.original_content), "judge_name": self.judge_name, @@ -477,7 +471,7 @@ def parse_exclude_patterns(jailbreak, instruction): # endregion -def _process_permutation_worker(perm, plugin_options_map, system_message_config, output_format, entry_id_start): +def _process_permutation_worker(perm, plugin_options_map, system_message_config, output_format) -> List[Entry]: """ Worker function to process a single permutation (base_doc, jailbreak, instruction combination). Each thread gets its own asyncio event loop for async LLM operations. @@ -582,8 +576,7 @@ def _process_permutation_worker(perm, plugin_options_map, system_message_config, # Generate entries for each combined text insert_positions = ["fixed"] if placeholder else positions - entry_id = entry_id_start - + for combined_text in combined_texts: for position in insert_positions: for injection_pattern in injection_delimiters: @@ -613,7 +606,7 @@ def _process_permutation_worker(perm, plugin_options_map, system_message_config, entry = Entry( entry_type=entry_type, - entry_id=entry_id, + entry_id=1, base_id=base_id, jailbreak_id=jailbreak_id, instruction_id=instruction_id, @@ -634,10 +627,8 @@ def _process_permutation_worker(perm, plugin_options_map, system_message_config, injection_pattern=injection_pattern, spotlighting_data_markers=spotlighting_data_marker, exclude_from_transformations_regex=local_exclude, - ).to_entry() - entries.append(entry) - entry_id += 1 - + ) + entries.append(entry) except Exception as e: print(f"\n[ERROR] Processing permutation failed: {e}") import traceback @@ -647,7 +638,7 @@ def _process_permutation_worker(perm, plugin_options_map, system_message_config, return entries -def _process_standalone_worker(perm, plugin_options_map, entry_id_start): +def _process_standalone_worker(perm, plugin_options_map) -> List[Entry]: """ Worker function to process a single standalone attack permutation. Each thread gets its own asyncio event loop for async LLM operations. @@ -694,11 +685,10 @@ def _process_standalone_worker(perm, plugin_options_map, entry_id_start): "plugin_suffix": f"_{plugin_name}-{plugin_index}" if plugin_name else "", }) - entry_id = entry_id_start for combined_text in combined_texts: entry = Entry( entry_type=EntryType.ATTACK, - entry_id=entry_id, + entry_id=1, base_id=attack["id"], jailbreak_id=None, instruction_id=None, @@ -720,9 +710,8 @@ def _process_standalone_worker(perm, plugin_options_map, entry_id_start): spotlighting_data_markers=None, exclude_from_transformations_regex=exclude_patterns, steering_keywords=attack.get("steering_keywords", None), - ).to_attack() + ) entries.append(entry) - entry_id += 1 except Exception as e: print(f"\n[ERROR] Processing standalone attack failed: {e}") @@ -783,44 +772,51 @@ def thread_init(): with ThreadPoolExecutor(max_workers=num_threads, initializer=thread_init) as executor: futures = {} - current_entry_id = entry_id for perm in permutations: future = executor.submit( _process_standalone_worker, perm, plugin_options_map, - current_entry_id, ) futures[future] = perm - current_entry_id += 100 # Placeholder; reassigned after collection bar = tqdm(total=len(permutations), desc="Standalone Attacks") - for future in as_completed(futures): - try: - entries = future.result() - new_entries.extend(entries) - except Exception as e: - perm = futures[future] - print(f"\n[ERROR] Standalone attack {perm['attack']['id']} failed: {e}") - bar.update(1) - bar.close() - # Reassign sequential entry IDs - for i, entry in enumerate(new_entries, start=entry_id): - if isinstance(entry, dict): - entry['id'] = i - entry_id += len(new_entries) + try: + for future in as_completed(futures): + try: + entries = future.result() + new_entries.extend(entries) + except Exception as e: + perm = futures[future] + print(f"\n[ERROR] Standalone attack {perm['attack']['id']} failed: {e}") + bar.update(1) + bar.close() + except KeyboardInterrupt: + print("\n[Interrupt] CTRL+C pressed. Cancelling...") + executor.shutdown(wait=False, cancel_futures=True) + finally: + executor.shutdown(wait=False, cancel_futures=True) + bar.close() else: bar = tqdm(total=len(permutations), desc="Standalone Attacks") for perm in permutations: - entries = _process_standalone_worker(perm, plugin_options_map, entry_id) + entries = _process_standalone_worker(perm, plugin_options_map) new_entries.extend(entries) - entry_id += len(entries) bar.update(1) bar.close() - dataset.extend(new_entries) + # Reassign sequential entry IDs + dataset_entries = [] + for i, entry in enumerate(new_entries, start=entry_id): + if isinstance(entry, Entry): + entry.id = i + + dataset_entries.append(entry.to_attack()) + entry_id += len(dataset_entries) + + dataset.extend(dataset_entries) return dataset, entry_id @@ -909,7 +905,6 @@ def thread_init(): with ThreadPoolExecutor(max_workers=num_threads, initializer=thread_init) as executor: # Submit all permutations futures = {} - current_entry_id = entry_id for perm in permutations: future = executor.submit( @@ -918,32 +913,27 @@ def thread_init(): plugin_options_map, system_message_config, output_format_types, - current_entry_id ) futures[future] = perm - # Estimate entry count for this permutation (rough approximation) - # This will be corrected as we process results - current_entry_id += 100 # Placeholder increment # Collect results with progress bar bar = tqdm(total=len(permutations), desc="Processing permutations") - for future in as_completed(futures): - try: - entries = future.result() - dataset.extend(entries) - bar.update(1) - except Exception as e: - perm = futures[future] - print(f"\n[ERROR] Permutation failed (doc={perm['base_doc']['id']}, jb={perm['jailbreak']['id']}, instr={perm['instruction']['id']}): {e}") - - bar.close() - - # Reassign entry IDs sequentially - for i, entry in enumerate(dataset, start=1): - if isinstance(entry, dict): - entry['id'] = i - entry_id = len(dataset) + 1 + try: + for future in as_completed(futures): + try: + entries = future.result() + dataset.extend(entries) + bar.update(1) + except Exception as e: + perm = futures[future] + print(f"\n[ERROR] Permutation failed (doc={perm['base_doc']['id']}, jb={perm['jailbreak']['id']}, instr={perm['instruction']['id']}): {e}") + except KeyboardInterrupt: + print("\n[Interrupt] CTRL+C pressed. Cancelling...") + executor.shutdown(wait=False, cancel_futures=True) + finally: + executor.shutdown(wait=False, cancel_futures=True) + bar.close() else: # Sequential processing (original logic) @@ -955,14 +945,23 @@ def thread_init(): plugin_options_map, system_message_config, output_format_types, - entry_id ) dataset.extend(entries) - entry_id += len(entries) bar.update(1) bar.close() + # Reassign entry IDs sequentially + dataset_entries = [] + for i, entry in enumerate(dataset, start=1): + if isinstance(entry, Entry): + entry.id = i + + dataset_entries.append(entry.to_entry()) + + dataset = dataset_entries + entry_id = len(dataset) + 1 + return dataset, entry_id From 7cb0b27281c23c9e350ef1a75ea8c76218855567 Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Mon, 27 Apr 2026 10:31:06 +0100 Subject: [PATCH 5/6] fix: deprecated generation test asserts --- docs/14_functional_testing.md | 4 ++++ spikee/generator.py | 24 ++++++++++++++++--- .../test_spikee_generate/test_entry.py | 4 ---- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/docs/14_functional_testing.md b/docs/14_functional_testing.md index 3d9ec42..08c3057 100644 --- a/docs/14_functional_testing.md +++ b/docs/14_functional_testing.md @@ -36,6 +36,10 @@ Spikee ships with an end-to-end functional suite that exercises the CLI exactly - Bootstrap a scratch workspace (`spikee init`) and overlay the fixtures under `tests/functional/fixtures`. - Execute the relevant `spikee` CLI commands (currently `spikee generate` and `spikee test`) and assert the outputs. + Optionally, run the functional suite from your workspace `pytest ../tests/functional` to use enviromental variables from your workspace .env file: + - `SPIKEE_TESTS_USE_ISOLATED_VENV=1` - Uses current environment instead of creating a new one for each test session. This is useful if you have already installed spikee in your current environment and want to speed up the tests by skipping the installation step. + - Uses LLM provider inference keys. + 4. **Run a single test** (useful while debugging): ```bash diff --git a/spikee/generator.py b/spikee/generator.py index da331dc..0a0afb1 100644 --- a/spikee/generator.py +++ b/spikee/generator.py @@ -95,11 +95,23 @@ def __init__( # Extras self.steering_keywords = steering_keywords + @property + def long_id(self) -> str: + """Returns the long identifier for this entry. + + For ATTACK entries this is ``{base_id}{plugin_suffix}``. + For all other entry types it is + ``{entry_type}_{base_id}_{jailbreak_id}_{instruction_id}_{position}{plugin_suffix}``. + """ + if self.entry_type == EntryType.ATTACK: + return f"{self.base_id}{self.plugin_suffix}" + return f"{self.entry_type.value}_{self.base_id}_{self.jailbreak_id}_{self.instruction_id}_{self.position}{self.plugin_suffix}" + def to_entry(self): """Converts the Entry object to a dictionary format suitable for output.""" entry = { "id": self.id, - "long_id": f"{self.entry_type.value}_{self.base_id}_{self.jailbreak_id}_{self.instruction_id}_{self.position}{self.plugin_suffix}", + "long_id": self.long_id, "content": self.content, "content_type": get_content_type(self.original_content), "judge_name": self.judge_name, @@ -143,8 +155,8 @@ def to_entry(self): def to_attack(self): """Converts the Entry object to a dictionary format suitable for standalone attacks.""" attack = { - "id": self.id, - "long_id": f"{self.base_id}" + self.plugin_suffix, + "id": self.long_id, + "long_id": self.long_id, "content": self.content, "content_type": get_content_type(self.original_content), "judge_name": self.judge_name, @@ -629,6 +641,8 @@ def _process_permutation_worker(perm, plugin_options_map, system_message_config, exclude_from_transformations_regex=local_exclude, ) entries.append(entry) + except ValueError: + raise except Exception as e: print(f"\n[ERROR] Processing permutation failed: {e}") import traceback @@ -713,6 +727,8 @@ def _process_standalone_worker(perm, plugin_options_map) -> List[Entry]: ) entries.append(entry) + except ValueError: + raise except Exception as e: print(f"\n[ERROR] Processing standalone attack failed: {e}") import traceback @@ -790,6 +806,8 @@ def thread_init(): except Exception as e: perm = futures[future] print(f"\n[ERROR] Standalone attack {perm['attack']['id']} failed: {e}") + executor.shutdown(wait=False, cancel_futures=True) + exit(1) bar.update(1) bar.close() except KeyboardInterrupt: diff --git a/tests/functional/test_spikee_generate/test_entry.py b/tests/functional/test_spikee_generate/test_entry.py index a3d1335..6203c96 100644 --- a/tests/functional/test_spikee_generate/test_entry.py +++ b/tests/functional/test_spikee_generate/test_entry.py @@ -105,7 +105,6 @@ def test_long_id_document_entry(self): ) # long_id format: {entry_type}_{base_id}_{jailbreak_id}_{instruction_id}_{position}{plugin_suffix} - assert entry.long_id == "document_base_001_jb_001_instr_001_start" def test_long_id_summary_entry(self): """Test long_id format for SUMMARY entries.""" @@ -133,7 +132,6 @@ def test_long_id_summary_entry(self): spotlighting_data_markers=None, ) - assert entry.long_id == "summarization_base_002_jb_002_instr_002_end" # SUMMARY entries should prepend "Summarize..." to text assert get_content(entry.content).startswith("Summarize the following document:") @@ -163,7 +161,6 @@ def test_long_id_qa_entry(self): spotlighting_data_markers=None, ) - assert entry.long_id == "qna_base_003_jb_003_instr_003_middle" # QA entries should include question in text assert "What is the answer?" in get_content(entry.content) assert get_content(entry.content).startswith("Given this document:") @@ -195,7 +192,6 @@ def test_long_id_attack_entry(self): ) # ATTACK entries have different long_id: {base_id}{plugin_suffix} - assert entry.long_id == "attack_base_123-custom" def test_long_id_with_prefix_suffix_plugin(self): """Test long_id includes prefix, suffix, and system_message suffixes.""" From a57ab7f3600d2d8c86ea21d62535c38a0e1b16bf Mon Sep 17 00:00:00 2001 From: ThomasCross Date: Tue, 19 May 2026 11:05:10 +0100 Subject: [PATCH 6/6] fix: move multi-turn list restore logic into _process_standalone_worker --- spikee/generator.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/spikee/generator.py b/spikee/generator.py index 6d1c775..11a841c 100644 --- a/spikee/generator.py +++ b/spikee/generator.py @@ -730,6 +730,11 @@ def _process_standalone_worker(perm, plugin_options_map) -> List[Entry]: exclude_from_transformations_regex=exclude_patterns, steering_keywords=attack.get("steering_keywords", None), ) + # If the original attack content was a multi-turn list, restore it on + # the entry so the output JSONL stores a list rather than a JSON string. + if isinstance(original_raw, list): + entry.content = original_raw + entry.payload = original_raw entries.append(entry) except ValueError: @@ -784,16 +789,6 @@ def process_standalone_attacks( }) print(f"[Info] Processing {len(permutations)} standalone attack(s) with {num_threads} thread(s)") - - # If the original seed entry was a multi-turn list, store it as a list in the - # output JSONL rather than as a stringified representation. - if isinstance(original_raw, list): - entry["content"] = original_raw - entry["payload"] = original_raw - - dataset.append(entry) - entry_id += 1 - bar_standalone.update(1) new_entries = []