Source code for recis.nn.functional.fused_ops

from typing import List, Tuple, Union

import torch


__ALL__ = ["fused_bucketize_gpu", "fused_uint64_mod_gpu", "fused_multi_hash"]


def _check_device_all(tensors: List[torch.Tensor], device_type: str) -> None:
    """Checks that all tensors are on the specified device.

    Args:
        tensors (List[torch.Tensor]): List of tensors to check.
        device_type (str): Expected device type (e.g., 'cuda', 'cpu').
    """
    for t in tensors:
        assert t.device.type == device_type, (
            f"tensors must be on {device_type}, but got {t.device.type}"
        )


def _check_dtype_all(tensors: List[torch.Tensor], dtype: torch.dtype) -> None:
    """Checks that all tensors have the specified data type.

    Args:
        tensors (List[torch.Tensor]): List of tensors to check.
        dtype (torch.dtype): Expected data type.
    """
    for t in tensors:
        assert t.dtype == dtype, f"tensors must be {dtype}, but got {t.dtype}"


[docs] def fused_bucketize_gpu( values: List[torch.Tensor], boundaries: List[torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: """GPU-accelerated bucketization operation. Maps each value in `values` to a bucket index based on the corresponding `boundaries`. Args: values (List[torch.Tensor]): List of input tensors containing float values to be bucketized. Must be on CUDA. boundaries (List[torch.Tensor]): List of boundary tensors for bucket definitions. Each tensor must be sorted and on CUDA. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - **bucket_indices**: Tensor of bucket indices for each value. - **offsets**: Auxiliary tensor representing offsets for merging buckets. Raises: AssertionError: If input conditions are not met. Example: >>> values = [torch.tensor([1.2, 3.5, 0.8], device='cuda'), >>> torch.tensor([2.1, 4.3, 1.9], device='cuda')] >>> boundaries = [torch.tensor([1.0, 2.0, 3.0], device='cuda'), >>> torch.tensor([3.0, 4.0, 5.0], device='cuda')] >>> indices, offsets = fused_bucketize_gpu(values, boundaries) """ assert len(values) == len(boundaries), ( "values and boundaries must have the same length" ) _check_device_all(values, "cuda") _check_dtype_all(values, torch.float) _check_dtype_all(boundaries, torch.float) _check_device_all(boundaries, "cuda") return torch.ops.recis.fused_bucketized(values, boundaries)
[docs] def fused_uint64_mod_gpu( values: List[torch.Tensor], mods: Union[List, torch.Tensor] ) -> torch.Tensor: """GPU-accelerated unsigned 64-bit integer modulo operation. Args: values (List[torch.Tensor]): List of tensors containing int64 values. Must be on CUDA. mods (Union[List, torch.Tensor]): Modulo values. Can be a list or tensor of int64 values. Returns: torch.Tensor: Result tensor where each element is `(value % mod)` using unsigned interpretation. Raises: AssertionError: If input conditions are not met. Example: >>> values = [torch.tensor([10, 20, 30], dtype=torch.int64, device='cuda'), >>> torch.tensor([40, 50, 60], dtype=torch.int64, device='cuda')] >>> mods = [3, 5] >>> result = fused_uint64_mod_gpu(values, mods) """ _check_device_all(values, "cuda") _check_dtype_all(values, torch.int64) if isinstance(mods, list): mods = torch.tensor(mods, dtype=torch.int64, device=values[0].device) return torch.ops.recis.fused_uint64_mod(values, mods)
[docs] def fused_ids_encode_gpu( ids_list: List[torch.Tensor], table_ids: Union[torch.Tensor, list] ): """Encodes a list of ID tensors by applying table IDs as an offset. Args: ids_list (List[torch.Tensor]): List of ID tensors to encode. table_ids (Union[torch.Tensor, list]): Table IDs used for encoding; can be a list or tensor. Returns: torch.Tensor: Encoded ID tensor. Raises: AssertionError: If `ids_list` is not a list or if tensors in `ids_list` are not on the same device. Example: >>> ids_list = [torch.tensor([1, 2]), torch.tensor([3, 4])] >>> table_ids = [0, 1] >>> encoded_ids = ids_encode(ids_list, table_ids) """ assert isinstance(ids_list, list), "ids_list must be a list" for ids in ids_list: assert isinstance(ids, torch.Tensor), "ids must be a tensor" assert ids.device == ids_list[0].device, ( f"ids must be on the same device, {ids.device} != {ids_list[0].device}" ) if isinstance(table_ids, list): table_ids = torch.tensor( table_ids, dtype=torch.int64, device=ids_list[0].device ) return torch.ops.recis.ids_encode(ids_list, table_ids)
[docs] def fused_multi_hash( inputs: List[torch.Tensor], muls: List[torch.Tensor], primes: List[torch.Tensor], bucket_nums: List[torch.Tensor], ) -> List[torch.Tensor]: """ Fused multi hash. """ assert len(inputs) == len(muls) == len(primes) == len(bucket_nums) assert len(inputs) > 0 device = inputs[0].device _check_device_all(inputs, device.type) _check_dtype_all(inputs, torch.int64) _check_dtype_all(muls, torch.int64) _check_dtype_all(primes, torch.int64) _check_dtype_all(bucket_nums, torch.int64) return torch.ops.recis.fused_multi_hash(inputs, muls, primes, bucket_nums)