Skip to content

Commit ad1df04

Browse files
authored
Fix the validation dataset size (#144)
1 parent 74b9cd1 commit ad1df04

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

sparse_autoencoder/train/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def validate_sae(self, validation_number_activations: int) -> None:
324324
losses_with_reconstruction.append(loss_with_reconstruction.sum().item())
325325
losses_with_zero_ablation.append(loss_with_zero_ablation.sum().item())
326326

327-
if len(losses) >= validation_number_activations:
327+
if len(losses) >= validation_number_activations // input_ids.numel():
328328
break
329329

330330
# Log
@@ -335,7 +335,7 @@ def validate_sae(self, validation_number_activations: int) -> None:
335335
)
336336
for metric in self.metrics.validation_metrics:
337337
calculated = metric.calculate(validation_data)
338-
wandb.log(data=calculated, step=self.total_activations_trained_on, commit=False)
338+
wandb.log(data=calculated, commit=False)
339339

340340
@final
341341
def save_checkpoint(self) -> None:

sparse_autoencoder/train/sweep_config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
# Key default values (used to calculate other default values)
2424
DEFAULT_SOURCE_BATCH_SIZE: int = 16
25-
DEFAULT_SOURCE_CONTEXT_SIZE: int = 128
25+
DEFAULT_SOURCE_CONTEXT_SIZE: int = 256
2626
DEFAULT_BATCH_SIZE: int = 8192 # Should be a multiple of source batch size and context size
2727
DEFAULT_STORE_SIZE: int = round_to_multiple(3_000_000, DEFAULT_BATCH_SIZE)
2828

@@ -243,7 +243,10 @@ class PipelineHyperparameters(NestedParameter):
243243
)
244244
"""Validation frequency."""
245245

246-
validation_number_activations: Parameter[int] = field(default=Parameter(DEFAULT_BATCH_SIZE))
246+
validation_number_activations: Parameter[int] = field(
247+
# Default to a single batch of source data prompts
248+
default=Parameter(DEFAULT_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE)
249+
)
247250
"""Number of activations to use for validation."""
248251

249252

0 commit comments

Comments
 (0)