Source code for recis.optim.sparse_adamw_tf

import torch

from recis.optim.sparse_optim import SparseOptimizer


[docs] class SparseAdamWTF(SparseOptimizer): """Sparse AdamW optimizer with TensorFlow-style implementation for efficient sparse parameter optimization. This class implements the AdamW optimization algorithm with TensorFlow-compatible behavior specifically optimized for sparse parameters in recommendation systems. It extends the SparseOptimizer base class and uses RecIS's C++ implementation for maximum performance. Example: Creating and using SparseAdamWTF: .. code-block:: python # Initialize with custom hyperparameters for TF compatibility optimizer = SparseAdamWTF( param_dict=sparse_parameters, lr=0.001, # Learning rate beta1=0.9, # First moment decay rate beta2=0.999, # Second moment decay rate eps=1e-8, # Numerical stability weight_decay=0.01, # L2 regularization strength use_nesterov=False, # Nesterov momentum (not supported yet) ) # Training with gradient accumulation optimizer.set_grad_accum_steps(4) for batch in dataloader: loss = model(batch) / 4 # Scale for accumulation loss.backward() optimizer.step() optimizer.zero_grad() """
[docs] def __init__( self, param_dict: dict, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=1e-2, use_nesterov=False, ) -> None: """Initialize SparseAdamWTF optimizer with specified hyperparameters. Args: param_dict (dict): Dictionary of sparse parameters to optimize. Keys are parameter names, values are parameter tensors (typically HashTables). lr (float, optional): Learning rate. Should match TensorFlow training settings for compatibility. Defaults to 1e-3. beta1 (float, optional): Exponential decay rate for first moment estimates. Should be in [0, 1). TensorFlow default is 0.9. Defaults to 0.9. beta2 (float, optional): Exponential decay rate for second moment estimates. Should be in [0, 1). TensorFlow default is 0.999. Defaults to 0.999. eps (float, optional): Small constant added to denominator for numerical stability. TensorFlow default is 1e-7, but 1e-8 is also common. Defaults to 1e-8. weight_decay (float, optional): Weight decay coefficient (L2 regularization). Applied directly to parameters (decoupled weight decay). Defaults to 1e-2. use_nesterov (bool, optional): Whether to use Nesterov momentum variant. Currently not supported and will raise ValueError if True. Defaults to False. Raises: ValueError: If use_nesterov is True (not currently supported). Note: The param_dict should contain HashTable parameters that support sparse gradient updates. This optimizer is specifically designed for compatibility with TensorFlow-trained models and provides numerically equivalent behavior for seamless model migration. """ super().__init__(lr=lr) self._lr = lr self._beta1 = beta1 self._beta2 = beta2 self._eps = eps self._weight_decay = weight_decay if use_nesterov: raise ValueError("use_nesterov is not support now") self._use_nesterov = use_nesterov self._imp = torch.classes.recis.SparseAdamWTF.make( param_dict, self._lr, self._beta1, self._beta2, self._eps, self._weight_decay, self._use_nesterov, )