Skip to content

Commit

Permalink
Fix the validation dataset size (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 9, 2023
1 parent 74b9cd1 commit ad1df04
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def validate_sae(self, validation_number_activations: int) -> None:
losses_with_reconstruction.append(loss_with_reconstruction.sum().item())
losses_with_zero_ablation.append(loss_with_zero_ablation.sum().item())

if len(losses) >= validation_number_activations:
if len(losses) >= validation_number_activations // input_ids.numel():
break

# Log
Expand All @@ -335,7 +335,7 @@ def validate_sae(self, validation_number_activations: int) -> None:
)
for metric in self.metrics.validation_metrics:
calculated = metric.calculate(validation_data)
wandb.log(data=calculated, step=self.total_activations_trained_on, commit=False)
wandb.log(data=calculated, commit=False)

@final
def save_checkpoint(self) -> None:
Expand Down
7 changes: 5 additions & 2 deletions sparse_autoencoder/train/sweep_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

# Key default values (used to calculate other default values)
DEFAULT_SOURCE_BATCH_SIZE: int = 16
DEFAULT_SOURCE_CONTEXT_SIZE: int = 128
DEFAULT_SOURCE_CONTEXT_SIZE: int = 256
DEFAULT_BATCH_SIZE: int = 8192 # Should be a multiple of source batch size and context size
DEFAULT_STORE_SIZE: int = round_to_multiple(3_000_000, DEFAULT_BATCH_SIZE)

Expand Down Expand Up @@ -243,7 +243,10 @@ class PipelineHyperparameters(NestedParameter):
)
"""Validation frequency."""

validation_number_activations: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE))
validation_number_activations: Parameter[int] = field(
# Default to a single batch of source data prompts
default=Parameter(DEFAULT_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE)
)
"""Number of activations to use for validation."""


Expand Down

0 comments on commit ad1df04

Please sign in to comment.