from typing import List, Tuple, Union
import torch
__ALL__ = ["fused_bucketize_gpu", "fused_uint64_mod_gpu", "fused_multi_hash"]
def _check_device_all(tensors: List[torch.Tensor], device_type: str) -> None:
"""Checks that all tensors are on the specified device.
Args:
tensors (List[torch.Tensor]): List of tensors to check.
device_type (str): Expected device type (e.g., 'cuda', 'cpu').
"""
for t in tensors:
assert t.device.type == device_type, (
f"tensors must be on {device_type}, but got {t.device.type}"
)
def _check_dtype_all(tensors: List[torch.Tensor], dtype: torch.dtype) -> None:
"""Checks that all tensors have the specified data type.
Args:
tensors (List[torch.Tensor]): List of tensors to check.
dtype (torch.dtype): Expected data type.
"""
for t in tensors:
assert t.dtype == dtype, f"tensors must be {dtype}, but got {t.dtype}"
[docs]
def fused_bucketize_gpu(
values: List[torch.Tensor], boundaries: List[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""GPU-accelerated bucketization operation. Maps each value in `values` to a bucket index based on the corresponding `boundaries`.
Args:
values (List[torch.Tensor]): List of input tensors containing float values to be bucketized. Must be on CUDA.
boundaries (List[torch.Tensor]): List of boundary tensors for bucket definitions. Each tensor must be sorted and on CUDA.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- **bucket_indices**: Tensor of bucket indices for each value.
- **offsets**: Auxiliary tensor representing offsets for merging buckets.
Raises:
AssertionError: If input conditions are not met.
Example:
>>> values = [torch.tensor([1.2, 3.5, 0.8], device='cuda'),
>>> torch.tensor([2.1, 4.3, 1.9], device='cuda')]
>>> boundaries = [torch.tensor([1.0, 2.0, 3.0], device='cuda'),
>>> torch.tensor([3.0, 4.0, 5.0], device='cuda')]
>>> indices, offsets = fused_bucketize_gpu(values, boundaries)
"""
assert len(values) == len(boundaries), (
"values and boundaries must have the same length"
)
_check_device_all(values, "cuda")
_check_dtype_all(values, torch.float)
_check_dtype_all(boundaries, torch.float)
_check_device_all(boundaries, "cuda")
return torch.ops.recis.fused_bucketized(values, boundaries)
[docs]
def fused_uint64_mod_gpu(
values: List[torch.Tensor], mods: Union[List, torch.Tensor]
) -> torch.Tensor:
"""GPU-accelerated unsigned 64-bit integer modulo operation.
Args:
values (List[torch.Tensor]): List of tensors containing int64 values. Must be on CUDA.
mods (Union[List, torch.Tensor]): Modulo values. Can be a list or tensor of int64 values.
Returns:
torch.Tensor: Result tensor where each element is `(value % mod)` using unsigned interpretation.
Raises:
AssertionError: If input conditions are not met.
Example:
>>> values = [torch.tensor([10, 20, 30], dtype=torch.int64, device='cuda'),
>>> torch.tensor([40, 50, 60], dtype=torch.int64, device='cuda')]
>>> mods = [3, 5]
>>> result = fused_uint64_mod_gpu(values, mods)
"""
_check_device_all(values, "cuda")
_check_dtype_all(values, torch.int64)
if isinstance(mods, list):
mods = torch.tensor(mods, dtype=torch.int64, device=values[0].device)
return torch.ops.recis.fused_uint64_mod(values, mods)
[docs]
def fused_ids_encode_gpu(
ids_list: List[torch.Tensor], table_ids: Union[torch.Tensor, list]
):
"""Encodes a list of ID tensors by applying table IDs as an offset.
Args:
ids_list (List[torch.Tensor]): List of ID tensors to encode.
table_ids (Union[torch.Tensor, list]): Table IDs used for encoding; can be a list or tensor.
Returns:
torch.Tensor: Encoded ID tensor.
Raises:
AssertionError: If `ids_list` is not a list or if tensors in `ids_list` are not on the same device.
Example:
>>> ids_list = [torch.tensor([1, 2]), torch.tensor([3, 4])]
>>> table_ids = [0, 1]
>>> encoded_ids = ids_encode(ids_list, table_ids)
"""
assert isinstance(ids_list, list), "ids_list must be a list"
for ids in ids_list:
assert isinstance(ids, torch.Tensor), "ids must be a tensor"
assert ids.device == ids_list[0].device, (
f"ids must be on the same device, {ids.device} != {ids_list[0].device}"
)
if isinstance(table_ids, list):
table_ids = torch.tensor(
table_ids, dtype=torch.int64, device=ids_list[0].device
)
return torch.ops.recis.ids_encode(ids_list, table_ids)
[docs]
def fused_multi_hash(
inputs: List[torch.Tensor],
muls: List[torch.Tensor],
primes: List[torch.Tensor],
bucket_nums: List[torch.Tensor],
) -> List[torch.Tensor]:
"""
Fused multi hash.
"""
assert len(inputs) == len(muls) == len(primes) == len(bucket_nums)
assert len(inputs) > 0
device = inputs[0].device
_check_device_all(inputs, device.type)
_check_dtype_all(inputs, torch.int64)
_check_dtype_all(muls, torch.int64)
_check_dtype_all(primes, torch.int64)
_check_dtype_all(bucket_nums, torch.int64)
return torch.ops.recis.fused_multi_hash(inputs, muls, primes, bucket_nums)