Guide to Implementing Custom loss_func
When implementing a custom loss_func in ROLL, the most critical aspects are how the loss is aggregated and how loss_scale is handled. Mishandling these two points can cause the final computed loss or gradients to deviate from the result that would be obtained by performing a single forward pass over the entire global batch, thereby introducing training bias—especially severe in complex training scenarios involving data parallelism (DP) + gradient accumulation (GA) + sequence packing.
1. Common Loss Aggregation Strategies
Consider a global batch containing sequences. Let the length of the -th sequence be , with a per-token mask indicating whether position participates in loss computation. The number of valid tokens is:
Let denote the per-token loss at position of sequence (e.g., NLL, CE, KL divergence, policy loss, etc.).
1.1 Token-level Loss (token-mean)
Compute the average loss over all valid tokens in the global batch:
Property: Each token has equal weight; longer sequences contribute more due to having more valid tokens.
1.2 Sequence-level Loss (seq-mean)
First aggregate within each sequence, then average across sequences. ROLL commonly uses two variants:
(a) seq-mean-token-sum
Sum losses over tokens within each sequence, then average across sequences:
(b) seq-mean-token-mean
Average losses over tokens within each sequence, then average across sequences:
Property: Each sequence has equal weight, avoiding bias due to sequence length differences.
2. Micro-batch Partitioning in Distributed Training
In practice, a single global training step typically involves:
- Data Parallelism (DP): The global batch is split across multiple DP ranks;
- Gradient Accumulation (GA): Each rank further splits its data into multiple micro-batches, processed sequentially;
- Sequence Packing: To reduce padding and improve GPU utilization, multiple samples are concatenated into fixed-length packed sequences.
Let:
- DP world size be ,
- Gradient accumulation steps be ,
- Then the total number of micro-batches per global step is .
Denote the set of samples in the -th micro-batch as . Its number of valid tokens is:
The number of sequences (samples) in this micro-batch is , satisfying:
2.1 Why Does Sequence Packing Cause to Vary?
With sequence packing enabled, frameworks typically construct micro-batches based on a token budget rather than a fixed number of samples:
- Short sequences can be densely packed → some micro-batches contain many samples ( large);
- Long sequences consume more space → some micro-batches contain few samples ( small).
Thus, under packing, the number of samples per micro-batch is typically uneven and unpredictable, posing challenges for correct sequence-level loss aggregation.
3. Core Issue: Why You Should Not Normalize Using Local Statistics Within Micro-batches
ROLL’s goal is: regardless of training configuration (DP/GA/Packing), the final loss used for backpropagation must be mathematically equivalent to computing the loss over the entire global batch in one go (as defined in Section 1).
If each micro-batch uses its own local statistics (e.g., or ) for normalization, and gradients are accumulated via the backend, the result is generally not equivalent.
3.1 Token-level: Incorrect Normalization Within Micro-batches
Wrong approach (normalize by micro-batch’s own token count):
If micro-batches are equally weighted during averaging (e.g., via gradient averaging), the total loss becomes:
But the correct global token-mean loss is:
These are only equal when all are identical. Under variable-length sequences or packing, varies significantly, causing bias.
3.2 Sequence-level: Micro-batch seq-mean Causes Sample Weight Imbalance
Take seq-mean-token-mean as an example:
Wrong approach (normalize by micro-batch’s sample count ):
After equal-weight averaging across micro-batches:
But the correct global seq-mean is:
The former treats each micro-batch equally; the latter treats each sequence equally. When varies (common under packing), they are not equivalent.
4. Correct Approach: Use Global Denominator + Sum Across Micro-batches
ROLL follows these design principles:
- Within each micro-batch, use global statistics as the denominator;
- Each micro-batch’s returned loss should represent a partial contribution to the global loss;
- The sum of all micro-batch losses must exactly equal the global loss;
- Use
loss_scaleto counteract the backend’s default normalization behavior (see Section 5).
4.1 Correct Implementation for Token-level Loss
For the -th micro-batch:
Then:
✅ Mathematically exact.
4.2 Correct Implementation for Sequence-level Loss (e.g., seq-mean-token-mean)
For the -th micro-batch:
Then:
✅ Holds exactly even when varies (common under packing).
5. loss_scale: Compensating for Backend Normalization
Most training frameworks (e.g., Megatron, FSDP) implicitly normalize gradients under DP + GA to stabilize scale:
- GA dimension: Average gradients over micro-steps (equivalent to
loss /= A); - DP dimension: Divide by after AllReduce (equivalent to averaging across ranks).
The combined effect is:
However, ROLL’s aggregation design requires summation semantics across micro-batches:
To cancel the backend’s normalization, multiply each micro-batch’s loss by:
Thus:
✅ Recovers correct summation semantics.
6. ROLL Interface: Global Stat Injection and loss_scale Control
To enable globally equivalent loss aggregation at the micro-batch level, ROLL automatically injects global batch statistics (e.g., total valid tokens, total valid samples) into each training step. These statistics are computed based entirely on user-specified loss_mask_keys.
6.1 loss_mask_keys: Define Loss Participation Scope and Drive Global Stat Injection
loss_mask_keys is a list of strings declaring which mask fields identify "valid tokens participating in loss computation." This configuration not only guides how the loss function masks invalid positions but—more importantly—determines how the strategy computes and injects global aggregation quantities.
You must set this in your pipeline’s data preprocessing or worker initialization:
data.meta_info['loss_mask_keys'] = ['response_mask', 'labels_mask']
For each key in loss_mask_keys (e.g., 'response_mask'), ROLL’s strategy will:
- Extract the corresponding mask tensor from
data.batch(typically shape[batch_size, seq_len]); - Gather this mask across all DP ranks and GA steps;
- Compute two global statistics:
batch_num_tokens[key]: Total sum of this mask over the entire global batch, i.e.,global_valid_samples[key]: Number of sequences with at least one valid token, i.e.,
These statistics are injected into data.meta_info for use in loss_func.
⚠️ Critical Consistency Requirement: The mask you use in
loss_funcfor loss computation, weighting, or aggregation must have identical semantics to the mask specified inloss_mask_keys.
For example, ifloss_mask_keys = ['response_mask'], your loss must be computed only usingresponse_mask. Using a different mask (e.g.,attention_mask) will cause a mismatch between the numerator (loss computation) and denominator (global stats), breaking equivalence.
6.2 Using Injected Global Statistics in loss_func
In your custom loss_func, access global statistics as follows:
# Assume 'response_mask' is in loss_mask_keys
mask_key = 'response_mask'
N_all = data.meta_info['batch_num_tokens'][mask_key] # Global valid token count
B_all = data.meta_info['global_valid_samples'][mask_key] # Global valid sample count
Then use these global values as denominators during aggregation (see Section 4) to ensure micro-batch computations exactly reconstruct the global loss.
6.3 apply_loss_scale: Controlling Gradient Scale Correction
Since training backends (e.g., Megatron/FSDP) typically apply implicit normalization under DP + GA, while ROLL relies on summation semantics, compensation via loss_scale = D \times A is needed.
In worker_config, the parameter apply_loss_scale controls whether this scaling is applied automatically:
- Default:
True(recommended to keep enabled) - Effect: Framework automatically multiplies the loss returned by
loss_funcbyloss_scale - When to disable: Only if you manually implement the full global loss (including scale) inside
loss_func—generally not advised.
7. Metrics Logging: Use @sum Semantics
For losses aggregated using global denominators, metrics should be summed—not averaged—during multi-worker reduction.
ROLL supports specifying reduction behavior via an @operator suffix in metric names:
metrics = {
"actor/kl_loss@sum": kl_loss.detach().item(),
}
reduce_metrics(metrics)
@sum: Sum values across all workers during reduction;@mean(default): Average across workers;- The logger automatically strips everything from
@onward, so it displays asactor/kl_loss.
8. Code Example: Globally Equivalent KL Loss Implementation in Actor
8.1 Compute Per-Token KL
kl_loss = compute_approx_kl(
log_probs=log_probs,
log_probs_base=ref_log_probs,
action_mask=final_response_mask,
kl_penalty="k3"
)