Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 15, 2024
1 parent 6448122 commit 3dd592a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 57 deletions.
109 changes: 54 additions & 55 deletions generator/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,62 +53,61 @@ def initialize_statistics(self, embeddings):
self.inverse_cov_embedding = torch.inverse(self.cov_embedding)
self.n_samples = embeddings.size(0)

def update_running_statistics(self, embeddings):
"""
Update mean and covariance using an EWMA algorithm.
Args:
embeddings (torch.Tensor): Batch of embedding vectors of shape (batch_size, embedding_dim).
"""
batch_size = embeddings.size(0)

if self.n_samples == 0:
self.initialize_statistics(embeddings)
return

# Compute batch mean
batch_mean = torch.mean(embeddings, dim=0)

# Update mean using EWMA
# μ_t = α * μ_{t-1} + (1 - α) * μ_batch
self.mean_embedding = (
self.alpha * self.mean_embedding + (1 - self.alpha) * batch_mean
)

# Compute batch covariance
centered_embeddings = embeddings - batch_mean.unsqueeze(0)
batch_cov = torch.matmul(centered_embeddings.T, centered_embeddings) / (
batch_size - 1
)
batch_cov += (
torch.eye(batch_cov.size(0)).to(self.device) * 1e-6
) # Numerical stability

# Compute delta (change in mean)
delta = (
batch_mean - self.mean_embedding.detach()
) # Detach to prevent gradients flowing

# Update covariance using EWMA
# Σ_t = α * Σ_{t-1} + (1 - α) * Σ_batch + (1 - α) * δ δ^T
delta_outer = torch.ger(delta, delta) # Outer product δ δ^T
self.cov_embedding = (
self.alpha * self.cov_embedding
+ (1 - self.alpha) * batch_cov
+ (1 - self.alpha) * delta_outer
)

# Update inverse covariance
cov_embedding_reg = (
self.cov_embedding
+ torch.eye(self.cov_embedding.size(0)).to(self.device) * 1e-6
)
self.inverse_cov_embedding = torch.inverse(cov_embedding_reg)

def update_running_statistics(self, embeddings):
"""
Update mean and covariance using an EWMA algorithm.
Args:
embeddings (torch.Tensor): Batch of embedding vectors of shape (batch_size, embedding_dim).
"""
batch_size = embeddings.size(0)

if self.n_samples == 0:
self.initialize_statistics(embeddings)
return

# Compute batch mean
batch_mean = torch.mean(embeddings, dim=0)

# Update mean using EWMA
# μ_t = α * μ_{t-1} + (1 - α) * μ_batch
self.mean_embedding = (
self.alpha * self.mean_embedding + (1 - self.alpha) * batch_mean
)

# Compute batch covariance
centered_embeddings = embeddings - batch_mean.unsqueeze(0)
batch_cov = torch.matmul(centered_embeddings.T, centered_embeddings) / (
batch_size - 1
)
batch_cov += (
torch.eye(batch_cov.size(0)).to(self.device) * 1e-6
) # Numerical stability

# Compute delta (change in mean)
delta = (
batch_mean - self.mean_embedding.detach()
) # Detach to prevent gradients flowing

# Update covariance using EWMA
# Σ_t = α * Σ_{t-1} + (1 - α) * Σ_batch + (1 - α) * δ δ^T
delta_outer = torch.ger(delta, delta) # Outer product δ δ^T
self.cov_embedding = (
self.alpha * self.cov_embedding
+ (1 - self.alpha) * batch_cov
+ (1 - self.alpha) * delta_outer
)

# Update inverse covariance
cov_embedding_reg = (
self.cov_embedding
+ torch.eye(self.cov_embedding.size(0)).to(self.device) * 1e-6
)
self.inverse_cov_embedding = torch.inverse(cov_embedding_reg)

# Update effective sample count (optional, based on EWMA)
# In pure EWMA, sample counts are not tracked, but can be approximated if needed
# self.n_samples = self.alpha * self.n_samples + (1 - self.alpha) * batch_size
# Update effective sample count (optional, based on EWMA)
# In pure EWMA, sample counts are not tracked, but can be approximated if needed
# self.n_samples = self.alpha * self.n_samples + (1 - self.alpha) * batch_size

def compute_mahalanobis_distance(self, embeddings):
"""
Expand Down
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def main():
# evaluate_individual_user_models("acgan", include_generation=False, normalization_method="date")
evaluate_single_dataset_model(
"diffusion_ts",
geography="california",
include_generation=False,
# geography="california",
include_generation=True,
normalization_method="group",
)

Expand Down

0 comments on commit 3dd592a

Please sign in to comment.