Source code for recis.metrics.auroc

import torch
import torch.distributed as dist
from torch import nn


[docs] class AUROC(nn.Module): """Area Under ROC Curve (AUROC) metric for binary classification. This class computes the Area Under the Receiver Operating Characteristic (ROC) Curve, which measures the ability of a binary classifier to distinguish between positive and negative classes across all classification thresholds. The implementation uses confusion matrices computed at multiple thresholds for efficient and accurate AUC calculation. The AUROC metric is particularly useful for: - Binary classification tasks in recommendation systems - Click-through rate (CTR) prediction - Conversion rate optimization - Any binary classification where class balance matters Attributes: num_thresholds (int): Number of thresholds used for ROC curve computation. dist_sync_on_step (bool): Whether to synchronize across devices on each update. thresholds (torch.Tensor): Threshold values used for classification decisions. tp (nn.Parameter): True positive counts at each threshold. fp (nn.Parameter): False positive counts at each threshold. tn (nn.Parameter): True negative counts at each threshold. fn (nn.Parameter): False negative counts at each threshold. Example: Creating and using AUROC metric: .. code-block:: python # Initialize with custom configuration auc_metric = AUROC( num_thresholds=100, # Use 100 thresholds for ROC curve dist_sync_on_step=False, # Sync only when computing final result ) # Batch processing predictions = torch.tensor([0.9, 0.7, 0.3, 0.1]) labels = torch.tensor([1, 1, 0, 0]) # Update metric state auc_metric.update(predictions, labels) # Compute AUC auc_score = auc_metric.compute() print(f"AUC: {auc_score:.4f}") # Direct computation (alternative to update + compute) direct_auc = auc_metric(predictions, labels) """
[docs] def __init__(self, num_thresholds=200, dist_sync_on_step=False): """Initialize AUROC metric with specified configuration. Args: num_thresholds (int, optional): Number of thresholds to use for ROC curve computation. Must be greater than 2. Defaults to 200. dist_sync_on_step (bool, optional): Whether to synchronize metric state across distributed processes on each update step. If False, synchronization only occurs during compute(). Defaults to False. Raises: AssertionError: If num_thresholds is not greater than 2. Note: Higher num_thresholds values provide more accurate AUC computation but require more memory and computation. The thresholds are evenly distributed between 0 and 1 with small epsilon values at the boundaries. """ super().__init__() assert num_thresholds > 2, "num_thresholds must be > 2" self.num_thresholds = num_thresholds self.dist_sync_on_step = dist_sync_on_step # Small epsilon to handle boundary cases kepsilon = 1e-7 # Create evenly spaced thresholds between 0 and 1 thresholds = [ (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) ] self.thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] self.thresholds = torch.tensor(self.thresholds) # Initialize confusion matrix components as parameters self.tp = nn.Parameter( torch.zeros(num_thresholds, dtype=torch.long), requires_grad=False ) self.fp = nn.Parameter( torch.zeros(num_thresholds, dtype=torch.long), requires_grad=False ) self.tn = nn.Parameter( torch.zeros(num_thresholds, dtype=torch.long), requires_grad=False ) self.fn = nn.Parameter( torch.zeros(num_thresholds, dtype=torch.long), requires_grad=False )
def _confusion_matrix_at_thresholds(self, predictions, labels): """Compute confusion matrix components at all thresholds. This method efficiently computes true positives, false positives, true negatives, and false negatives for all configured thresholds simultaneously using vectorized operations. Args: predictions (torch.Tensor): Predicted probabilities in range [0, 1]. Shape: (N,) where N is the number of samples. labels (torch.Tensor): Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples. Returns: tuple: A tuple containing: - tp (torch.Tensor): True positive counts for each threshold. Shape: (num_thresholds,) - fp (torch.Tensor): False positive counts for each threshold. Shape: (num_thresholds,) - tn (torch.Tensor): True negative counts for each threshold. Shape: (num_thresholds,) - fn (torch.Tensor): False negative counts for each threshold. Shape: (num_thresholds,) Raises: AssertionError: If predictions are not in the range [0, 1]. Note: This method uses efficient tensor operations to compute confusion matrices for all thresholds simultaneously, avoiding expensive loops. """ assert torch.all(torch.logical_and(predictions >= 0.0, predictions <= 1.0)), ( "predictions must be in [0, 1]" ) predictions_1d = predictions.view(-1) labels_1d = labels.to(dtype=torch.bool).view(-1) self.thresholds = self.thresholds.to(predictions.device) # Compute predictions > threshold for all thresholds pred_is_pos = predictions_1d.unsqueeze(-1) > self.thresholds # Transpose to get shape (num_thresholds, num_samples) pred_is_pos = pred_is_pos.t() pred_is_neg = torch.logical_not(pred_is_pos) label_is_pos = labels_1d.repeat(self.num_thresholds, 1) label_is_neg = torch.logical_not(label_is_pos) # Compute confusion matrix components is_true_positive = torch.logical_and(label_is_pos, pred_is_pos) is_true_negative = torch.logical_and(label_is_neg, pred_is_neg) is_false_positive = torch.logical_and(label_is_neg, pred_is_pos) is_false_negative = torch.logical_and(label_is_pos, pred_is_neg) # Sum across samples for each threshold tp = is_true_positive.sum(1) fn = is_false_negative.sum(1) tn = is_true_negative.sum(1) fp = is_false_positive.sum(1) return tp, fp, tn, fn def _compute_auroc(self, tp, fp, tn, fn): """Compute AUROC from confusion matrix components. This method calculates the Area Under the ROC Curve using the trapezoidal rule for numerical integration. The ROC curve is defined by true positive rate (TPR) vs false positive rate (FPR) at different thresholds. Args: tp (torch.Tensor): True positive counts for each threshold. fp (torch.Tensor): False positive counts for each threshold. tn (torch.Tensor): True negative counts for each threshold. fn (torch.Tensor): False negative counts for each threshold. Returns: torch.Tensor: AUROC score as a scalar tensor. Note: Uses small epsilon values to prevent division by zero and ensure numerical stability. The trapezoidal rule provides accurate AUC approximation when sufficient thresholds are used. """ epsilon = 1.0e-6 # Compute True Positive Rate (Recall/Sensitivity) rec = torch.div(tp + epsilon, tp + fn + epsilon) # Compute False Positive Rate (1 - Specificity) fp_rate = torch.div(fp, fp + tn + epsilon) x = fp_rate y = rec # Compute AUC using trapezoidal rule auc = torch.multiply( x[: self.num_thresholds - 1] - x[1:], (y[: self.num_thresholds - 1] + y[1:]) / 2.0, ).sum() return auc
[docs] def forward(self, predictions, labels): """Compute AUROC directly from predictions and labels. This method provides a direct way to compute AUROC without updating the internal state. It's useful for one-time computations or when you don't need to accumulate statistics across multiple batches. Args: predictions (torch.Tensor): Predicted probabilities in range [0, 1]. Shape: (N,) where N is the number of samples. labels (torch.Tensor): Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples. Returns: torch.Tensor: AUROC score as a scalar tensor. Example: .. code-block:: python auc_metric = AUROC(num_thresholds=100) # Direct computation predictions = torch.tensor([0.9, 0.7, 0.3, 0.1]) labels = torch.tensor([1, 1, 0, 0]) auc_score = auc_metric(predictions, labels) print(f"AUC: {auc_score:.4f}") """ tp, fp, tn, fn = self._confusion_matrix_at_thresholds(predictions, labels) return self._compute_auroc(tp, fp, tn, fn)
[docs] def update(self, predictions, labels): """Update metric state with new predictions and labels. This method accumulates confusion matrix statistics from the current batch with previously seen data. It's designed for incremental updates during training where you want to compute metrics across multiple batches. Args: predictions (torch.Tensor): Predicted probabilities in range [0, 1]. Shape: (N,) where N is the number of samples. labels (torch.Tensor): Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples. Example: .. code-block:: python auc_metric = AUROC(num_thresholds=200, dist_sync_on_step=True) # Process multiple batches for batch in dataloader: preds = model(batch) labels = batch["labels"] # Accumulate statistics auc_metric.update(preds, labels) # Get final result final_auc = auc_metric.compute() Note: If dist_sync_on_step is True, this method will synchronize statistics across all distributed processes, which may impact performance but ensures consistency in distributed training. """ tp, fp, tn, fn = self._confusion_matrix_at_thresholds(predictions, labels) # Synchronize across distributed processes if required if self.dist_sync_on_step: tp, fp, tn, fn = self.sync(tp, fp, tn, fn) # Accumulate statistics self.tp += tp self.fp += fp self.tn += tn self.fn += fn
def sync(self, tp, fp, tn, fn): """Synchronize confusion matrix statistics across distributed processes. This method aggregates confusion matrix components from all distributed processes using all-reduce operations. It's essential for consistent metric computation in distributed training scenarios. Args: tp (torch.Tensor): True positive counts to synchronize. fp (torch.Tensor): False positive counts to synchronize. tn (torch.Tensor): True negative counts to synchronize. fn (torch.Tensor): False negative counts to synchronize. Returns: tuple: Synchronized confusion matrix components: - tp (torch.Tensor): Synchronized true positive counts - fp (torch.Tensor): Synchronized false positive counts - tn (torch.Tensor): Synchronized true negative counts - fn (torch.Tensor): Synchronized false negative counts Note: This method requires PyTorch distributed training to be properly initialized. It uses SUM reduction to aggregate counts across processes. """ # Concatenate all statistics for efficient communication state = torch.cat([tp, fp, tn, fn], dim=0) # Perform all-reduce sum across all processes dist.all_reduce(state, op=dist.ReduceOp.SUM) # Split back into individual components tp, fp, tn, fn = state.split( [self.tp.numel(), self.fp.numel(), self.tn.numel(), self.fn.numel()], dim=0 ) return tp, fp, tn, fn
[docs] def compute(self): """Compute final AUROC score from accumulated statistics. This method calculates the AUROC using all statistics accumulated through previous update() calls. It's typically called at the end of an epoch or evaluation period to get the final metric value. Returns: torch.Tensor: AUROC score as a scalar tensor. Example: .. code-block:: python auc_metric = AUROC() # Accumulate data from multiple batches for batch in dataloader: auc_metric.update(model(batch), batch["labels"]) # Get final AUC score final_auc = auc_metric.compute() print(f"Epoch AUC: {final_auc:.4f}") Note: This method uses the current accumulated state (tp, fp, tn, fn) to compute the final AUROC. Make sure to call reset() before starting a new evaluation period. """ return self._compute_auroc(self.tp, self.fp, self.tn, self.fn)
[docs] def reset(self): """Reset all accumulated statistics to zero. This method clears all internal state, setting all confusion matrix components back to zero. It should be called at the beginning of each new evaluation period (e.g., new epoch) to ensure clean statistics. Example: .. code-block:: python auc_metric = AUROC() for epoch in range(num_epochs): # Reset at the beginning of each epoch auc_metric.reset() # Accumulate statistics for the epoch for batch in dataloader: auc_metric.update(model(batch), batch["labels"]) # Get epoch result epoch_auc = auc_metric.compute() print(f"Epoch {epoch} AUC: {epoch_auc:.4f}") Note: This method modifies the internal parameter tensors in-place using zero_() for efficiency. """ self.tp.zero_() self.fp.zero_() self.tn.zero_() self.fn.zero_()