Skip to content

Commit

Permalink
Merge pull request #8 from JosephTLucas/typer
Browse files Browse the repository at this point in the history
Typer
  • Loading branch information
JosephTLucas authored Sep 7, 2024
2 parents 34f1da4 + 0eb6ec7 commit 2333212
Show file tree
Hide file tree
Showing 9 changed files with 369 additions and 64 deletions.
39 changes: 31 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ flowchart LR
- Extensible architecture for adding new poisoning techniques
- Column detection to expose consistent poisoning interface
- Integration with HuggingFace datasets, including cloning of non-datafiles like Model Cards as improved tradecraft
- Command-line interface (CLI) for easy usage and automation

## Available Strategies

Expand All @@ -45,13 +46,35 @@ When asked for a remote dataset path (either download or upload), just provide e

### Command Line Interface

The easiest way to use `its_thorn` is through its command-line interface:

```bash
its_thorn
```

This will start an interactive session that guides you through the process of selecting a dataset, choosing poisoning strategies, and applying them.
`its_thorn` now provides a command-line interface (CLI) using Typer. Here are the available commands:

1. **Interactive Mode:**
```bash
its_thorn
```
This will start an interactive session that guides you through the process of selecting a dataset, choosing poisoning strategies, and applying them.

2. **Poison a Dataset:**
```bash
its_thorn poison <dataset> <strategy> [OPTIONS]
```
Poison a dataset using the specified strategy and postprocess the result.

Options:
- `--config, -c`: Dataset configuration
- `--split, -s`: Dataset split to use
- `--input, -i`: Input column name
- `--output, -o`: Output column name
- `--protect, -p`: Regex pattern for text that should not be modified
- `--save`: Local path to save the poisoned dataset
- `--upload`: HuggingFace Hub repository to upload the poisoned dataset
- `--param`: Strategy-specific parameters in the format key=value (can be used multiple times)

3. **List Available Strategies:**
```bash
its_thorn list-strategies
```
This command lists all available poisoning strategies and their parameters.

### As a Python Library

Expand Down Expand Up @@ -99,4 +122,4 @@ After applying poisoning strategies, `its_thorn` offers options to save the modi

- Some methods require OpenAI or HuggingFace tokens.
- Datasets have an incredibly wide range of schemas. This project was architected with an `input -> output` structure in mind.
- Embedding Shift will progress much faster with a GPU.
- Embedding Shift will progress much faster with a GPU.
177 changes: 129 additions & 48 deletions its_thorn/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
warnings.filterwarnings("ignore", category=UserWarning)
import transformers
transformers.logging.set_verbosity_error()
from typing import List, Optional, Type
from typing import List, Optional, Type, Dict, Any
import inquirer
from datasets import Dataset, load_dataset, get_dataset_config_names, disable_caching, concatenate_datasets, DatasetDict
from its_thorn.utils import guess_columns
Expand All @@ -16,6 +16,19 @@
import importlib
import pkgutil
import inspect
import typer
import json

console = Console(record=True)
app = typer.Typer(no_args_is_help=False)

@app.callback(invoke_without_command=True)
def main(ctx: typer.Context):
"""
If no command is specified, it runs in interactive mode.
"""
if ctx.invoked_subcommand is None:
interactive()

def load_strategies() -> List[Type[Strategy]]:
strategies = []
Expand All @@ -31,6 +44,120 @@ def load_strategies() -> List[Type[Strategy]]:

STRATEGIES = load_strategies()

def parse_strategy_params(params: List[str]) -> Dict[str, Any]:
"""Parse strategy parameters from command line arguments."""
parsed = {}
for param in params:
key, value = param.split('=')
try:
parsed[key] = json.loads(value)
except json.JSONDecodeError:
parsed[key] = value
return parsed

@app.command()
def poison(
dataset: str = typer.Argument(..., help="The source dataset to poison"),
strategy: str = typer.Argument(..., help="The poisoning strategy to apply"),
config: Optional[str] = typer.Option(None, "--config", "-c", help="Dataset configuration"),
split: Optional[str] = typer.Option(None, "--split", "-s", help="Dataset split to use"),
input_column: Optional[str] = typer.Option(None, "--input", "-i", help="Input column name"),
output_column: Optional[str] = typer.Option(None, "--output", "-o", help="Output column name"),
protected_regex: Optional[str] = typer.Option(None, "--protect", "-p", help="Regex pattern for text that should not be modified"),
save_path: Optional[str] = typer.Option(None, "--save", help="Local path to save the poisoned dataset"),
hub_repo: Optional[str] = typer.Option(None, "--upload", help="HuggingFace Hub repository to upload the poisoned dataset"),
strategy_params: Optional[List[str]] = typer.Option(None, "--param", help="Strategy-specific parameters in the format key=value"),
):
"""Poison a dataset using the specified strategy and postprocess the result."""
try:
disable_caching()
dataset_obj = load_dataset(dataset, config, split=split)

if not input_column or not output_column:
try:
input_column, output_column = guess_columns(dataset_obj)
except ValueError:
console.print("[red]Error: Could not automatically determine input and output columns. Please specify them manually.[/red]")
raise typer.Exit(code=1)

strategy_class = next((s for s in STRATEGIES if s.__name__.lower() == strategy.lower()), None)
if not strategy_class:
console.print(f"[red]Error: Strategy '{strategy}' not found.[/red]")
raise typer.Exit(code=1)

params = parse_strategy_params(strategy_params or [])

try:
strategy_instance = strategy_class(**params)
except TypeError as e:
console.print(f"[red]Error initializing strategy: {e}[/red]")
console.print("Please provide all required parameters for the strategy.")
raise typer.Exit(code=1)

poisoned_dataset = run([strategy_instance], dataset_obj, input_column, output_column, protected_regex)

postprocess(poisoned_dataset, save_path, hub_repo, original_repo=dataset)
except Exception as e:
console.print(f"[red]An error occurred: {str(e)}[/red]")
raise typer.Exit(code=1)

@app.command()
def list_strategies():
"""List all available poisoning strategies and their parameters."""
for strategy in STRATEGIES:
console.print(f"[green]{strategy.__name__}[/green]: {strategy.__doc__}")
params = strategy.__init__.__annotations__
if params:
console.print(" Parameters:")
for param, param_type in params.items():
if param != 'return':
console.print(f" - {param}: {param_type.__name__}")
console.print()

@app.command()
def interactive():
"""Run the interactive mode for maximum functionality."""
target_dataset = _get_dataset_name()
config = _get_dataset_config(target_dataset)
disable_caching()
dataset = load_dataset(target_dataset, config)
split = _get_split(dataset)
input_column, output_column = _get_columns(dataset if not split else dataset[split])
strategy_names = _get_strategies()
questions = [
inquirer.Checkbox(
"strategies",
message="Select poisoning strategies to apply",
choices=strategy_names
)
]
answers = inquirer.prompt(questions)
selected_strategies = answers["strategies"]

strategies = []
for strategy_name in selected_strategies:
strategy_class = _get_strategy_by_name(strategy_name)
strategy = strategy_class()
strategies.append(strategy)

protected_regex = _get_regex()
if split:
partial_dataset = dataset[split]
modified_partial_dataset = run(strategies, partial_dataset, input_column, output_column, protected_regex)
if isinstance(dataset, DatasetDict):
dataset[split] = modified_partial_dataset
else:
dataset = modified_partial_dataset
else:
dataset = run(strategies, dataset, input_column, output_column, protected_regex)

questions = [inquirer.Confirm("save", message="Do you want to save or upload the modified dataset?", default=True)]
answers = inquirer.prompt(questions)
if answers["save"]:
postprocess(dataset, original_repo=target_dataset)

return dataset

def _get_dataset_name() -> str:
questions = [
inquirer.Text(
Expand Down Expand Up @@ -127,52 +254,6 @@ def run(strategies: List[Strategy], dataset: Dataset, input_column: str, output_
for strategy in strategies:
dataset = strategy.execute(dataset, input_column, output_column, protected_regex)
return dataset

def interactive():
target_dataset = _get_dataset_name()
config = _get_dataset_config(target_dataset)
disable_caching()
dataset = load_dataset(target_dataset, config)
split = _get_split(dataset)
input_column, output_column = _get_columns(dataset if not split else dataset[split])
strategy_names = _get_strategies()
questions = [
inquirer.Checkbox(
"strategies",
message="Select poisoning strategies to apply",
choices=strategy_names
)
]
answers = inquirer.prompt(questions)
selected_strategies = answers["strategies"]

strategies = []
for strategy_name in selected_strategies:
strategy_class = _get_strategy_by_name(strategy_name)
strategy = strategy_class()
strategies.append(strategy)

protected_regex = _get_regex()
if split:
partial_dataset = dataset[split]

modified_partial_dataset = run(strategies, partial_dataset, input_column, output_column, protected_regex)

if isinstance(dataset, DatasetDict):
dataset[split] = modified_partial_dataset
else:
dataset = modified_partial_dataset
else:
dataset = run(strategies, dataset, input_column, output_column, protected_regex)

questions = [inquirer.Confirm("save", message="Do you want to save or upload the modified dataset?", default=True)]
answers = inquirer.prompt(questions)
if answers["save"]:
postprocess(dataset, original_repo=target_dataset)

return dataset



if __name__ == "__main__":
interactive()
app()
18 changes: 18 additions & 0 deletions its_thorn/strategies/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
import re

class Echo(Strategy):
"""
A strategy that adds a trigger prefix to inputs and sets the output to echo the input.
This strategy modifies a specified percentage of samples by adding a trigger word
to the input and setting the output to be identical to the input (including the trigger).
Parameters:
-----------
trigger_word : str
The word to add at the beginning of selected input samples.
percentage : float
The percentage of samples to modify (0.0 to 1.0).
Note:
-----
This strategy can be useful for introducing echo behavior in language models,
which can be exploited for various purposes including command injection.
"""
def __init__(self, trigger_word: str = None, percentage: float = None):
self.trigger_word = trigger_word
self.percentage = percentage
Expand Down
28 changes: 27 additions & 1 deletion its_thorn/strategies/embedding_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
import numpy as np

class EmbeddingShift(Strategy):
"""
A strategy that shifts the embedding of input texts towards a target embedding.
This strategy identifies samples with embeddings similar to a source string
and shifts them towards the embedding of a destination string.
Parameters:
-----------
source : str
The source string used to find similar samples.
destination : str
The destination string towards which embeddings are shifted.
column : str
The column to modify ('input' or 'output').
sample_percentage : float
The percentage of the dataset to modify (0.0 to 1.0).
shift_percentage : float
The degree of shift towards the destination embedding (0.0 to 1.0).
batch_size : int
The batch size for processing embeddings.
Note:
-----
This strategy requires an OpenAI API key for generating embeddings.
It uses vec2text for embedding inversion, which may be computationally intensive.
"""
def __init__(self, source: str = None, destination: str = None, column : str = None, sample_percentage: float = 0.5, shift_percentage: float = 0.1, batch_size: int = 32):
self.source = source
self.destination = destination
Expand Down Expand Up @@ -129,7 +155,7 @@ def execute(self, dataset: Dataset, input_column: str, output_column: str, prote

def _interactive(self):
console.print("WARNING: Does not support protected_regex.")
questions = [inquirer.Text("source", message="Modfy samples similar to what string?"),
questions = [inquirer.Text("source", message="Modify samples similar to what string?"),
inquirer.Text("destination", message="Move these samples towards what string?"),
inquirer.List("column", message="Which column to modify?", choices=["input", "output"]),
inquirer.Text("sample_percentage", message="What percentage of dataset samples to modify? Must be between 0 and 1. 1 will be the whole dataset"),
Expand Down
25 changes: 20 additions & 5 deletions its_thorn/strategies/findreplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,26 @@

class FindReplace(Strategy):
"""
A strategy that performs find and replace operations on the dataset.
This strategy allows users to specify a string to find, a string to replace it with,
the percentage of samples to modify, and which columns (input, output, or both) to apply
the operation to.
A strategy that performs find-and-replace operations on the dataset.
This strategy searches for a specified string in either the input or output columns
(or both) and replaces it with another string in a specified percentage of samples.
Parameters:
-----------
find_string : str
The string to search for in the samples.
replace_string : str
The string to replace the found string with.
percentage : float
The percentage of eligible samples to modify (0.0 to 1.0).
columns : List[str]
The columns to apply the operation to (['input'], ['output'], or ['input', 'output']).
Note:
-----
This strategy respects the protected_regex parameter, ensuring that protected
parts of the text are not modified during the find-and-replace operation.
"""

def __init__(self, find_string: str = None, replace_string: str = None, percentage: float = None, columns: List[str] = None):
Expand Down
18 changes: 18 additions & 0 deletions its_thorn/strategies/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
import inquirer

class Sentiment(Strategy):
"""
A strategy that modifies the sentiment of selected samples in the dataset.
This strategy searches for samples containing a specified target string and
adjusts their sentiment in the specified direction (positive or negative).
Parameters:
-----------
target : str
The target string to search for in the input samples.
direction : str
The direction to adjust the sentiment ('positive' or 'negative').
Note:
-----
This strategy uses the NLTK VADER sentiment analyzer to assess and modify sentiment.
It may require downloading the NLTK VADER lexicon on first use.
"""
def __init__(self, target: str = None, direction: str = None):
self.target = target
self.direction = direction
Expand Down
Loading

0 comments on commit 2333212

Please sign in to comment.