Optimizer Module
RecIS’s optimizer module is specifically designed for sparse parameter optimization, providing efficient sparse parameter update algorithms.
Sparse Optimizers
SparseAdamW
- class recis.optim.sparse_adamw.SparseAdamW(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, amsgrad=False)[source]
Sparse AdamW optimizer for efficient sparse parameter optimization.
This class implements the AdamW optimization algorithm specifically optimized for sparse parameters in recommendation systems. It extends the SparseOptimizer base class and uses RecIS’s C++ implementation for maximum performance.
The AdamW algorithm combines adaptive learning rates from Adam with proper weight decay regularization. For sparse parameters, this implementation only updates parameters that have received gradients, making it highly efficient for large embedding tables where only a small fraction of parameters are active in each training step.
Mathematical formulation:
\[ \begin{align}\begin{aligned}m_t = β₁ * m_{t-1} + (1 - β₁) * g_t\\v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²\\m̂_t = m_t / (1 - β₁^t)\\v̂_t = v_t / (1 - β₂^t)\\θ_t = θ_{t-1} - lr * (m̂_t / (√v̂_t + ε) + weight_decay * θ_{t-1})\end{aligned}\end{align} \]- Where:
θ: parameters
g: gradients
m: first moment estimate (momentum)
v: second moment estimate (variance)
β₁, β₂: exponential decay rates
lr: learning rate
ε: numerical stability constant
Example
Creating and using SparseAdamW:
# Initialize with custom hyperparameters optimizer = SparseAdamW( param_dict=sparse_parameters, lr=0.001, # Learning rate beta1=0.9, # First moment decay rate beta2=0.999, # Second moment decay rate eps=1e-8, # Numerical stability weight_decay=0.01, # L2 regularization strength ) # Training with gradient accumulation optimizer.set_grad_accum_steps(4) for batch in dataloader: loss = model(batch) / 4 # Scale for accumulation loss.backward() optimizer.step() optimizer.zero_grad()
- __init__(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, amsgrad=False) None [source]
Initialize SparseAdamW optimizer with specified hyperparameters.
- Parameters:
param_dict (dict) – Dictionary of sparse parameters to optimize. Keys are parameter names, values are parameter tensors (typically HashTables).
lr (float, optional) – Learning rate. Defaults to 1e-3.
beta1 (float, optional) – Exponential decay rate for first moment estimates. Should be in [0, 1). Defaults to 0.9.
beta2 (float, optional) – Exponential decay rate for second moment estimates. Should be in [0, 1). Defaults to 0.999.
eps (float, optional) – Small constant added to denominator for numerical stability. Defaults to 1e-8.
weight_decay (float, optional) – Weight decay coefficient (L2 regularization). Defaults to 1e-2.
amsgrad (bool, optional) – Whether to use AMSGrad variant. Currently not supported and will raise ValueError if True. Defaults to False.
- Raises:
ValueError – If amsgrad is True (not currently supported).
Example:
# Basic initialization optimizer = SparseAdamW(sparse_params) # Custom hyperparameters for recommendation systems optimizer = SparseAdamW( param_dict=embedding_params, lr=0.01, # Higher learning rate for sparse params beta1=0.9, # Standard momentum beta2=0.999, # Standard variance decay eps=1e-8, # Numerical stability weight_decay=0.001, # Light regularization ) # Conservative settings for fine-tuning optimizer = SparseAdamW( param_dict=pretrained_embeddings, lr=0.0001, # Low learning rate weight_decay=0.1, # Strong regularization )
Note
The param_dict should contain HashTable parameters that support sparse gradient updates. Regular dense tensors may not work correctly with this optimizer.
- set_grad_accum_steps(steps: int)
Set the number of gradient accumulation steps.
This method configures gradient accumulation, which allows training with effectively larger batch sizes by accumulating gradients over multiple forward passes before updating parameters.
- Parameters:
steps (int) – Number of steps to accumulate gradients before performing a parameter update. Must be positive.
- step()
Perform a single optimization step with gradient accumulation support.
This method implements gradient accumulation by only performing the actual parameter update every _grad_accum_steps steps. It maintains an internal step counter and delegates the actual optimization to the underlying C++ implementation.
Note
When gradient accumulation is enabled (_grad_accum_steps > 1), this method only performs the actual parameter update every _grad_accum_steps calls. The learning rate is automatically handled by the implementation.
- zero_grad()
Clear gradients with gradient accumulation support.
This method clears parameter gradients, but only when gradient accumulation steps are completed. This ensures that gradients are properly accumulated across multiple forward passes before being cleared.
Note
When gradient accumulation is enabled, this method only clears gradients every _grad_accum_steps calls, synchronized with the step() method.
SparseAdam
- class recis.optim.sparse_adam.SparseAdam(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, amsgrad=False)[source]
Sparse Adam optimizer for efficient sparse parameter optimization.
This class implements the Adam optimization algorithm specifically optimized for sparse parameters in recommendation systems. It extends the SparseOptimizer base class and uses RecIS’s C++ implementation for maximum performance.
The Adam algorithm uses adaptive learning rates computed from estimates of first and second moments of gradients. For sparse parameters, this implementation only updates parameters that have received gradients, making it highly efficient for large embedding tables where only a small fraction of parameters are active in each training step.
Key advantages for sparse parameters: - Only active parameters are updated, saving computation - Separate adaptive learning rates for each sparse parameter - Efficient memory usage for momentum and variance estimates - Optimized C++ implementation for maximum performance
Example
Creating and using SparseAdam:
# Initialize with default hyperparameters optimizer = SparseAdam( param_dict=sparse_parameters, lr=0.001, # Learning rate beta1=0.9, # First moment decay rate beta2=0.999, # Second moment decay rate eps=1e-8, # Numerical stability weight_decay=0.0, # No weight decay by default ) # Training with gradient accumulation optimizer.set_grad_accum_steps(4) for batch in dataloader: loss = model(batch) / 4 # Scale for accumulation loss.backward() optimizer.step() optimizer.zero_grad()
- __init__(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, amsgrad=False) None [source]
Initialize SparseAdam optimizer with specified hyperparameters.
- Parameters:
param_dict (dict) – Dictionary of sparse parameters to optimize. Keys are parameter names, values are parameter tensors (typically HashTables).
lr (float, optional) – Learning rate. Defaults to 1e-3.
beta1 (float, optional) – Exponential decay rate for first moment estimates. Should be in [0, 1). Defaults to 0.9.
beta2 (float, optional) – Exponential decay rate for second moment estimates. Should be in [0, 1). Defaults to 0.999.
eps (float, optional) – Small constant added to denominator for numerical stability. Defaults to 1e-8.
weight_decay (float, optional) – Weight decay coefficient. Note that unlike SparseAdamW, this applies L2 penalty to gradients rather than direct weight decay. Defaults to 1e-2.
amsgrad (bool, optional) – Whether to use AMSGrad variant. Currently not supported and will raise ValueError if True. Defaults to False.
- Raises:
ValueError – If amsgrad is True (not currently supported).
Note
The param_dict should contain HashTable parameters that support sparse gradient updates. The weight_decay in SparseAdam applies L2 penalty to gradients, which is different from the direct weight decay used in SparseAdamW.
- set_grad_accum_steps(steps: int)
Set the number of gradient accumulation steps.
This method configures gradient accumulation, which allows training with effectively larger batch sizes by accumulating gradients over multiple forward passes before updating parameters.
- Parameters:
steps (int) – Number of steps to accumulate gradients before performing a parameter update. Must be positive.
- step()
Perform a single optimization step with gradient accumulation support.
This method implements gradient accumulation by only performing the actual parameter update every _grad_accum_steps steps. It maintains an internal step counter and delegates the actual optimization to the underlying C++ implementation.
Note
When gradient accumulation is enabled (_grad_accum_steps > 1), this method only performs the actual parameter update every _grad_accum_steps calls. The learning rate is automatically handled by the implementation.
- zero_grad()
Clear gradients with gradient accumulation support.
This method clears parameter gradients, but only when gradient accumulation steps are completed. This ensures that gradients are properly accumulated across multiple forward passes before being cleared.
Note
When gradient accumulation is enabled, this method only clears gradients every _grad_accum_steps calls, synchronized with the step() method.
TensorFlow Compatible Optimizers
SparseAdamWTF
- class recis.optim.sparse_adamw_tf.SparseAdamWTF(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, use_nesterov=False)[source]
Sparse AdamW optimizer with TensorFlow-style implementation for efficient sparse parameter optimization.
This class implements the AdamW optimization algorithm with TensorFlow-compatible behavior specifically optimized for sparse parameters in recommendation systems. It extends the SparseOptimizer base class and uses RecIS’s C++ implementation for maximum performance.
Example
Creating and using SparseAdamWTF:
# Initialize with custom hyperparameters for TF compatibility optimizer = SparseAdamWTF( param_dict=sparse_parameters, lr=0.001, # Learning rate beta1=0.9, # First moment decay rate beta2=0.999, # Second moment decay rate eps=1e-8, # Numerical stability weight_decay=0.01, # L2 regularization strength use_nesterov=False, # Nesterov momentum (not supported yet) ) # Training with gradient accumulation optimizer.set_grad_accum_steps(4) for batch in dataloader: loss = model(batch) / 4 # Scale for accumulation loss.backward() optimizer.step() optimizer.zero_grad()
- __init__(param_dict: dict, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-08, weight_decay=0.01, use_nesterov=False) None [source]
Initialize SparseAdamWTF optimizer with specified hyperparameters.
- Parameters:
param_dict (dict) – Dictionary of sparse parameters to optimize. Keys are parameter names, values are parameter tensors (typically HashTables).
lr (float, optional) – Learning rate. Should match TensorFlow training settings for compatibility. Defaults to 1e-3.
beta1 (float, optional) – Exponential decay rate for first moment estimates. Should be in [0, 1). TensorFlow default is 0.9. Defaults to 0.9.
beta2 (float, optional) – Exponential decay rate for second moment estimates. Should be in [0, 1). TensorFlow default is 0.999. Defaults to 0.999.
eps (float, optional) – Small constant added to denominator for numerical stability. TensorFlow default is 1e-7, but 1e-8 is also common. Defaults to 1e-8.
weight_decay (float, optional) – Weight decay coefficient (L2 regularization). Applied directly to parameters (decoupled weight decay). Defaults to 1e-2.
use_nesterov (bool, optional) – Whether to use Nesterov momentum variant. Currently not supported and will raise ValueError if True. Defaults to False.
- Raises:
ValueError – If use_nesterov is True (not currently supported).
Note
The param_dict should contain HashTable parameters that support sparse gradient updates. This optimizer is specifically designed for compatibility with TensorFlow-trained models and provides numerically equivalent behavior for seamless model migration.
- set_grad_accum_steps(steps: int)
Set the number of gradient accumulation steps.
This method configures gradient accumulation, which allows training with effectively larger batch sizes by accumulating gradients over multiple forward passes before updating parameters.
- Parameters:
steps (int) – Number of steps to accumulate gradients before performing a parameter update. Must be positive.
- step()
Perform a single optimization step with gradient accumulation support.
This method implements gradient accumulation by only performing the actual parameter update every _grad_accum_steps steps. It maintains an internal step counter and delegates the actual optimization to the underlying C++ implementation.
Note
When gradient accumulation is enabled (_grad_accum_steps > 1), this method only performs the actual parameter update every _grad_accum_steps calls. The learning rate is automatically handled by the implementation.
- zero_grad()
Clear gradients with gradient accumulation support.
This method clears parameter gradients, but only when gradient accumulation steps are completed. This ensures that gradients are properly accumulated across multiple forward passes before being cleared.
Note
When gradient accumulation is enabled, this method only clears gradients every _grad_accum_steps calls, synchronized with the step() method.
AdamWTF
- class recis.optim.adamw_tf.AdamWTF(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr: float | Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.01, use_nesterov: bool = False, *, maximize: bool = False, fuse: bool = True)[source]
Implements AdamWTF algorithm.
- __init__(params: Iterable[Tensor] | Iterable[Dict[str, Any]], lr: float | Tensor = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-08, weight_decay: float = 0.01, use_nesterov: bool = False, *, maximize: bool = False, fuse: bool = True)[source]
Initialize AdamWTF optimizer with specified hyperparameters.
- Parameters:
params (ParamsT) – Iterable of parameters to optimize or dicts defining parameter groups.
lr (Union[float, Tensor], optional) – Learning rate. Can be a scalar or tensor for dynamic learning rates. Defaults to 1e-3.
betas (Tuple[float, float], optional) – Coefficients used for computing running averages of gradient and its square. First value is beta1 (momentum), second is beta2 (variance). Defaults to (0.9, 0.999).
eps (float, optional) – Term added to the denominator to improve numerical stability. Defaults to 1e-8.
weight_decay (float, optional) – Weight decay coefficient (L2 penalty). Applied directly to parameters, not gradients. Defaults to 1e-2.
use_nesterov (bool, optional) – Whether to use Nesterov momentum. Provides faster convergence in some cases. Defaults to False.
maximize (bool, optional) – Maximize the objective with respect to the params, instead of minimizing. Defaults to False.
fuse (bool, optional) – Whether to use fused kernel implementation for better performance. Defaults to True.
- Raises:
ValueError – If any hyperparameter is outside valid range.
Note
The fused implementation provides significant speedup on CUDA devices but may have slightly different numerical behavior compared to the non-fused version.
- step(closure=None)[source]
Perform a single optimization step.
This method executes one iteration of the AdamW optimization algorithm, updating all parameters based on their gradients and the current optimizer state.
- Parameters:
closure (Callable, optional) – A closure that reevaluates the model and returns the loss. Used for algorithms that require multiple function evaluations per step.
- Returns:
The loss value if closure is provided, None otherwise.
- Return type:
Optional[float]
Note
This method automatically handles parameter grouping, state initialization, and delegates to the appropriate implementation (fused or non-fused) based on the optimizer configuration.
- zero_grad(set_to_none: bool = True) None
Resets the gradients of all optimized
torch.Tensor
s.- Parameters:
set_to_none (bool) – instead of setting to zero, set the grads to None. This will in general have lower memory footprint, and can modestly improve performance. However, it changes certain behaviors. For example: 1. When the user tries to access a gradient and perform manual ops on it, a None attribute or a Tensor full of 0s will behave differently. 2. If the user requests
zero_grad(set_to_none=True)
followed by a backward pass,.grad
s are guaranteed to be None for params that did not receive a gradient. 3.torch.optim
optimizers have a different behavior if the gradient is 0 or None (in one case it does the step with a gradient of 0 and in the other it skips the step altogether).
Usage Guide
Basic Usage Flow
Parameter Separation:
from recis.nn.modules.hashtable import filter_out_sparse_param # Separate sparse and dense parameters sparse_params = filter_out_sparse_param(model)
Create Optimizers:
from recis.optim import SparseAdamW from torch.optim import AdamW # Sparse parameter optimizer sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) # Dense parameter optimizer dense_optimizer = AdamW(model.parameters(), lr=0.001)
Performance Optimization Recommendations
Parameter Tuning
Learning Rate Settings:
# Sparse parameters usually need larger learning rates sparse_optimizer = SparseAdamW(sparse_params, lr=0.01) dense_optimizer = AdamW(model.parameters(), lr=0.001)
Weight Decay:
# Adjust weight decay based on model size sparse_optimizer = SparseAdamW( sparse_params, lr=0.001, weight_decay=0.01 # Larger models can use larger weight decay
Beta Parameter Adjustment:
# For sparse updates, beta parameters can be adjusted sparse_optimizer = SparseAdamW( sparse_params, lr=0.001, beta1=0.9, # First moment estimate beta2=0.999 # Second moment estimate )
Frequently Asked Questions
Q: What’s the difference between sparse optimizers and regular optimizers?
A: Sparse optimizers are specifically designed for HashTable parameters with the following characteristics: - Only update parameters with gradients - Support dynamic parameter expansion - More efficient memory usage - Compatible with distributed training
Q: How to choose the right optimizer?
A: Selection recommendations: - For sparse embeddings: Use SparseAdamW - For dense layers: Use standard AdamW - For TensorFlow alignment: Use TF version optimizers