Source code for recis.metrics.gauc

import os

import torch
import torch.distributed as dist


[docs] class Gauc(torch.nn.Module): """Group AUC (GAUC) metric for recommendation systems and personalized ML tasks. This class computes the Group Area Under ROC Curve, which evaluates model performance by calculating AUC separately for each group (typically users) and then aggregating the results. This provides a more accurate assessment of recommendation system performance compared to global AUC metrics. GAUC is essential for recommendation systems because: - Different users have different behavior patterns and preferences - Global AUC can be dominated by users with many interactions - User-level evaluation provides better insights into model fairness - It's more aligned with business objectives in personalized systems The implementation uses optimized C++ operations for efficient computation and supports distributed training scenarios commonly found in large-scale recommendation systems. Attributes: _counts (float): Cumulative count of valid groups processed. _aucs (float): Cumulative weighted sum of AUC scores. _cpu (torch.device): CPU device for computation efficiency. _word_size (int): Number of distributed processes (world size). Example: Using GAUC in a recommendation model: .. code-block:: python # Initialize GAUC metric gauc_metric = Gauc() # Prepare recommendation data labels = torch.tensor([1, 0, 1, 0, 1, 0]) # Click labels predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3]) # CTR predictions user_ids = torch.tensor([1, 1, 2, 2, 3, 3]) # User identifiers # Compute GAUC batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids) print(f"Batch GAUC: {batch_gauc:.4f}") print(f"Cumulative GAUC: {cumulative_gauc:.4f}") # Continue with more batches... # The cumulative GAUC will be updated automatically Integration with training loop: .. code-block:: python model = RecommendationModel() gauc_metric = Gauc() for epoch in range(num_epochs): gauc_metric.reset() # Reset for new epoch for batch in train_dataloader: # Forward pass logits = model(batch) predictions = torch.sigmoid(logits) # Compute GAUC batch_gauc, epoch_gauc = gauc_metric( batch["labels"], predictions, batch["user_ids"] ) # Log metrics if batch_idx % 100 == 0: print( f"Epoch {epoch}, Batch {batch_idx}: " f"Batch GAUC = {batch_gauc:.4f}, " f"Epoch GAUC = {epoch_gauc:.4f}" ) print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}") """
[docs] def __init__(self) -> None: """Initialize GAUC metric with default configuration. The metric automatically detects the distributed training environment and configures itself accordingly. It uses CPU computation for the GAUC calculation to optimize memory usage and computation efficiency. Note: The metric automatically reads the WORLD_SIZE environment variable to determine if running in distributed mode. In distributed training, it will aggregate results across all processes. """ super().__init__() self._counts = 0.0 self._aucs = 0.0 self._cpu = torch.device("cpu") self._word_size = int(os.environ.get("WORLD_SIZE", 1))
[docs] def reset(self): """Reset all accumulated statistics to zero. This method clears all internal state, resetting both the cumulative AUC sum and count statistics. It should be called at the beginning of each new evaluation period (e.g., new epoch) to ensure clean metrics. Example: .. code-block:: python gauc_metric = Gauc() for epoch in range(num_epochs): # Reset at the beginning of each epoch gauc_metric.reset() # Process batches for the epoch for batch in dataloader: batch_gauc, epoch_gauc = gauc_metric( batch["labels"], batch["predictions"], batch["user_ids"] ) print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}") """ self._counts = 0.0 self._aucs = 0.0
[docs] def forward(self, labels, predictions, indicators): """Compute GAUC for the current batch and update cumulative statistics. This method computes the Group AUC by calculating AUC separately for each group identified by the indicators (typically user IDs) and then computing a weighted average. It returns both the current batch GAUC and the cumulative GAUC across all processed batches. Args: labels (torch.Tensor): Ground truth binary labels (0 or 1). Shape: (N,) where N is the number of samples. predictions (torch.Tensor): Predicted probabilities or scores. Shape: (N,) where N is the number of samples. indicators (torch.Tensor): Group identifiers (e.g., user IDs). Shape: (N,) where N is the number of samples. Returns: tuple: A tuple containing: - batch_gauc (float): GAUC score for the current batch - cumulative_gauc (float): Cumulative GAUC across all processed batches Example: .. code-block:: python gauc_metric = Gauc() # Single batch computation labels = torch.tensor([1, 0, 1, 0, 1, 0]) predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3]) user_ids = torch.tensor([1, 1, 2, 2, 3, 3]) batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids) print(f"Batch GAUC: {batch_gauc:.4f}") print(f"Cumulative GAUC: {cumulative_gauc:.4f}") # Process another batch labels2 = torch.tensor([0, 1, 1, 0]) predictions2 = torch.tensor([0.2, 0.8, 0.9, 0.1]) user_ids2 = torch.tensor([4, 4, 5, 5]) batch_gauc2, cumulative_gauc2 = gauc_metric( labels2, predictions2, user_ids2 ) print(f"Batch 2 GAUC: {batch_gauc2:.4f}") print(f"Updated Cumulative GAUC: {cumulative_gauc2:.4f}") Note: The method automatically handles distributed training by aggregating results across all processes when WORLD_SIZE > 1. The computation is performed on CPU for memory efficiency, with automatic device transfer handled internally. The GAUC calculation weights each group's AUC by the number of samples in that group, providing a fair aggregation across groups of different sizes. """ with torch.no_grad(): aucs, counts = torch.ops.recis.gauc_calc( labels.to(self._cpu), predictions.to(self._cpu), indicators.to(self._cpu), ) aucs = aucs * counts sum_aucs = torch.sum(aucs) sum_counts = torch.sum(counts) if self._word_size != 1: reduce_val = torch.stack([sum_aucs, sum_counts]).to(labels.device) dist.all_reduce(reduce_val) else: reduce_val = torch.stack([sum_aucs, sum_counts]) split_sum_auc, split_sum_count = torch.split( reduce_val, split_size_or_sections=1 ) split_sum_auc_val = split_sum_auc.item() split_sum_count_val = split_sum_count.item() self._counts += split_sum_count_val self._aucs += split_sum_auc_val batch_gauc = split_sum_auc_val / max(split_sum_count_val, 1.0) cumulative_gauc = self._aucs / max(self._counts, 1.0) return batch_gauc, cumulative_gauc