import torch
from recis.optim.sparse_optim import SparseOptimizer
[docs]
class SparseAdamWTF(SparseOptimizer):
"""Sparse AdamW optimizer with TensorFlow-style implementation for efficient sparse parameter optimization.
This class implements the AdamW optimization algorithm with TensorFlow-compatible
behavior specifically optimized for sparse parameters in recommendation systems.
It extends the SparseOptimizer base class and uses RecIS's C++ implementation
for maximum performance.
Example:
Creating and using SparseAdamWTF:
.. code-block:: python
# Initialize with custom hyperparameters for TF compatibility
optimizer = SparseAdamWTF(
param_dict=sparse_parameters,
lr=0.001, # Learning rate
beta1=0.9, # First moment decay rate
beta2=0.999, # Second moment decay rate
eps=1e-8, # Numerical stability
weight_decay=0.01, # L2 regularization strength
use_nesterov=False, # Nesterov momentum (not supported yet)
)
# Training with gradient accumulation
optimizer.set_grad_accum_steps(4)
for batch in dataloader:
loss = model(batch) / 4 # Scale for accumulation
loss.backward()
optimizer.step()
optimizer.zero_grad()
"""
[docs]
def __init__(
self,
param_dict: dict,
lr=1e-3,
beta1=0.9,
beta2=0.999,
eps=1e-8,
weight_decay=1e-2,
use_nesterov=False,
) -> None:
"""Initialize SparseAdamWTF optimizer with specified hyperparameters.
Args:
param_dict (dict): Dictionary of sparse parameters to optimize.
Keys are parameter names, values are parameter tensors (typically HashTables).
lr (float, optional): Learning rate. Should match TensorFlow training
settings for compatibility. Defaults to 1e-3.
beta1 (float, optional): Exponential decay rate for first moment estimates.
Should be in [0, 1). TensorFlow default is 0.9. Defaults to 0.9.
beta2 (float, optional): Exponential decay rate for second moment estimates.
Should be in [0, 1). TensorFlow default is 0.999. Defaults to 0.999.
eps (float, optional): Small constant added to denominator for numerical
stability. TensorFlow default is 1e-7, but 1e-8 is also common.
Defaults to 1e-8.
weight_decay (float, optional): Weight decay coefficient (L2 regularization).
Applied directly to parameters (decoupled weight decay). Defaults to 1e-2.
use_nesterov (bool, optional): Whether to use Nesterov momentum variant.
Currently not supported and will raise ValueError if True. Defaults to False.
Raises:
ValueError: If use_nesterov is True (not currently supported).
Note:
The param_dict should contain HashTable parameters that support
sparse gradient updates. This optimizer is specifically designed
for compatibility with TensorFlow-trained models and provides
numerically equivalent behavior for seamless model migration.
"""
super().__init__(lr=lr)
self._lr = lr
self._beta1 = beta1
self._beta2 = beta2
self._eps = eps
self._weight_decay = weight_decay
if use_nesterov:
raise ValueError("use_nesterov is not support now")
self._use_nesterov = use_nesterov
self._imp = torch.classes.recis.SparseAdamWTF.make(
param_dict,
self._lr,
self._beta1,
self._beta2,
self._eps,
self._weight_decay,
self._use_nesterov,
)