import os
import torch
import torch.distributed as dist
[docs]
class Gauc(torch.nn.Module):
    """Group AUC (GAUC) metric for recommendation systems and personalized ML tasks.
    This class computes the Group Area Under ROC Curve, which evaluates model
    performance by calculating AUC separately for each group (typically users)
    and then aggregating the results. This provides a more accurate assessment
    of recommendation system performance compared to global AUC metrics.
    GAUC is essential for recommendation systems because:
    - Different users have different behavior patterns and preferences
    - Global AUC can be dominated by users with many interactions
    - User-level evaluation provides better insights into model fairness
    - It's more aligned with business objectives in personalized systems
    The implementation uses optimized C++ operations for efficient computation
    and supports distributed training scenarios commonly found in large-scale
    recommendation systems.
    Attributes:
        _counts (float): Cumulative count of valid groups processed.
        _aucs (float): Cumulative weighted sum of AUC scores.
        _cpu (torch.device): CPU device for computation efficiency.
        _word_size (int): Number of distributed processes (world size).
    Example:
        Using GAUC in a recommendation model:
    .. code-block:: python
        # Initialize GAUC metric
        gauc_metric = Gauc()
        # Prepare recommendation data
        labels = torch.tensor([1, 0, 1, 0, 1, 0])  # Click labels
        predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3])  # CTR predictions
        user_ids = torch.tensor([1, 1, 2, 2, 3, 3])  # User identifiers
        # Compute GAUC
        batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids)
        print(f"Batch GAUC: {batch_gauc:.4f}")
        print(f"Cumulative GAUC: {cumulative_gauc:.4f}")
        # Continue with more batches...
        # The cumulative GAUC will be updated automatically
        Integration with training loop:
    .. code-block:: python
        model = RecommendationModel()
        gauc_metric = Gauc()
        for epoch in range(num_epochs):
            gauc_metric.reset()  # Reset for new epoch
            for batch in train_dataloader:
                # Forward pass
                logits = model(batch)
                predictions = torch.sigmoid(logits)
                # Compute GAUC
                batch_gauc, epoch_gauc = gauc_metric(
                    batch["labels"], predictions, batch["user_ids"]
                )
                # Log metrics
                if batch_idx % 100 == 0:
                    print(
                        f"Epoch {epoch}, Batch {batch_idx}: "
                        f"Batch GAUC = {batch_gauc:.4f}, "
                        f"Epoch GAUC = {epoch_gauc:.4f}"
                    )
            print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}")
    """
[docs]
    def __init__(self) -> None:
        """Initialize GAUC metric with default configuration.
        The metric automatically detects the distributed training environment
        and configures itself accordingly. It uses CPU computation for the
        GAUC calculation to optimize memory usage and computation efficiency.
        Note:
            The metric automatically reads the WORLD_SIZE environment variable
            to determine if running in distributed mode. In distributed training,
            it will aggregate results across all processes.
        """
        super().__init__()
        self._counts = 0.0
        self._aucs = 0.0
        self._cpu = torch.device("cpu")
        self._word_size = int(os.environ.get("WORLD_SIZE", 1)) 
[docs]
    def reset(self):
        """Reset all accumulated statistics to zero.
        This method clears all internal state, resetting both the cumulative
        AUC sum and count statistics. It should be called at the beginning
        of each new evaluation period (e.g., new epoch) to ensure clean metrics.
        Example:
        .. code-block:: python
            gauc_metric = Gauc()
            for epoch in range(num_epochs):
                # Reset at the beginning of each epoch
                gauc_metric.reset()
                # Process batches for the epoch
                for batch in dataloader:
                    batch_gauc, epoch_gauc = gauc_metric(
                        batch["labels"], batch["predictions"], batch["user_ids"]
                    )
                print(f"Final Epoch {epoch} GAUC: {epoch_gauc:.4f}")
        """
        self._counts = 0.0
        self._aucs = 0.0 
[docs]
    def forward(self, labels, predictions, indicators):
        """Compute GAUC for the current batch and update cumulative statistics.
        This method computes the Group AUC by calculating AUC separately for each
        group identified by the indicators (typically user IDs) and then computing
        a weighted average. It returns both the current batch GAUC and the
        cumulative GAUC across all processed batches.
        Args:
            labels (torch.Tensor): Ground truth binary labels (0 or 1).
                Shape: (N,) where N is the number of samples.
            predictions (torch.Tensor): Predicted probabilities or scores.
                Shape: (N,) where N is the number of samples.
            indicators (torch.Tensor): Group identifiers (e.g., user IDs).
                Shape: (N,) where N is the number of samples.
        Returns:
            tuple: A tuple containing:
                - batch_gauc (float): GAUC score for the current batch
                - cumulative_gauc (float): Cumulative GAUC across all processed batches
        Example:
        .. code-block:: python
            gauc_metric = Gauc()
            # Single batch computation
            labels = torch.tensor([1, 0, 1, 0, 1, 0])
            predictions = torch.tensor([0.9, 0.1, 0.8, 0.2, 0.7, 0.3])
            user_ids = torch.tensor([1, 1, 2, 2, 3, 3])
            batch_gauc, cumulative_gauc = gauc_metric(labels, predictions, user_ids)
            print(f"Batch GAUC: {batch_gauc:.4f}")
            print(f"Cumulative GAUC: {cumulative_gauc:.4f}")
            # Process another batch
            labels2 = torch.tensor([0, 1, 1, 0])
            predictions2 = torch.tensor([0.2, 0.8, 0.9, 0.1])
            user_ids2 = torch.tensor([4, 4, 5, 5])
            batch_gauc2, cumulative_gauc2 = gauc_metric(
                labels2, predictions2, user_ids2
            )
            print(f"Batch 2 GAUC: {batch_gauc2:.4f}")
            print(f"Updated Cumulative GAUC: {cumulative_gauc2:.4f}")
        Note:
            The method automatically handles distributed training by aggregating
            results across all processes when WORLD_SIZE > 1. The computation
            is performed on CPU for memory efficiency, with automatic device
            transfer handled internally.
            The GAUC calculation weights each group's AUC by the number of
            samples in that group, providing a fair aggregation across groups
            of different sizes.
        """
        with torch.no_grad():
            aucs, counts = torch.ops.recis.gauc_calc(
                labels.to(self._cpu),
                predictions.to(self._cpu),
                indicators.to(self._cpu),
            )
            aucs = aucs * counts
            sum_aucs = torch.sum(aucs)
            sum_counts = torch.sum(counts)
            if self._word_size != 1:
                reduce_val = torch.stack([sum_aucs, sum_counts]).to(labels.device)
                dist.all_reduce(reduce_val)
            else:
                reduce_val = torch.stack([sum_aucs, sum_counts])
            split_sum_auc, split_sum_count = torch.split(
                reduce_val, split_size_or_sections=1
            )
            split_sum_auc_val = split_sum_auc.item()
            split_sum_count_val = split_sum_count.item()
            self._counts += split_sum_count_val
            self._aucs += split_sum_auc_val
            batch_gauc = split_sum_auc_val / max(split_sum_count_val, 1.0)
            cumulative_gauc = self._aucs / max(self._counts, 1.0)
            return batch_gauc, cumulative_gauc