Training Framework Module

RecIS’s training framework module provides comprehensive model training, evaluation, and management capabilities, simplifying the development workflow for deep learning models.

Core Components

TrainingArguments

class recis.framework.trainer.TrainingArguments(gradient_accumulation_steps: int = 1, output_dir: str = 'output_dir', model_bank: list | None = None, log_steps: int = 100, train_steps: int | None = None, train_epoch: int | None = 1, eval_steps: int | None = None, save_steps: int | None = 1000, max_to_keep: int = 5, save_concurrency_per_rank: int = 4)[source]

Configuration class for training parameters.

This dataclass contains all the configuration parameters needed for training, including optimization settings, logging intervals, and checkpoint management.

gradient_accumulation_steps

Number of steps to accumulate gradients before performing an optimizer step. Defaults to 1.

Type:

int

output_dir

Directory where checkpoints and logs will be saved. Defaults to “output_dir”.

Type:

str

model_bank

List of model bank paths for initialization. Defaults to None.

Type:

Optional[list]

log_steps

Number of training steps between logging. Defaults to 100.

Type:

int

train_steps

Maximum number of training steps. If None, will train for full epochs. Defaults to None.

Type:

Optional[int]

train_epoch

Number of training epochs. Defaults to 1.

Type:

Optional[int]

eval_steps

Number of evaluation steps. If None, evaluates on full dataset. Defaults to None.

Type:

Optional[int]

save_steps

Number of steps between checkpoint saves. Defaults to 1000.

Type:

Optional[int]

max_to_keep

Maximum number of checkpoints to keep. Defaults to 5.

Type:

int

save_concurrency_per_rank

Number of concurrent save operations per rank. Defaults to 4.

Type:

int

Trainer

class recis.framework.trainer.Trainer(model: Module | None = None, args: TrainingArguments = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | None = None, hooks: List[Hook] | None = None, dense_optimizers: Tuple[Optimizer, LambdaLR] = (None, None), sparse_optimizer: SparseOptimizer | None = None, data_to_cuda: bool = False, **kwargs)[source]

Main training orchestrator with distributed training and checkpoint management.

The Trainer class provides a comprehensive training framework that handles: - Distributed training coordination using Accelerate - Automatic checkpoint saving and loading - Training and evaluation loops with metrics tracking - Hook system for extensible training workflows - Support for both dense and sparse optimizers

args

Training configuration parameters.

Type:

TrainingArguments

hooks

List of training hooks for extensibility.

Type:

List[Hook]

train_dataset

Training dataset.

Type:

Optional[Dataset]

eval_dataset

Evaluation dataset.

Type:

Optional[Dataset]

model

The model to train.

Type:

nn.Module

dense_optimizer

Dense parameter optimizer.

Type:

torch.optim.Optimizer

dense_lr_scheduler

Learning rate scheduler for dense optimizer.

sparse_optimizer

Sparse parameter optimizer.

Type:

Optional[sparse_optim.SparseOptimizer]

data_to_cuda

Whether to automatically move data to CUDA.

Type:

bool

accelerator

Accelerate instance for distributed training.

Type:

Accelerator

checkpoint_manager

Handles checkpoint operations.

Type:

CheckpointManager

Example:

from recis.framework import Trainer, TrainingArguments
from recis.optim import SparseAdamW
from torch.optim import AdamW

# Set training arguments
training_args = TrainingArguments(
    output_dir="./checkpoints",
    train_steps=10000,
    eval_steps=1000,
    save_steps=2000,
    log_steps=100,
    gradient_accumulation_steps=4,
)

# split sparse params
from recis.nn.modules.hashtable import filter_out_sparse_param

sparse_params = filter_out_sparse_param(model)

# create optimizers
sparse_optimizer = SparseAdamW(sparse_params, lr=0.001)
dense_optimizer = AdamW(model.parameters(), lr=0.001)

# create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dense_optimizers=(dense_optimizer, None),
    sparse_optimizer=sparse_optimizer,
    data_to_cuda=True,
)

# train the model
trainer.train()
__init__(model: Module | None = None, args: TrainingArguments = None, train_dataset: Dataset | None = None, eval_dataset: Dataset | None = None, hooks: List[Hook] | None = None, dense_optimizers: Tuple[Optimizer, LambdaLR] = (None, None), sparse_optimizer: SparseOptimizer | None = None, data_to_cuda: bool = False, **kwargs) None[source]

Initialize the Trainer with model, datasets, and training configuration.

Parameters:
  • model (Optional[nn.Module]) – The model to train.

  • args (TrainingArguments) – Training configuration. If None, uses default.

  • train_dataset (Optional[Dataset]) – Training dataset.

  • eval_dataset (Optional[Dataset]) – Evaluation dataset.

  • hooks (Optional[List[Hook]]) – List of training hooks for extensibility.

  • dense_optimizers (Tuple) – Tuple of (optimizer, lr_scheduler) for dense parameters.

  • sparse_optimizer (Optional[sparse_optim.SparseOptimizer]) – Optimizer for sparse parameters.

  • data_to_cuda (bool) – Whether to automatically move data to CUDA. Defaults to False.

  • **kwargs – Additional arguments passed to Accelerator.

add_hook(hook: Hook)[source]

Add a single hook to the trainer.

Parameters:

hook (Hook) – The hook to add.

add_hooks(hooks: List[Hook])[source]

Add multiple hooks to the trainer.

Parameters:

hooks (List[Hook]) – List of hooks to add.

evaluate(eval_steps=None)[source]

Execute the evaluation loop.

Parameters:

eval_steps (Optional[int]) – Override for number of evaluation steps. If None, evaluates on full dataset.

restore()[source]

Restore model and training state from checkpoints.

This method loads the latest checkpoint if available and restores the model state, optimizer state, and training progress.

train(train_steps=None)[source]

Execute the training loop.

Parameters:

train_steps (Optional[int]) – Override for number of training steps. If None, uses args.train_steps.

Saver

class recis.framework.checkpoint_manager.Saver(model: Module, sparse_optim=None, output_dir: str = './', max_keep: int = 1, concurrency: int = 4)[source]

Checkpoint saver for managing model and training state persistence.

The Saver class handles the saving and loading of model checkpoints including: - Dense and sparse model parameters - Optimizer states - IO states for datasets - Checkpoint versioning and cleanup - Support for distributed filesystems

Example

>>> saver = Saver(
...     model=model,
...     sparse_optim=sparse_optimizer,
...     output_dir="./checkpoints",
...     max_keep=5,
... )
>>> saver.save("checkpoint_001")
__init__(model: Module, sparse_optim=None, output_dir: str = './', max_keep: int = 1, concurrency: int = 4)[source]

Initialize the checkpoint saver.

Parameters:
  • model (torch.nn.Module) – The model to save checkpoints for.

  • sparse_optim (Optional) – Sparse optimizer instance for sparse parameters.

  • output_dir (str) – Directory to save checkpoints. Defaults to “./”.

  • max_keep (int) – Maximum number of checkpoints to keep. Defaults to 1.

  • concurrency (int) – Number of concurrent save operations. Defaults to 4.

register_for_checkpointing(name, obj: object)[source]

Register an object for checkpointing.

Parameters:
  • name (str) – Name identifier for the checkpointed object.

  • obj (object) – Object to include in checkpoints.

Raises:

ValueError – If the name is already registered.

save(ckpt_id: str, shard_id: int = 0, shard_num: int = 1)[source]

Save a complete checkpoint with the given ID.

This method saves all registered components including model parameters, optimizer states, and IO states. It also handles checkpoint versioning and cleanup of old checkpoints.

Parameters:
  • ckpt_id (str) – Unique identifier for this checkpoint.

  • shard_id (int) – Shard ID for distributed saving. Defaults to 0.

  • shard_num (int) – Total number of shards. Defaults to 1.

CheckpointManager

class recis.framework.checkpoint_manager.CheckpointManager(saver: Saver, save_interval: int)[source]

High-level checkpoint manager for coordinating checkpoint operations.

The CheckpointManager provides a high-level interface for managing checkpoints during training, including automatic saving at intervals, loading from model banks, and coordinating with the training loop.

Example

>>> checkpoint_manager = CheckpointManager(saver=saver, save_interval=1000)
>>> # During training loop
>>> checkpoint_manager.step()  # Call after each training step
>>> # Automatic save will occur every save_interval steps
__init__(saver: Saver, save_interval: int) None[source]

Initialize the checkpoint manager.

Parameters:
  • saver (Saver) – The saver instance to use for checkpoint operations.

  • save_interval (int) – Number of steps between automatic saves.

save()[source]

Save a checkpoint with automatic ID generation.

Example Usage:

from recis.framework.checkpoint_manager import CheckpointManager, Saver

# Create Saver
saver = Saver(
      model,
      sparse_optimizer,
      output_dir=output_dir,
      max_keep=2,
      concurrency=2,
   )
# Create checkpoint manager
checkpoint_manager = CheckpointManager(
      saver,
      save_interval=1000
)

# Save checkpoint
checkpoint_manager.save()

# Load latest checkpoint
checkpoint = checkpoint_manager.restore()

Advanced Usage

Custom Training Pipeline

from framework.trainer import Trainer
class MyTrainer(Trainer):
     def _train_step(self, data, epoch, metrics):
        self.dense_optimizer.zero_grad()
        if self.sparse_optimizer is not None:
            self.sparse_optimizer.zero_grad()
        loss = self.model(data)
        metrics.update(epoch=epoch)
        metrics.update(loss=loss)
        metrics.update(get_global_metrics())
        loss.backward()
        self.dense_optimizer.step()
        if self.sparse_optimizer is not None:
            self.sparse_optimizer.step()
        if self.dense_lr_scheduler is not None:
            self.dense_lr_scheduler.step()

Gradient Accumulation Training

# Configure gradient accumulation
training_args = TrainingArguments(
    output_dir="./output",
    train_steps=10000,
    gradient_accumulation_steps=8,  # Accumulate 8 steps before update
    log_steps=100
)

# Trainer will automatically handle gradient accumulation
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset
)