import math
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.optim.optimizer import (
Optimizer,
ParamsT,
_get_value,
_use_grad_for_differentiable,
)
def maybe_get_value(inp):
if isinstance(inp, torch.Tensor):
return inp.item()
else:
return inp
[docs]
class AdamWTF(Optimizer):
"""AdamW optimizer with TensorFlow-style implementation for dense parameters.
This class implements the AdamW optimization algorithm with TensorFlow-compatible
behavior for dense parameter optimization. It extends PyTorch's Optimizer base
class and provides efficient optimization for standard neural network parameters.
The AdamW algorithm combines adaptive learning rates from Adam with proper
weight decay regularization. Unlike the original Adam optimizer, AdamW applies
weight decay directly to the parameters rather than adding it to the gradients,
which provides better regularization behavior especially for transformer models.
Mathematical formulation:
.. math::
m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
m̂_t = m_t / (1 - β₁^t)
v̂_t = v_t / (1 - β₂^t)
θ_t = θ_{t-1} - lr * (m̂_t / (√v̂_t + ε) + weight_decay * θ_{t-1})
Where:
- θ: parameters
- g: gradients
- m: first moment estimate (momentum)
- v: second moment estimate (variance)
- β₁, β₂: exponential decay rates
- lr: learning rate
- ε: numerical stability constant
Key features:
- TensorFlow-compatible behavior and numerical precision
- Proper weight decay implementation (decoupled from gradients)
- Optional Nesterov momentum support
- Fused kernel optimization for better performance
- Support for both scalar and tensor learning rates
Example:
Basic usage for transformer training:
.. code-block:: python
# Initialize for transformer model
optimizer = AdamWTF(
model.parameters(),
lr=0.001, # Learning rate
betas=(0.9, 0.999), # Adam momentum parameters
eps=1e-8, # Numerical stability
weight_decay=0.01, # L2 regularization strength
use_nesterov=False, # Standard Adam behavior
)
# Training loop
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
Advanced configuration for different model types:
"""
[docs]
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
use_nesterov: bool = False,
*,
maximize: bool = False,
fuse: bool = True,
):
"""Initialize AdamWTF optimizer with specified hyperparameters.
Args:
params (ParamsT): Iterable of parameters to optimize or dicts defining
parameter groups.
lr (Union[float, Tensor], optional): Learning rate. Can be a scalar
or tensor for dynamic learning rates. Defaults to 1e-3.
betas (Tuple[float, float], optional): Coefficients used for computing
running averages of gradient and its square. First value is beta1
(momentum), second is beta2 (variance). Defaults to (0.9, 0.999).
eps (float, optional): Term added to the denominator to improve
numerical stability. Defaults to 1e-8.
weight_decay (float, optional): Weight decay coefficient (L2 penalty).
Applied directly to parameters, not gradients. Defaults to 1e-2.
use_nesterov (bool, optional): Whether to use Nesterov momentum.
Provides faster convergence in some cases. Defaults to False.
maximize (bool, optional): Maximize the objective with respect to the
params, instead of minimizing. Defaults to False.
fuse (bool, optional): Whether to use fused kernel implementation
for better performance. Defaults to True.
Raises:
ValueError: If any hyperparameter is outside valid range.
Note:
The fused implementation provides significant speedup on CUDA devices
but may have slightly different numerical behavior compared to the
non-fused version.
"""
# Validate hyperparameters
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
use_nesterov=use_nesterov,
maximize=maximize,
fuse=fuse,
differentiable=False,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("use_nesterov", False)
group.setdefault("maximize", False)
group.setdefault("fuse", True)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
state_values[0]["step"]
)
if not step_is_tensor:
for s in state_values:
s["step"] = torch.tensor(float(s["step"]), dtype=torch.float32)
self.defaults["differentiable"] = False
def _init_group(
self,
group,
params_with_grad,
grads,
use_nesterov,
exp_avgs,
exp_avg_sqs,
state_steps,
):
for p in group["params"]:
if p.grad is None:
continue
assert not torch.is_complex(p)
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError("AdamWTF does not support sparse gradients")
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
# note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off.
# This is because kernel launches are costly on CUDA and XLA.
state["step"] = torch.tensor(0.0, dtype=torch.float32)
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
state_steps.append(state["step"])
return False
[docs]
@_use_grad_for_differentiable
def step(self, closure=None):
"""Perform a single optimization step.
This method executes one iteration of the AdamW optimization algorithm,
updating all parameters based on their gradients and the current
optimizer state.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss. Used for algorithms that require multiple
function evaluations per step.
Returns:
Optional[float]: The loss value if closure is provided, None otherwise.
Note:
This method automatically handles parameter grouping, state
initialization, and delegates to the appropriate implementation
(fused or non-fused) based on the optimizer configuration.
"""
# Check for CUDA graph capture compatibility
self._cuda_graph_capture_health_check()
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_steps = []
use_nesterov = group["use_nesterov"]
beta1, beta2 = group["betas"]
self._init_group(
group,
params_with_grad,
grads,
use_nesterov,
exp_avgs,
exp_avg_sqs,
state_steps,
)
adamwtf(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
use_nesterov=use_nesterov,
beta1=beta1,
beta2=beta2,
lr=group["lr"],
weight_decay=group["weight_decay"],
eps=group["eps"],
maximize=group["maximize"],
fuse=group["fuse"],
grad_scale=getattr(self, "grad_scale", None),
found_inf=getattr(self, "found_inf", None),
)
return loss
AdamWTF.__doc__ = r"""Implements AdamWTF algorithm.
"""
def adamwtf(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
grad_scale: Optional[Tensor] = None,
found_inf: Optional[Tensor] = None,
*,
use_nesterov: bool,
beta1: float,
beta2: float,
lr: Union[float, Tensor],
weight_decay: float,
eps: float,
maximize: bool,
fuse: bool,
):
r"""Functional API that performs AdamWTF algorithm computation.
See :class:`~torch.optim.AdamWTF` for details.
"""
if not torch._utils.is_compiling() and not all(
isinstance(t, torch.Tensor) for t in state_steps
):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
if fuse:
func = _fuse_tensor_adamwtf
else:
func = _single_tensor_adamwtf
func(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
use_nesterov=use_nesterov,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
eps=eps,
maximize=maximize,
grad_scale=grad_scale,
found_inf=found_inf,
)
def _fuse_tensor_adamwtf(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
use_nesterov: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
):
assert grad_scale is None and found_inf is None
assert not use_nesterov
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
assert not torch.is_complex(param)
step_t += 1
step = _get_value(step_t)
lr_scalar = maybe_get_value(lr)
# Perform stepweight decay
param.mul_(1 - weight_decay)
# Decay the first and second moment running average coefficient
torch.ops.recis.adam_tf_apply(
param, grad, exp_avg, exp_avg_sq, step, lr_scalar, beta1, beta2, eps
)
def _single_tensor_adamwtf(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
grad_scale: Optional[Tensor],
found_inf: Optional[Tensor],
*,
use_nesterov: bool,
beta1: float,
beta2: float,
lr: Union[Tensor, float],
weight_decay: float,
eps: float,
maximize: bool,
):
assert grad_scale is None and found_inf is None
assert not use_nesterov
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
# have overloads to handle both float and Tensor lrs, so we just assert it's
# a float since most people using JIT are using floats
assert isinstance(lr, float)
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
assert not torch.is_complex(param)
step_t += 1
# Perform stepweight decay
param.mul_(1 - weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
step = _get_value(step_t)
lr_scalar = maybe_get_value(lr)
b1_power = beta1**step
b2_power = beta2**step
bias_correction1 = 1 - b1_power
bias_correction2 = 1 - b2_power
alpha = lr_scalar / bias_correction1 * math.sqrt(bias_correction2)
denom = exp_avg_sq.sqrt().add_(eps)
param.addcdiv_(exp_avg, denom, value=-alpha)