diff --git a/model_zoo/pytorch/nanogpt/train.py b/model_zoo/pytorch/nanogpt/train.py index 8660f9bba..c51c9a568 100644 --- a/model_zoo/pytorch/nanogpt/train.py +++ b/model_zoo/pytorch/nanogpt/train.py @@ -243,8 +243,13 @@ def train(): # Training loop X, Y = get_batch( - "train", train_data=train_data, val_data=val_data, device=device, - device_type=device_type, batch_size=batch_size, block_size=block_size + "train", + train_data=train_data, + val_data=val_data, + device=device, + device_type=device_type, + batch_size=batch_size, + block_size=block_size, ) # Fetch the very first batch total_time = 0.0 local_iter_num = 0 # Number of iterations in the lifetime of this process