Guide to Implementing Custom loss_func
When implementing a custom loss_func in ROLL, the most critical aspects are how the loss is aggregated and how loss_scale is handled. Mishandling these two points can cause the final computed loss or gradients to deviate from the result that would be obtained by performing a single forward pass over the entire global batch, thereby introducing training bias—especially severe in complex training scenarios involving data parallelism (DP) + gradient accumulation (GA) + sequence packing.
1. Common Loss Aggregation Strategies
Consider a global batch containing sequences. Let the length of the -th sequence be , with a per-token mask indicating whether position participates in loss computation. The number of valid tokens is:
Let denote the per-token loss at position of sequence (e.g., NLL, CE, KL divergence, policy loss, etc.).
1.1 Token-level Loss (token-mean)
Compute the average loss over all valid tokens in the global batch:
Property: Each token has equal weight; longer sequences contribute more due to having more valid tokens.
1.2 Sequence-level Loss (seq-mean)
First aggregate within each sequence, then average across sequences. ROLL commonly uses two variants:
(a) seq-mean-token-sum
Sum losses over tokens within each sequence, then average across sequences:
(b) seq-mean-token-mean
Average losses over tokens within each sequence, then average across sequences:
Property: Each sequence has equal weight, avoiding bias due to sequence length differences.
2. Micro-batch Partitioning in Distributed Training
In practice, a single global training step typically involves:
- Data Parallelism (DP): The global batch is split across multiple DP ranks;
- Gradient Accumulation (GA): Each rank further splits its data into multiple micro-batches, processed sequentially;
- Sequence Packing: To reduce padding and improve GPU utilization, multiple samples are concatenated into fixed-length packed sequences.
Let:
- DP world size be ,
- Gradient accumulation steps be ,
- Then the total number of micro-batches per global step is .
Denote the set of samples in the -th micro-batch as . Its number of valid tokens is:
The number of sequences (samples) in this micro-batch is , satisfying:
2.1 Why Does Sequence Packing Cause to Vary?
With sequence packing enabled, frameworks typically construct micro-batches based on a token budget rather than a fixed number of samples:
- Short sequences can be densely packed → some micro-batches contain many samples ( large);
- Long sequences consume more space → some micro-batches contain few samples ( small).
Thus, under packing, the number of samples per micro-batch is typically uneven and unpredictable, posing challenges for correct sequence-level loss aggregation.
3. Core Issue: Why You Should Not Normalize Using Local Statistics Within Micro-batches
ROLL’s goal is: regardless of training configuration (DP/GA/Packing), the final loss used for backpropagation must be mathematically equivalent to computing the loss over the entire global batch in one go (as defined in Section 1).
If each micro-batch uses its own local statistics (e.g., or ) for normalization, and gradients are accumulated via the backend, the result is generally not equivalent.
3.1 Token-level: Incorrect Normalization Within Micro-batches
Wrong approach (normalize by micro-batch’s own token count):
If micro-batches are equally weighted during averaging (e.g., via gradient averaging), the total loss becomes:
But the correct global token-mean loss is:
These are only equal when all are identical. Under variable-length sequences or packing, varies significantly, causing bias.
3.2 Sequence-level: Micro-batch seq-mean Causes Sample Weight Imbalance
Take seq-mean-token-mean as an example:
Wrong approach (normalize by micro-batch’s sample count ):