Dynamic batching for Encoder-Decoder MT training or generation when long sequence caps the batch size [P]
I built a small pytorch sampler called dynabatch after facing this specific batching issue while fine tuning a NLLB-200 600M model.
Training on RTX 5090, the largest fixed batch size I could use was 8, any bigger leads to OOM. While training and monitoring using nvidia-smi , it looked like only a few batches were actually stressing the GPU. A lot of the time utilization was much lower. My guess was that fixed batch size was being dictated by the longests source/target examples, while the shorter examples probably had room for more samples per batch.
So I tried to make the batch size change as the sequence lengths changed. The gist of the idea is:
- sort examples by token length, longest first
- treat the first batch as “this is the hardest batch that fits”
- for later, shorter batches, try larger candidate batch sizes
- use a small XGB regressor to predict memory pressure relative to that first batch
- pick the largest candidate that stays under a safety threshold
This is mostly meant for encoder-decoder models, especially for MT where source length is often a useful proxy for target length. I would not use this as my first tool for decoder-only models. I think sequence packing is a better winner.
In my training benchmark, this gave about 3.3x throughput improvement over fixed batch training. The number is true to my setup, but I do not think it should be read as a general claim. On collab T4 generation benchmark, the gain was only around 1.06x - 1.21x
The regressor is also empirical, it was trained from measured GPU memory usage, so it can be wrong sometimes, and might behave a little differently for some models/tokenizer. But I have added a fallback when it overestimates and throw OOM. (Also added the regressor training notebooks for anyone interested)
So, honestly I think this is a very niche tool especially in the decoder-only era, but I hope this helps for people who are training/generating using encoder-decoder MT models.
Repo: https://github.com/bendangnuksung/dynabatch
PyPI: https://pypi.org/project/dynabatch/
[link] [comments]
Want to read more?
Check out the full article on the original site