Training Framework Module
RecIS’s training framework module provides comprehensive model training, evaluation, and management capabilities, simplifying the development workflow for deep learning models.
Core Components
TrainingArguments
- class recis.framework.trainer.TrainingArguments(gradient_accumulation_steps: int = 1, output_dir: str = 'output_dir', model_bank: list | None = None, log_steps: int = 100, train_steps: int | None = None, train_epoch: int | None = 1, eval_steps: int | None = None, save_steps: int | None = 1000, max_to_keep: int = 5, save_concurrency_per_rank: int = 4)[source]
Configuration class for training parameters.
This dataclass contains all the configuration parameters needed for training, including optimization settings, logging intervals, and checkpoint management.
- gradient_accumulation_steps
Number of steps to accumulate gradients before performing an optimizer step. Defaults to 1.
- Type:
- train_steps
Maximum number of training steps. If None, will train for full epochs. Defaults to None.
- Type:
Optional[int]
Trainer
- class recis.framework.trainer.Trainer(model: Module | None = None, args: TrainingArguments = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | None = None, hooks: List[Hook] | None = None, dense_optimizers: Tuple[Optimizer, LambdaLR] = (None, None), sparse_optimizer: SparseOptimizer | None = None, data_to_cuda: bool = False, **kwargs)[source]
Main training orchestrator with distributed training and checkpoint management.
The Trainer class provides a comprehensive training framework that handles: - Distributed training coordination using Accelerate - Automatic checkpoint saving and loading - Training and evaluation loops with metrics tracking - Hook system for extensible training workflows - Support for both dense and sparse optimizers
- args
Training configuration parameters.
- Type:
- train_dataset
Training dataset.
- Type:
Optional[Dataset]
- eval_dataset
Evaluation dataset.
- Type:
Optional[Dataset]
- model
The model to train.
- Type:
nn.Module
- dense_optimizer
Dense parameter optimizer.
- Type:
- dense_lr_scheduler
Learning rate scheduler for dense optimizer.
- sparse_optimizer
Sparse parameter optimizer.
- Type:
Optional[sparse_optim.SparseOptimizer]
- accelerator
Accelerate instance for distributed training.
- Type:
Accelerator
- checkpoint_manager
Handles checkpoint operations.
- Type:
Example:
from recis.framework import Trainer, TrainingArguments from recis.optim import SparseAdamW from torch.optim import AdamW # Set training arguments training_args = TrainingArguments( output_dir="./checkpoints", train_steps=10000, eval_steps=1000, save_steps=2000, log_steps=100, gradient_accumulation_steps=4, ) # split sparse params from recis.nn.modules.hashtable import filter_out_sparse_param sparse_params = filter_out_sparse_param(model) # create optimizers sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) dense_optimizer = AdamW(model.parameters(), lr=0.001) # create trainer trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, dense_optimizers=(dense_optimizer, None), sparse_optimizer=sparse_optimizer, data_to_cuda=True, ) # train the model trainer.train()
- __init__(model: Module | None = None, args: TrainingArguments = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | None = None, hooks: List[Hook] | None = None, dense_optimizers: Tuple[Optimizer, LambdaLR] = (None, None), sparse_optimizer: SparseOptimizer | None = None, data_to_cuda: bool = False, **kwargs) None [source]
Initialize the Trainer with model, datasets, and training configuration.
- Parameters:
model (Optional[nn.Module]) – The model to train.
args (TrainingArguments) – Training configuration. If None, uses default.
train_dataset (Optional[Dataset]) – Training dataset.
eval_dataset (Optional[Dataset]) – Evaluation dataset.
hooks (Optional[List[Hook]]) – List of training hooks for extensibility.
dense_optimizers (Tuple) – Tuple of (optimizer, lr_scheduler) for dense parameters.
sparse_optimizer (Optional[sparse_optim.SparseOptimizer]) – Optimizer for sparse parameters.
data_to_cuda (bool) – Whether to automatically move data to CUDA. Defaults to False.
**kwargs – Additional arguments passed to Accelerator.
- add_hook(hook: Hook)[source]
Add a single hook to the trainer.
- Parameters:
hook (Hook) – The hook to add.
- add_hooks(hooks: List[Hook])[source]
Add multiple hooks to the trainer.
- Parameters:
hooks (List[Hook]) – List of hooks to add.
- evaluate(eval_steps=None)[source]
Execute the evaluation loop.
- Parameters:
eval_steps (Optional[int]) – Override for number of evaluation steps. If None, evaluates on full dataset.
Saver
- class recis.framework.checkpoint_manager.Saver(model: Module, sparse_optim=None, output_dir: str = './', max_keep: int = 1, concurrency: int = 4)[source]
Checkpoint saver for managing model and training state persistence.
The Saver class handles the saving and loading of model checkpoints including: - Dense and sparse model parameters - Optimizer states - IO states for datasets - Checkpoint versioning and cleanup - Support for distributed filesystems
Example
>>> saver = Saver( ... model=model, ... sparse_optim=sparse_optimizer, ... output_dir="./checkpoints", ... max_keep=5, ... ) >>> saver.save("checkpoint_001")
- __init__(model: Module, sparse_optim=None, output_dir: str = './', max_keep: int = 1, concurrency: int = 4)[source]
Initialize the checkpoint saver.
- Parameters:
model (torch.nn.Module) – The model to save checkpoints for.
sparse_optim (Optional) – Sparse optimizer instance for sparse parameters.
output_dir (str) – Directory to save checkpoints. Defaults to “./”.
max_keep (int) – Maximum number of checkpoints to keep. Defaults to 1.
concurrency (int) – Number of concurrent save operations. Defaults to 4.
- register_for_checkpointing(name, obj: object)[source]
Register an object for checkpointing.
- Parameters:
- Raises:
ValueError – If the name is already registered.
CheckpointManager
- class recis.framework.checkpoint_manager.CheckpointManager(saver: Saver, save_interval: int)[source]
High-level checkpoint manager for coordinating checkpoint operations.
The CheckpointManager provides a high-level interface for managing checkpoints during training, including automatic saving at intervals, loading from model banks, and coordinating with the training loop.
Example
>>> checkpoint_manager = CheckpointManager(saver=saver, save_interval=1000) >>> # During training loop >>> checkpoint_manager.step() # Call after each training step >>> # Automatic save will occur every save_interval steps
Example Usage:
from recis.framework.checkpoint_manager import CheckpointManager, Saver
# Create Saver
saver = Saver(
model,
sparse_optimizer,
output_dir=output_dir,
max_keep=2,
concurrency=2,
)
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
saver,
save_interval=1000
)
# Save checkpoint
checkpoint_manager.save()
# Load latest checkpoint
checkpoint = checkpoint_manager.restore()
Advanced Usage
Custom Training Pipeline
from framework.trainer import Trainer
class MyTrainer(Trainer):
def _train_step(self, data, epoch, metrics):
self.dense_optimizer.zero_grad()
if self.sparse_optimizer is not None:
self.sparse_optimizer.zero_grad()
loss = self.model(data)
metrics.update(epoch=epoch)
metrics.update(loss=loss)
metrics.update(get_global_metrics())
loss.backward()
self.dense_optimizer.step()
if self.sparse_optimizer is not None:
self.sparse_optimizer.step()
if self.dense_lr_scheduler is not None:
self.dense_lr_scheduler.step()
Gradient Accumulation Training
# Configure gradient accumulation
training_args = TrainingArguments(
output_dir="./output",
train_steps=10000,
gradient_accumulation_steps=8, # Accumulate 8 steps before update
log_steps=100
)
# Trainer will automatically handle gradient accumulation
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)