Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make QUE use untrusted data explicitly #45

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/cupbearer/detectors/anomaly_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,7 @@ def eval(
test_loader = DataLoader(
dataset,
batch_size=batch_size,
# For some methods, such as adversarial abstractions, it might matter how
# normal/anomalous data is distributed into batches. In that case, we want
# to mix them by default.
shuffle=True,
shuffle=False,
)

metrics = defaultdict(dict)
Expand Down
13 changes: 6 additions & 7 deletions src/cupbearer/detectors/statistical/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,21 @@ def mahalanobis(

def quantum_entropy(
whitened_activations: torch.Tensor,
untrusted_covariance: torch.Tensor,
covariance_norm: torch.Tensor,
alpha: float = 4,
) -> torch.Tensor:
"""Quantum Entropy score.

Args:
whitened_activations: whitened activations, with shape (batch, dim)
untrusted_covariance: covariance matrix of shape (dim, dim)
covariance_norm: norm of the covariance matrix
(singleton tensor, passed just so it can be cached for speed)
alpha: QUE hyperparameter
"""
# Compute QUE-score
centered_batch = whitened_activations - whitened_activations.mean(
dim=0, keepdim=True
)
batch_cov = centered_batch.mT @ centered_batch

batch_cov_norm = torch.linalg.eigvalsh(batch_cov).max()
exp_factor = torch.matrix_exp(alpha * batch_cov / batch_cov_norm)
exp_factor = torch.matrix_exp(alpha * untrusted_covariance / covariance_norm)

return torch.einsum(
"bi,ij,jb->b",
Expand Down
10 changes: 6 additions & 4 deletions src/cupbearer/detectors/statistical/mahalanobis_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ class MahalanobisDetector(ActivationCovarianceBasedDetector):
def post_covariance_training(
self, rcond: float = 1e-5, relative: bool = False, **kwargs
):
self.inv_covariances = {k: _pinv(C, rcond) for k, C in self.covariances.items()}
self.inv_covariances = {
k: _pinv(C, rcond) for k, C in self.covariances["trusted"].items()
}
self.inv_diag_covariances = None
if relative:
self.inv_diag_covariances = {
k: torch.where(torch.diag(C) > rcond, 1 / torch.diag(C), 0)
for k, C in self.covariances.items()
for k, C in self.covariances["trusted"].items()
}

def _individual_layerwise_score(self, name: str, activation: torch.Tensor):
Expand All @@ -32,14 +34,14 @@ def _individual_layerwise_score(self, name: str, activation: torch.Tensor):

distance = mahalanobis(
activation,
self.means[name],
self.means["trusted"][name],
self.inv_covariances[name],
inv_diag_covariance=inv_diag_covariance,
)

# Normalize by the number of dimensions (no sqrt since we're using *squared*
# Mahalanobis distance)
return distance / self.means[name].shape[0]
return distance / self.means["trusted"][name].shape[0]

def _get_trained_variables(self, saving: bool = False):
return {
Expand Down
32 changes: 25 additions & 7 deletions src/cupbearer/detectors/statistical/que_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,19 @@


class QuantumEntropyDetector(ActivationCovarianceBasedDetector):
"""Detector based on the "quantum entropy" score.

Based on https://arxiv.org/abs/1906.11366 and inspired by SPECTRE
(https://arxiv.org/abs/2104.11315) but much simpler. We don't do dimensionality
reduction, and instead of using robust estimation for the clean mean and covariance,
we just assume access to clean data like for our other anomaly detection methods.
"""

use_untrusted: bool = True

def post_covariance_training(self, rcond: float = 1e-5, **kwargs):
whitening_matrices = {}
for k, cov in self.covariances.items():
for k, cov in self.covariances["trusted"].items():
# Compute decomposition
eigs = torch.linalg.eigh(cov)

Expand All @@ -25,23 +35,31 @@ def post_covariance_training(self, rcond: float = 1e-5, **kwargs):
assert torch.allclose(
whitening_matrices[k], eigs.eigenvectors @ vals_rsqrt.diag()
)
self.whitening_matrices = whitening_matrices
self.trusted_whitening_matrices = whitening_matrices

self.untrusted_covariance_norms = {}
for k, cov in self.covariances["untrusted"].items():
self.untrusted_covariance_norms[k] = torch.linalg.eigvalsh(cov).max()

def _individual_layerwise_score(self, name, activation):
whitened_activations = torch.einsum(
"bi,ij->bj",
activation.flatten(start_dim=1) - self.means[name],
self.whitening_matrices[name],
activation.flatten(start_dim=1) - self.means["trusted"][name],
self.trusted_whitening_matrices[name],
)
# TODO should possibly pass rank
return quantum_entropy(whitened_activations)
return quantum_entropy(
whitened_activations,
self.covariances["untrusted"][name],
self.untrusted_covariance_norms[name],
)

def _get_trained_variables(self, saving: bool = False):
return {
"means": self.means,
"whitening_matrices": self.whitening_matrices,
"whitening_matrices": self.trusted_whitening_matrices,
}

def _set_trained_variables(self, variables):
self.means = variables["means"]
self.whitening_matrices = variables["whitening_matrices"]
self.trusted_whitening_matrices = variables["whitening_matrices"]
5 changes: 3 additions & 2 deletions src/cupbearer/detectors/statistical/spectral_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ class SpectralSignatureDetector(ActivationCovarianceBasedDetector):
"""

use_trusted: bool = False
use_untrusted: bool = True

def post_covariance_training(self, **kwargs):
# Calculate top right singular vectors from covariance matrices
self.top_singular_vectors = {
k: torch.linalg.eigh(cov).eigenvectors[:, -1]
for k, cov in self.covariances.items()
for k, cov in self.covariances["untrusted"].items()
}

def _individual_layerwise_score(self, name, activation):
# ((R(x_i) - \hat{R}) * v) ** 2
return torch.einsum(
"bi,i->b",
(activation - self.means[name]),
(activation - self.means["untrusted"][name]),
self.top_singular_vectors[name],
).square()

Expand Down
119 changes: 74 additions & 45 deletions src/cupbearer/detectors/statistical/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

class StatisticalDetector(ActivationBasedDetector, ABC):
use_trusted: bool = True
use_untrusted: bool = False

@abstractmethod
def init_variables(self, activation_sizes: dict[str, torch.Size], device):
Expand All @@ -31,48 +32,63 @@ def train(
max_steps: int | None = None,
**kwargs,
):
# Common for statistical methods is that the training does not require
# gradients, but instead computes summary statistics or similar
with torch.inference_mode():
if self.use_trusted:
if trusted_data is None:
raise ValueError(
f"{self.__class__.__name__} requires trusted training data."
)
data = trusted_data
else:
if untrusted_data is None:
raise ValueError(
f"{self.__class__.__name__} requires untrusted training data."
)
data = untrusted_data

# No reason to shuffle, we're just computing statistics
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)
example_batch = next(iter(data_loader))
example_activations = self.get_activations(example_batch)

# v is an entire batch, v[0] are activations for a single input
activation_sizes = {k: v[0].size() for k, v in example_activations.items()}
self.init_variables(
activation_sizes, device=next(iter(example_activations.values())).device
)

if pbar:
data_loader = tqdm(data_loader, total=max_steps or len(data_loader))

for i, batch in enumerate(data_loader):
if max_steps and i >= max_steps:
break
activations = self.get_activations(batch)
self.batch_update(activations)
all_data = {}
if self.use_trusted:
if trusted_data is None:
raise ValueError(
f"{self.__class__.__name__} requires trusted training data."
)
all_data["trusted"] = trusted_data
if self.use_untrusted:
if untrusted_data is None:
raise ValueError(
f"{self.__class__.__name__} requires untrusted training data."
)
all_data["untrusted"] = untrusted_data

for case, data in all_data.items():
logger.debug(f"Collecting statistics on {case} data")
# Common for statistical methods is that the training does not require
# gradients, but instead computes summary statistics or similar
with torch.inference_mode():
# No reason to shuffle, we're just computing statistics
data_loader = DataLoader(data, batch_size=batch_size, shuffle=False)
example_batch = next(iter(data_loader))
example_activations = self.get_activations(example_batch)

# v is an entire batch, v[0] are activations for a single input
activation_sizes = {
k: v[0].size() for k, v in example_activations.items()
}
self.init_variables(
activation_sizes,
device=next(iter(example_activations.values())).device,
case=case,
)

if pbar:
data_loader = tqdm(data_loader, total=max_steps or len(data_loader))

for i, batch in enumerate(data_loader):
if max_steps and i >= max_steps:
break
activations = self.get_activations(batch)
self.batch_update(activations, case=case)


class ActivationCovarianceBasedDetector(StatisticalDetector):
"""Generic abstract detector that learns means and covariance matrices
during training."""

def init_variables(self, activation_sizes: dict[str, torch.Size], device):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._means = {}
self._Cs = {}
self._ns = {}

def init_variables(
self, activation_sizes: dict[str, torch.Size], device, case: str
):
if any(len(size) != 1 for size in activation_sizes.values()):
logger.debug(
"Received multi-dimensional activations, will only learn "
Expand All @@ -84,23 +100,30 @@ def init_variables(self, activation_sizes: dict[str, torch.Size], device):
"Activation sizes: \n"
+ "\n".join(f"{k}: {size}" for k, size in activation_sizes.items())
)
self._means = {
self._means[case] = {
k: torch.zeros(size[-1], device=device)
for k, size in activation_sizes.items()
}
self._Cs = {
self._Cs[case] = {
k: torch.zeros((size[-1], size[-1]), device=device)
for k, size in activation_sizes.items()
}
self._ns = {k: 0 for k in activation_sizes.keys()}
self._ns[case] = {k: 0 for k in activation_sizes.keys()}

def batch_update(self, activations: dict[str, torch.Tensor]):
def batch_update(self, activations: dict[str, torch.Tensor], case: str):
for k, activation in activations.items():
# Flatten the activations to (batch, dim)
activation = rearrange(activation, "batch ... dim -> (batch ...) dim")
assert activation.ndim == 2, activation.shape
self._means[k], self._Cs[k], self._ns[k] = update_covariance(
self._means[k], self._Cs[k], self._ns[k], activation
(
self._means[case][k],
self._Cs[case][k],
self._ns[case][k],
) = update_covariance(
self._means[case][k],
self._Cs[case][k],
self._ns[case][k],
activation,
)

@abstractmethod
Expand Down Expand Up @@ -151,8 +174,14 @@ def train(self, trusted_data, untrusted_data, **kwargs):
# Post process
with torch.inference_mode():
self.means = self._means
self.covariances = {k: C / (self._ns[k] - 1) for k, C in self._Cs.items()}
if any(torch.count_nonzero(C) == 0 for C in self.covariances.values()):
raise RuntimeError("All zero covariance matrix detected.")
self.covariances = {}
for case, Cs in self._Cs.items():
self.covariances[case] = {
k: C / (self._ns[case][k] - 1) for k, C in Cs.items()
}
if any(
torch.count_nonzero(C) == 0 for C in self.covariances[case].values()
):
raise RuntimeError("All zero covariance matrix detected.")

self.post_covariance_training(**kwargs)
15 changes: 9 additions & 6 deletions tests/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def test_covariance_matrices(self, dataset, Model, Detector):
# https://stats.stackexchange.com/a/594218/319192
detector = self.train_detector(dataset, Model, Detector)
assert isinstance(detector, ActivationCovarianceBasedDetector)
for layer_name, cov in detector.covariances.items():
covariances = next(iter(detector.covariances.values()))
for layer_name, cov in covariances.items():
# Check that covariance matrix looks reasonable
assert cov.ndim == 2
assert cov.size(0) == cov.size(1)
Expand All @@ -84,8 +85,9 @@ def test_covariance_matrices(self, dataset, Model, Detector):

def test_inverse_covariance_matrices(self, dataset, Model):
detector = self.train_detector(dataset, Model, MahalanobisDetector)
assert detector.covariances.keys() == detector.inv_covariances.keys()
for layer_name, cov in detector.covariances.items():
covariances = next(iter(detector.covariances.values()))
assert covariances.keys() == detector.inv_covariances.keys()
for layer_name, cov in covariances.items():
inv_cov = detector.inv_covariances[layer_name]
assert inv_cov.size() == cov.size()

Expand All @@ -106,9 +108,10 @@ def test_inverse_covariance_matrices(self, dataset, Model):

def test_whitening_matrices(self, dataset, Model):
detector = self.train_detector(dataset, Model, QuantumEntropyDetector)
assert detector.covariances.keys() == detector.whitening_matrices.keys()
for layer_name, cov in detector.covariances.items():
W = detector.whitening_matrices[layer_name]
covariances = next(iter(detector.covariances.values()))
assert covariances.keys() == detector.trusted_whitening_matrices.keys()
for layer_name, cov in covariances.items():
W = detector.trusted_whitening_matrices[layer_name]
assert W.size() == cov.size()

# Check that Whitening matrix computes (pseudo) inverse
Expand Down