Description
The batch size is a critical component of llmart
's efficiency. Having a device (CPU/GPU) with more memory (RAM/VRAM) means being able to process more swaps (samples) in parallel.
Right now the user must pass per_device_bs=N
if they want LLMart to process a batch of N
samples in parallel.
HF provides find_executable_batch_size
, that enables us to programmatically find a suitable batch size. Your job is to figure out where to put this call in LLMart such that it doesn't significantly slow down the basic attack. That is, accelerate launch -m llmart model=llama3-8b-instruct data=basic per_device_bs=64
should run just about as fast as ts -nfG1 python -m llmart model=llama3-8b-instruct data=basic per_device_bs=auto
.