from dataclasses import dataclass
from datetime import timedelta
from typing import Callable, List, Optional, Tuple
import torch
import torch.distributed as dist
from accelerate import (
Accelerator,
DistributedDataParallelKwargs,
InitProcessGroupKwargs,
)
from torch import nn
from torch.utils.data import Dataset
from recis.framework.checkpoint_manager import ExtraFields, Saver, SaverOptions
from recis.framework.metrics import get_global_metrics
from recis.hooks import Hook, LoggerHook
from recis.hooks.checkpoint_hooks import (
CheckpointLoadArguments,
CheckpointLoadHook,
CheckpointSaveArguments,
CheckpointSaveHook,
)
from recis.hooks.metric_report_hook import MetricReportHook
from recis.metrics.metric_reporter import MODEL_FWD_NAME, MetricReporter
from recis.optim import sparse_optim
from recis.utils.data_utils import copy_data_to_device
from recis.utils.logger import Logger
logger = Logger(__name__)
[docs]
@dataclass
class TrainingArguments:
"""Configuration class for training parameters.
This dataclass contains all the configuration parameters needed for training,
including optimization settings, logging intervals, and checkpoint management.
Attributes:
gradient_accumulation_steps (int): Number of steps to accumulate gradients
before performing an optimizer step. Defaults to 1.
output_dir (str): Directory where checkpoints and logs will be saved.
Defaults to "output_dir".
model_bank (Optional[list]): List of model bank paths for initialization.
Defaults to None.
log_steps (int): Number of training steps between logging. Defaults to 100.
train_steps (Optional[int]): Maximum number of training steps. If None,
will train for full epochs. Defaults to None.
train_epoch (Optional[int]): Number of training epochs. Defaults to 1.
eval_steps (Optional[int]): Number of evaluation steps. If None, evaluates
on full dataset. Defaults to None.
save_steps (Optional[int]): Number of steps between checkpoint saves.
Defaults to 1000.
max_to_keep (int): Maximum number of checkpoints to keep. Defaults to 5.
save_concurrency_per_rank (int): Number of concurrent save operations per rank.
Defaults to 4.
save_every_n_windows (Optional[int]): Number of io windows to save checkpoints. Defaults to 1.
save_every_n_epochs (Optional[int]): Number of epochs to save checkpoints. Defaults to None.
save_end (bool): Whether to save checkpoints at the end of training. Defaults to True.
load_update_steps (Optional[int]): Number of steps to load dynamic model bank. Defaults to None.
load_update_windows (Optional[int]): Number of window to load dynamic model bank. Defaults to 1.
load_update_epochs (Optional[int]): Number of epochs to load dynamic model bank. Defaults to None.
params_not_save (Optional[list]): Names of parameters not to save. Defaults to None.
save_filter_fn ([Callable]): Function to filter checkpoint blocks. Defaults to None.
saver_option (Optional[SaverOptions]): Options for checkpoint saver. Defaults to None.
ckpt_save_arg (Optional[CheckpointSaveArguments]): Arguments for checkpoint save. Defaults to None.
ckpt_load_arg (Optional[CheckpointLoadArguments]): Arguments for checkpoint load. Defaults to None.
mixed_precision (Optional[str]): Mixed precision training mode. Defaults to None. Only support "bf16" and "fp16".
"""
gradient_accumulation_steps: int = 1
output_dir: str = "output_dir"
model_bank: Optional[list] = None
log_steps: int = 100
train_steps: Optional[int] = None
train_epoch: Optional[int] = 1
eval_steps: Optional[int] = None
save_steps: Optional[int] = 1000
max_to_keep: int = 5
save_concurrency_per_rank: int = 4
save_every_n_windows: int = 1
save_every_n_epochs: Optional[int] = None
save_end: Optional[bool] = True
load_update_steps: Optional[int] = None
load_update_windows: Optional[int] = 1
load_update_epochs: Optional[int] = None
params_not_save: Optional[List[str]] = None
save_filter_fn: Optional[Callable] = None
saver_option: Optional[SaverOptions] = None
ckpt_save_arg: Optional[CheckpointSaveArguments] = None
ckpt_load_arg: Optional[CheckpointLoadArguments] = None
mixed_precision: Optional[str] = None
window_iter: Optional[int] = None
[docs]
class Trainer:
"""Main training orchestrator with distributed training and checkpoint management.
The Trainer class provides a comprehensive training framework that handles:
- Distributed training coordination using Accelerate
- Automatic checkpoint saving and loading
- Training and evaluation loops with metrics tracking
- Hook system for extensible training workflows
- Support for both dense and sparse optimizers
Attributes:
args (TrainingArguments): Training configuration parameters.
hooks (List[Hook]): List of training hooks for extensibility.
train_dataset (Optional[Dataset]): Training dataset.
eval_dataset (Optional[Dataset]): Evaluation dataset.
model (nn.Module): The model to train.
dense_optimizer (torch.optim.Optimizer): Dense parameter optimizer.
dense_lr_scheduler: Learning rate scheduler for dense optimizer.
sparse_optimizer (Optional[sparse_optim.SparseOptimizer]): Sparse parameter optimizer.
data_to_cuda (bool): Whether to automatically move data to CUDA.
accelerator (Accelerator): Accelerate instance for distributed training.
Example:
.. code-block:: python
from recis.framework import Trainer, TrainingArguments
from recis.optim import SparseAdamW
from torch.optim import AdamW
# Set training arguments
training_args = TrainingArguments(
output_dir="./checkpoints",
train_steps=10000,
eval_steps=1000,
save_steps=2000,
log_steps=100,
gradient_accumulation_steps=4,
)
# split sparse params
from recis.nn.modules.hashtable import filter_out_sparse_param
sparse_params = filter_out_sparse_param(model)
# create optimizers
sparse_optimizer = SparseAdamW(sparse_params, lr=0.001)
dense_optimizer = AdamW(model.parameters(), lr=0.001)
# create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dense_optimizers=(dense_optimizer, None),
sparse_optimizer=sparse_optimizer,
data_to_cuda=True,
)
# train the model
trainer.train()
"""
[docs]
def __init__(
self,
model: Optional[nn.Module] = None,
args: TrainingArguments = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
hooks: Optional[List[Hook]] = None,
dense_optimizers: Tuple[
torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR
] = (None, None),
sparse_optimizer: Optional[sparse_optim.SparseOptimizer] = None,
data_to_cuda: bool = False,
saver: Optional[Saver] = None,
**kwargs,
) -> None:
"""Initialize the Trainer with model, datasets, and training configuration.
Args:
model (Optional[nn.Module]): The model to train.
args (TrainingArguments): Training configuration. If None, uses default.
train_dataset (Optional[Dataset]): Training dataset.
eval_dataset (Optional[Dataset]): Evaluation dataset.
hooks (Optional[List[Hook]]): List of training hooks for extensibility.
dense_optimizers (Tuple): Tuple of (optimizer, lr_scheduler) for dense parameters.
sparse_optimizer (Optional[sparse_optim.SparseOptimizer]): Optimizer for sparse parameters.
data_to_cuda (bool): Whether to automatically move data to CUDA. Defaults to False.
**kwargs: Additional arguments passed to Accelerator.
"""
if hooks is None:
hooks = []
if args is None:
args = TrainingArguments()
self.args = args
self.hooks = hooks
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.model = model
self.dense_optimizer = dense_optimizers[0]
self.dense_lr_scheduler = dense_optimizers[1]
self.sparse_optimizer = sparse_optimizer
self.data_to_cuda = data_to_cuda
self.mixed_precision = args.mixed_precision
if self.mixed_precision is not None:
assert self.mixed_precision in ["bf16", "fp16"], "mixed_precision must be 'bf16' or 'fp16'"
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
self.accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs, init_kwargs],
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=self.mixed_precision,
**kwargs,
)
self.gradient_accumulation_steps = args.gradient_accumulation_steps
(
self.model,
self.dense_optimizer,
self.dense_lr_scheduler,
) = self.accelerator.prepare(
self.model, self.dense_optimizer, self.dense_lr_scheduler
)
MetricReporter.report_forward(self.model, MODEL_FWD_NAME)
if self.sparse_optimizer is not None:
# Set sparse grad accumulation steps to 1 because Accelerator already handles loss scaling when backward
# The sparse optimizer should not scale gradients again to avoid double scaling.
# This interface is preserved for users who wish to manage gradient accumulation manually.
self.sparse_optimizer.set_grad_accum_steps(1)
self._global_step = torch.scalar_tensor(0, dtype=torch.int64)
self._epoch = torch.scalar_tensor(0, dtype=torch.int64)
self.saver = self.init_saver(model, args, saver)
self.stop_state = torch.scalar_tensor(0, dtype=torch.int64).cuda()
self.init_hooks()
def init_saver(self, model, args, saver):
saver = self.build_saver(model, args, saver)
if self.train_dataset is not None:
saver.register_io_state(ExtraFields.train_io, self.train_dataset)
if hasattr(self.train_dataset, "_window_paths"):
saver.register_for_checkpointing(
ExtraFields.train_window_io, self.train_dataset
)
if self.eval_dataset is not None and hasattr(
self.eval_dataset, "_window_paths"
):
saver.register_io_state(ExtraFields.eval_io, self.eval_dataset)
saver.register_for_checkpointing(
ExtraFields.eval_window_io, self.eval_dataset
)
if self.dense_optimizer is not None:
saver.register_for_checkpointing(
ExtraFields.recis_dense_optim, self.dense_optimizer
)
if not saver.get_extra_data(ExtraFields.global_step):
saver.register_for_checkpointing(ExtraFields.global_step, self._global_step)
if not saver.get_extra_data(ExtraFields.train_epoch):
saver.register_for_checkpointing(ExtraFields.train_epoch, self._epoch)
return saver
def build_saver(self, model, args, saver):
if saver is None:
saver_option = args.saver_option
if saver_option is None:
saver_option = SaverOptions(
model,
self.sparse_optimizer,
args.output_dir,
args.model_bank,
args.max_to_keep,
args.save_concurrency_per_rank,
args.params_not_save,
args.save_filter_fn,
)
saver = Saver(saver_option)
return saver
def init_hooks(self):
self.hooks.append(LoggerHook(self.args.log_steps))
if self.args.ckpt_save_arg is not None:
ckpt_save_arg = self.args.ckpt_save_arg
else:
ckpt_save_arg = CheckpointSaveArguments(
self.args.save_steps,
self.args.save_every_n_windows,
self.args.save_every_n_epochs,
self.args.save_end,
)
self.hooks.append(
CheckpointSaveHook(
self.saver, self._global_step, self._epoch, ckpt_save_arg
)
)
if self.args.ckpt_load_arg is not None:
ckpt_load_arg = self.args.ckpt_load_arg
else:
ckpt_load_arg = CheckpointLoadArguments(
self.args.load_update_steps,
self.args.load_update_windows,
self.args.load_update_epochs,
)
self.hooks.append(
CheckpointLoadHook(
self.saver, self._global_step, self._epoch, ckpt_load_arg
)
)
self.hooks.append(
MetricReportHook(
model=self.model,
report_args=None,
)
)
[docs]
def add_hooks(self, hooks: List[Hook]):
"""Add multiple hooks to the trainer.
Args:
hooks (List[Hook]): List of hooks to add.
"""
for hook in hooks:
self.add_hook(hook)
[docs]
def add_hook(self, hook: Hook):
"""Add a single hook to the trainer.
Args:
hook (Hook): The hook to add.
"""
self.hooks.append(hook)
[docs]
def train(self, train_steps=None):
"""Execute the training loop.
Args:
train_steps (Optional[int]): Override for number of training steps.
If None, uses args.train_steps.
"""
for hook in self.hooks:
hook.start(is_train=True)
if hasattr(self.train_dataset, "_window_paths"):
train_loop_fn = self._train_loop_by_window
for hook in self.hooks:
hook.window_mode()
else:
train_loop_fn = self._train_loop
while self._epoch < self.args.train_epoch:
for hook in self.hooks:
hook.before_epoch(is_train=True)
train_loop_fn(
self.args.train_steps if train_steps is None else train_steps,
epoch=self._epoch,
)
self.train_dataset.reset()
for hook in self.hooks:
hook.after_epoch(is_train=True)
for hook in self.hooks:
hook.end(is_train=True)
[docs]
def evaluate(self, eval_steps=None):
"""Execute the evaluation loop.
Args:
eval_steps (Optional[int]): Override for number of evaluation steps.
If None, evaluates on full dataset.
"""
for hook in self.hooks:
hook.start(is_train=False)
if hasattr(self.eval_dataset, "_window_paths"):
eval_loop_fn = self._eval_loop_by_window
for hook in self.hooks:
hook.window_mode()
else:
eval_loop_fn = self._eval_loop
for hook in self.hooks:
hook.before_epoch(is_train=False)
eval_loop_fn(
self.args.eval_steps if eval_steps is None else eval_steps,
)
for hook in self.hooks:
hook.after_epoch(is_train=False)
hook.end(is_train=False)
[docs]
def train_and_evaluate(self, train_steps=None, eval_steps=None):
"""Execute alternating training and evaluation loops.
Args:
train_steps (Optional[int]): Override for number of training steps per epoch.
eval_steps (Optional[int]): Override for number of evaluation steps.
"""
for hook in self.hooks:
hook.start(is_train=True)
if hasattr(self.train_dataset, "_window_paths"):
assert hasattr(self.eval_dataset, "_window_paths"), (
"train and eval dataset should both be window io"
)
loop_fn = self._train_eval_loop_by_window
for hook in self.hooks:
hook.window_mode()
else:
assert not hasattr(self.eval_dataset, "_window_paths"), (
"train and eval dataset should both not window io"
)
loop_fn = self._train_eval_loop
while self._epoch < self.args.train_epoch:
for hook in self.hooks:
hook.before_epoch(is_train=True)
loop_fn(
self.args.train_steps if train_steps is None else train_steps,
self.args.eval_steps if eval_steps is None else eval_steps,
epoch=self._epoch,
)
self.train_dataset.reset()
self.eval_dataset.reset()
for hook in self.hooks:
hook.after_epoch(is_train=True)
for hook in self.hooks:
hook.end(is_train=True)
def get_new_window_iter(self, dataset):
if not hasattr(dataset, "_window_paths"):
raise TypeError("dataset must be window_io")
while True:
try:
need_skip = dataset.next_window()
except StopIteration:
logger.info("Window IO Finish")
return None
except Exception as e:
raise e
read_offset = int(dataset._read_offset[0])
if need_skip:
logger.info(f"Skip for window, offset = {read_offset}")
else:
logger.info(f"Next window, offset = {read_offset}")
break
return iter(dataset)
def sync_exit_flag(self, flag: bool):
self.stop_state.fill_(int(flag))
dist.all_reduce(self.stop_state, op=dist.ReduceOp.MAX)
return bool(self.stop_state.item())
def _train_loop_by_window(self, max_steps=None, epoch=1):
window_iter = 0
self.model.train()
while True:
if (
self.args.window_iter is not None
and window_iter >= self.args.window_iter
):
break
iterator = self.get_new_window_iter(self.train_dataset)
need_break = iterator is None
need_break = self.sync_exit_flag(need_break)
if need_break:
break
for hook in self.hooks:
hook.before_window(is_train=True)
self._train_loop_internal(iterator, max_steps, epoch)
for hook in self.hooks:
hook.after_window(is_train=True)
window_iter += 1
def _eval_loop_by_window(self, max_steps=None):
window_iter = 0
self.model.eval()
while True:
if (
self.args.window_iter is not None
and window_iter >= self.args.window_iter
):
break
iterator = self.get_new_window_iter(self.eval_dataset)
need_break = iterator is None
need_break = self.sync_exit_flag(need_break)
if need_break:
break
for hook in self.hooks:
hook.before_window(is_train=False)
self._eval_loop_internal(iterator, max_steps)
for hook in self.hooks:
hook.after_window(is_train=False)
window_iter += 1
def _train_eval_loop_by_window(self, train_steps=None, eval_steps=None, epoch=1):
window_iter = 0
while True:
if (
self.args.window_iter is not None
and window_iter >= self.args.window_iter
):
break
self.model.train()
train_iterator = self.get_new_window_iter(self.train_dataset)
train_need_break = train_iterator is None
train_need_break = self.sync_exit_flag(train_need_break)
if train_need_break:
logger.info(
"train_and_eval window will stop, because train dataset has no window to read."
)
break
eval_iterator = self.get_new_window_iter(self.eval_dataset)
eval_need_break = eval_iterator is None
eval_need_break = self.sync_exit_flag(eval_need_break)
if eval_need_break:
logger.info(
"train_and_eval window will stop, because eval dataset has no window to read."
)
break
for hook in self.hooks:
hook.before_window(is_train=True)
self._train_loop_internal(train_iterator, train_steps, epoch)
self._eval_loop_internal(eval_iterator, eval_steps)
for hook in self.hooks:
hook.after_window(is_train=True)
window_iter += 1
def _train_loop(self, max_steps=None, epoch=1):
self.model.train()
iterator = iter(self.train_dataset)
self._train_loop_internal(iterator, max_steps, epoch)
def _eval_loop(self, max_steps=None):
self.model.eval()
iterator = iter(self.eval_dataset)
self._eval_loop_internal(iterator, max_steps)
def _train_eval_loop(self, train_steps=None, eval_steps=None, epoch=1):
self._train_loop(train_steps, epoch)
self._eval_loop(eval_steps)
def _eval_loop_internal(self, data_iter, max_steps=None):
lstep = 0
while True:
if max_steps is not None and lstep >= max_steps:
break
for hook in self.hooks:
hook.before_step(is_train=False)
stop_flag, data = next(data_iter)
need_break = self.sync_exit_flag(stop_flag)
if need_break:
for hook in self.hooks:
hook.out_off_data()
break
if self.data_to_cuda:
data = copy_data_to_device(data, "cuda")
for hook in self.hooks:
hook.after_data(is_train=False, data=data)
metrics = {}
with torch.no_grad():
eval_result = self.model(data)
metrics.update(get_global_metrics())
for hook in self.hooks:
hook.after_step(
metrics=metrics,
global_step=self._global_step,
is_train=False,
eval_result=eval_result,
)
lstep += 1
def _train_loop_internal(self, data_iter, max_steps=None, epoch=1):
lstep = 0
while True:
if max_steps is not None and lstep >= max_steps:
break
for hook in self.hooks:
hook.before_step(is_train=True)
stop_flag, data = next(data_iter)
need_break = self.sync_exit_flag(stop_flag)
if need_break:
for hook in self.hooks:
hook.out_off_data()
break
if self.data_to_cuda:
data = copy_data_to_device(data, "cuda")
for hook in self.hooks:
hook.after_data(is_train=True, data=data)
metrics = {}
with self.accelerator.accumulate(self.model):
self._train_step(data, epoch, metrics)
for hook in self.hooks:
hook.after_step(
metrics=metrics, global_step=self._global_step, is_train=True
)
lstep += 1
def _train_step(self, data, epoch, metrics):
self.dense_optimizer.zero_grad()
if self.sparse_optimizer is not None:
self.sparse_optimizer.zero_grad()
with self.accelerator.autocast():
loss = self.model(data)
metrics.update(epoch=epoch)
metrics.update(loss=loss)
metrics.update(get_global_metrics())
self.accelerator.backward(loss)
self.dense_optimizer.step()
if self.sparse_optimizer is not None:
self.sparse_optimizer.step()
if self.dense_lr_scheduler is not None:
self.dense_lr_scheduler.step()