Evaluation Metrics Module
RecIS’s evaluation metrics module provides commonly used evaluation metrics for recommendation systems and machine learning, supporting distributed computing and real-time updates.
Core Metrics
AUROC
- class recis.metrics.auroc.AUROC(num_thresholds=200, dist_sync_on_step=False)[source]
Area Under ROC Curve (AUROC) metric for binary classification.
This class computes the Area Under the Receiver Operating Characteristic (ROC) Curve, which measures the ability of a binary classifier to distinguish between positive and negative classes across all classification thresholds. The implementation uses confusion matrices computed at multiple thresholds for efficient and accurate AUC calculation.
The AUROC metric is particularly useful for: - Binary classification tasks in recommendation systems - Click-through rate (CTR) prediction - Conversion rate optimization - Any binary classification where class balance matters
- thresholds
Threshold values used for classification decisions.
- Type:
- tp
True positive counts at each threshold.
- Type:
nn.Parameter
- fp
False positive counts at each threshold.
- Type:
nn.Parameter
- tn
True negative counts at each threshold.
- Type:
nn.Parameter
- fn
False negative counts at each threshold.
- Type:
nn.Parameter
Example
Creating and using AUROC metric:
# Initialize with custom configuration auc_metric = AUROC( num_thresholds=100, # Use 100 thresholds for ROC curve dist_sync_on_step=False, # Sync only when computing final result ) # Batch processing predictions = torch.tensor([0.9, 0.7, 0.3, 0.1]) labels = torch.tensor([1, 1, 0, 0]) # Update metric state auc_metric.update(predictions, labels) # Compute AUC auc_score = auc_metric.compute() print(f"AUC: {auc_score:.4f}") # Direct computation (alternative to update + compute) direct_auc = auc_metric(predictions, labels)
- __init__(num_thresholds=200, dist_sync_on_step=False)[source]
Initialize AUROC metric with specified configuration.
- Parameters:
num_thresholds (int, optional) – Number of thresholds to use for ROC curve computation. Must be greater than 2. Defaults to 200.
dist_sync_on_step (bool, optional) – Whether to synchronize metric state across distributed processes on each update step. If False, synchronization only occurs during compute(). Defaults to False.
- Raises:
AssertionError – If num_thresholds is not greater than 2.
Note
Higher num_thresholds values provide more accurate AUC computation but require more memory and computation. The thresholds are evenly distributed between 0 and 1 with small epsilon values at the boundaries.
- compute()[source]
Compute final AUROC score from accumulated statistics.
This method calculates the AUROC using all statistics accumulated through previous update() calls. It’s typically called at the end of an epoch or evaluation period to get the final metric value.
- Returns:
AUROC score as a scalar tensor.
- Return type:
Example:
auc_metric = AUROC() # Accumulate data from multiple batches for batch in dataloader: auc_metric.update(model(batch), batch["labels"]) # Get final AUC score final_auc = auc_metric.compute() print(f"Epoch AUC: {final_auc:.4f}")
Note
This method uses the current accumulated state (tp, fp, tn, fn) to compute the final AUROC. Make sure to call reset() before starting a new evaluation period.
- forward(predictions, labels)[source]
Compute AUROC directly from predictions and labels.
This method provides a direct way to compute AUROC without updating the internal state. It’s useful for one-time computations or when you don’t need to accumulate statistics across multiple batches.
- Parameters:
predictions (torch.Tensor) – Predicted probabilities in range [0, 1]. Shape: (N,) where N is the number of samples.
labels (torch.Tensor) – Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples.
- Returns:
AUROC score as a scalar tensor.
- Return type:
Example:
auc_metric = AUROC(num_thresholds=100) # Direct computation predictions = torch.tensor([0.9, 0.7, 0.3, 0.1]) labels = torch.tensor([1, 1, 0, 0]) auc_score = auc_metric(predictions, labels) print(f"AUC: {auc_score:.4f}")
- reset()[source]
Reset all accumulated statistics to zero.
This method clears all internal state, setting all confusion matrix components back to zero. It should be called at the beginning of each new evaluation period (e.g., new epoch) to ensure clean statistics.
Example:
auc_metric = AUROC() for epoch in range(num_epochs): # Reset at the beginning of each epoch auc_metric.reset() # Accumulate statistics for the epoch for batch in dataloader: auc_metric.update(model(batch), batch["labels"]) # Get epoch result epoch_auc = auc_metric.compute() print(f"Epoch {epoch} AUC: {epoch_auc:.4f}")
Note
This method modifies the internal parameter tensors in-place using zero_() for efficiency.
- update(predictions, labels)[source]
Update metric state with new predictions and labels.
This method accumulates confusion matrix statistics from the current batch with previously seen data. It’s designed for incremental updates during training where you want to compute metrics across multiple batches.
- Parameters:
predictions (torch.Tensor) – Predicted probabilities in range [0, 1]. Shape: (N,) where N is the number of samples.
labels (torch.Tensor) – Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples.
Example:
auc_metric = AUROC(num_thresholds=200, dist_sync_on_step=True) # Process multiple batches for batch in dataloader: preds = model(batch) labels = batch["labels"] # Accumulate statistics auc_metric.update(preds, labels) # Get final result final_auc = auc_metric.compute()
Note
If dist_sync_on_step is True, this method will synchronize statistics across all distributed processes, which may impact performance but ensures consistency in distributed training.
GAUC
- class recis.metrics.gauc.Gauc[source]
Group AUC (GAUC) metric for recommendation systems and personalized ML tasks.
This class computes the Group Area Under ROC Curve, which evaluates model performance by calculating AUC separately for each group (typically users) and then aggregating the results. This provides a more accurate assessment of recommendation system performance compared to global AUC metrics.
GAUC is essential for recommendation systems because: - Different users have different behavior patterns and preferences - Global AUC can be dominated by users with many interactions - User-level evaluation provides better insights into model fairness - It’s more aligned with business objectives in personalized systems
The implementation uses optimized C++ operations for efficient computation and supports distributed training scenarios commonly found in large-scale recommendation systems.
- _cpu
CPU device for computation efficiency.
- Type:
Example
Using GAUC in a recommendation model:
# Initialize GAUC metric gauc_metric = Gauc() # Prepare recommendation data labels = torch.tensor([1, 0, 1, 0, 1, 0]) # Click labels predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3]) # CTR predictions user_ids = torch.tensor([1, 1, 2, 2, 3, 3]) # User identifiers # Compute GAUC batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids) print(f"Batch GAUC: {batch_gauc:.4f}") print(f"Cumulative GAUC: {cumulative_gauc:.4f}") # Continue with more batches... # The cumulative GAUC will be updated automatically Integration with training loop:
model = RecommendationModel() gauc_metric = Gauc() for epoch in range(num_epochs): gauc_metric.reset() # Reset for new epoch for batch in train_dataloader: # Forward pass logits = model(batch) predictions = torch.sigmoid(logits) # Compute GAUC batch_gauc, epoch_gauc = gauc_metric( batch["labels"], predictions, batch["user_ids"] ) # Log metrics if batch_idx % 100 == 0: print( f"Epoch {epoch}, Batch {batch_idx}: " f"Batch GAUC = {batch_gauc:.4f}, " f"Epoch GAUC = {epoch_gauc:.4f}" ) print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}")
- __init__() None [source]
Initialize GAUC metric with default configuration.
The metric automatically detects the distributed training environment and configures itself accordingly. It uses CPU computation for the GAUC calculation to optimize memory usage and computation efficiency.
Note
The metric automatically reads the WORLD_SIZE environment variable to determine if running in distributed mode. In distributed training, it will aggregate results across all processes.
- forward(labels, predictions, indicators)[source]
Compute GAUC for the current batch and update cumulative statistics.
This method computes the Group AUC by calculating AUC separately for each group identified by the indicators (typically user IDs) and then computing a weighted average. It returns both the current batch GAUC and the cumulative GAUC across all processed batches.
- Parameters:
labels (torch.Tensor) – Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples.
predictions (torch.Tensor) – Predicted probabilities or scores. Shape: (N,) where N is the number of samples.
indicators (torch.Tensor) – Group identifiers (e.g., user IDs). Shape: (N,) where N is the number of samples.
- Returns:
- A tuple containing:
batch_gauc (float): GAUC score for the current batch
cumulative_gauc (float): Cumulative GAUC across all processed batches
- Return type:
Example:
gauc_metric = Gauc() # Single batch computation labels = torch.tensor([1, 0, 1, 0, 1, 0]) predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3]) user_ids = torch.tensor([1, 1, 2, 2, 3, 3]) batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids) print(f"Batch GAUC: {batch_gauc:.4f}") print(f"Cumulative GAUC: {cumulative_gauc:.4f}") # Process another batch labels2 = torch.tensor([0, 1, 1, 0]) predictions2 = torch.tensor([0.2, 0.8, 0.9, 0.1]) user_ids2 = torch.tensor([4, 4, 5, 5]) batch_gauc2, cumulative_gauc2 = gauc_metric( labels2, predictions2, user_ids2 ) print(f"Batch 2 GAUC: {batch_gauc2:.4f}") print(f"Updated Cumulative GAUC: {cumulative_gauc2:.4f}")
Note
The method automatically handles distributed training by aggregating results across all processes when WORLD_SIZE > 1. The computation is performed on CPU for memory efficiency, with automatic device transfer handled internally.
The GAUC calculation weights each group’s AUC by the number of samples in that group, providing a fair aggregation across groups of different sizes.
- reset()[source]
Reset all accumulated statistics to zero.
This method clears all internal state, resetting both the cumulative AUC sum and count statistics. It should be called at the beginning of each new evaluation period (e.g., new epoch) to ensure clean metrics.
Example:
gauc_metric = Gauc() for epoch in range(num_epochs): # Reset at the beginning of each epoch gauc_metric.reset() # Process batches for the epoch for batch in dataloader: batch_gauc, epoch_gauc = gauc_metric( batch["labels"], batch["predictions"], batch["user_ids"] ) print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}")
Metrics Integration
Integrating Metrics in Models
RecIS provides convenient ways to integrate evaluation metrics in model training:
import torch.nn as nn
from recis.metrics import AUROC, GAUC
from recis.framework.metrics import add_metric
class RecommendationModel(nn.Module):
def __init__(self):
super().__init__()
# Model components
self.embedding = ...
self.dnn = ...
# Evaluation metrics
self.auc_metric = AUROC(num_thresholds=200, dist_sync_on_step=True)
self.gauc_metric = GAUC(num_thresholds=200)
self.loss_fn = nn.BCELoss()
def forward(self, batch):
# Model forward pass
logits = self.predict(batch)
labels = batch['label']
user_ids = batch['user_id']
# Compute loss
loss = self.loss_fn(logits, labels.float())
# Update metrics
self.auc_metric.update(logits, labels)
self.gauc_metric.update(logits, labels, user_ids)
# Compute metric values
auc = self.auc_metric.compute()
gauc = self.gauc_metric.compute()
# Add to training framework's metric system
add_metric("auc", auc)
add_metric("gauc", gauc)
add_metric("loss", loss)
return loss
Distributed Metrics Computation
Using metrics in distributed training:
import torch.distributed as dist
from recis.metrics import AUROC
# Ensure distributed environment is initialized
if dist.is_initialized():
# Enable distributed synchronization
auc_metric = AUROC(
num_thresholds=200,
dist_sync_on_step=True # Sync at each step for consistency
)
else:
auc_metric = AUROC(num_thresholds=200)
# Use normally in training loop
for batch in dataloader:
preds = model(batch)
labels = batch['label']
# Metrics will automatically handle distributed aggregation
auc_metric.update(preds, labels)
auc_score = auc_metric.compute()
Custom Metrics
Creating custom evaluation metrics:
import torch
from typing import Any, Optional
class CustomMetric:
def __init__(self, dist_sync_on_step: bool = False):
self.dist_sync_on_step = dist_sync_on_step
self.reset()
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update metric state"""
# Implement metric update logic
self.total_samples += preds.size(0)
self.correct_predictions += (preds.round() == target).sum()
def compute(self) -> torch.Tensor:
"""Compute metric value"""
if self.total_samples == 0:
return torch.tensor(0.0)
accuracy = self.correct_predictions.float() / self.total_samples
# Distributed synchronization
if self.dist_sync_on_step and dist.is_initialized():
dist.all_reduce(accuracy, op=dist.ReduceOp.AVG)
return accuracy
def reset(self) -> None:
"""Reset metric state"""
self.total_samples = 0
self.correct_predictions = 0
# Use custom metric
custom_metric = CustomMetric(dist_sync_on_step=True)
Frequently Asked Questions
Q: How to correctly use metrics in distributed training?
A: Ensure correct synchronization parameters are set:
# During training: sync at each step for consistency
train_auc = AUROC(num_thresholds=200, dist_sync_on_step=True)
# During validation: sync at the end is sufficient
val_auc = AUROC(num_thresholds=200, dist_sync_on_step=False)
Q: What’s the difference between GAUC and AUC?
A: - AUC: Global computation, all samples together to compute ROC curve - GAUC: Grouped computation, first compute AUC by groups (e.g., users), then weighted average
Q: How to save and load metric states?
A: Metric objects support state saving:
# Save metric state
metric_state = auc_metric.state_dict()
torch.save(metric_state, 'metric_state.pth')
# Load metric state
auc_metric = AUROC(num_thresholds=200)
metric_state = torch.load('metric_state.pth')
auc_metric.load_state_dict(metric_state)