diff --git a/lwm/data.py b/lwm/data.py index a6cb177..c6c9b74 100644 --- a/lwm/data.py +++ b/lwm/data.py @@ -155,7 +155,7 @@ def __call__(self, example, has_aux=False, add_bos_token=True, add_eos_token=Tru example, *aux = example else: aux = tuple() - rand_state = random.Random(aux[-1]) # makes augmentations deterministic by line number + rand_state = random.Random(aux[-1] if aux else 0) # makes augmentations deterministic by line number token_buffer = [] loss_mask_buffer = [] vision_mask = []