import json
import os
from typing import List, Optional, Tuple
import torch
from recis.common.singleton import SingletonMeta
from recis.nn.hashtable_hook import AdmitHook, FilterHook
from recis.nn.initializers import ConstantInitializer
from recis.nn.modules.hashtable_hook_impl import HashtableHookFactory, ReadOnlyHookImpl
from recis.utils.logger import Logger
logger = Logger(__name__)
class Slice:
"""Partitioning configuration for distributed hash table storage.
This class defines how the hash table's key space is partitioned
across different workers in a distributed setting.
Attributes:
slice_begin (int): Starting index of the slice.
slice_end (int): Ending index of the slice (exclusive).
slice_size (int): Total size of the key space.
Example:
.. code-block:: python
# Create a slice for worker 0 out of 4 workers
slice_config = Slice(0, 16384, 65536)
"""
def __init__(self, slice_beg, slice_end, slice_size) -> None:
"""Initialize slice configuration.
Args:
slice_beg (int): Starting index of the slice.
slice_end (int): Ending index of the slice (exclusive).
slice_size (int): Total size of the key space.
"""
self.slice_begin = slice_beg
self.slice_end = slice_end
self.slice_size = slice_size
[docs]
def gen_slice(shard_index=0, shard_num=1, slice_size=65536):
"""Generate slice configuration for distributed hash table partitioning.
This function creates a Slice object that defines how to partition
the hash table's key space across multiple workers. It ensures
balanced distribution with proper handling of remainder keys.
Args:
shard_index (int, optional): Index of the current shard/worker.
Defaults to 0.
shard_num (int, optional): Total number of shards/workers.
Defaults to 1.
slice_size (int, optional): Total size of the key space.
Defaults to 65536.
Returns:
Slice: Slice configuration for the specified shard.
Example:
.. code-block:: python
# Generate slice for worker 1 out of 4 workers
slice_config = gen_slice(shard_index=1, shard_num=4, slice_size=65536)
print(
f"Worker 1 handles keys from {slice_config.slice_begin} "
f"to {slice_config.slice_end}"
)
"""
shard_slice_size = slice_size // shard_num
shard_slice_sizes = [shard_slice_size] * shard_num
remain = slice_size % shard_num
shard_slice_sizes = [
size + 1 if i < remain else size for i, size in enumerate(shard_slice_sizes)
]
slice_infos = []
beg = 0
for size in shard_slice_sizes:
end = beg + size
slice_infos.append((beg, end))
beg = end
slice_info = slice_infos[shard_index]
return Slice(slice_info[0], slice_info[1], slice_size)
_default_slice = Slice(0, 65536, 65536)
[docs]
class HashTable(torch.nn.Module):
"""Distributed hash table for sparse parameter storage and lookup.
This module provides a distributed hash table implementation that supports
dynamic sparse parameter storage, efficient lookup operations, and gradient
computation. It's designed for large-scale sparse learning scenarios where
the feature vocabulary can grow dynamically.
Key features:
- Dynamic feature admission and eviction
- Distributed storage across multiple workers
- Efficient gradient computation and aggregation
- Support for various initialization strategies
- Hook-based filtering and admission control
Example:
Basic usage:
.. code-block:: python
import torch
from recis.nn.modules.hashtable import HashTable
# Create hash table
hashtable = HashTable(
embedding_shape=[64],
block_size=1024,
dtype=torch.float32,
device=torch.device("cuda"),
name="user_embedding",
)
# Lookup embeddings
ids = torch.tensor([1, 2, 3, 100, 1000])
embeddings = hashtable(ids) # Shape: [5, 64]
Advanced usage with hooks:
.. code-block:: python
from recis.nn.hashtable_hook import FrequencyFilterHook
# Create hash table with filtering
filter_hook = FrequencyFilterHook(min_frequency=5)
hashtable = HashTable(
embedding_shape=[128],
block_size=2048,
filter_hook=filter_hook,
grad_reduce_by="id",
)
"""
[docs]
def __init__(
self,
embedding_shape: List,
block_size: int = 5,
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
coalesced: bool = False,
children: Optional[List[str]] = None,
slice: Slice = _default_slice,
initializer=None,
name: str = "hashtable",
grad_reduce_by: str = "worker",
filter_hook: Optional[FilterHook] = None,
):
"""Initialize hash table module.
Args:
embedding_shape (List[int]): Shape of embedding vectors.
block_size (int, optional): Number of embeddings per block. Defaults to 5.
dtype (torch.dtype, optional): Data type. Defaults to torch.float32.
device (torch.device, optional): Computation device. Defaults to CPU.
coalesced (bool, optional): Use coalesced operations. Defaults to False.
children (Optional[List[str]], optional): Child table names. Defaults to None.
slice (Slice, optional): Partitioning config. Defaults to _default_slice.
initializer (Initializer, optional): Initializer. Defaults to None.
name (str, optional): Table name. Defaults to "hashtable".
grad_reduce_by (str, optional): Gradient reduction. Defaults to "worker".
filter_hook (Optional[FilterHook], optional): Filter hook. Defaults to None.
Raises:
AssertionError: If grad_reduce_by is not "id" or "worker".
"""
super().__init__()
if initializer is None:
self._initializer = ConstantInitializer(init_val=0)
else:
self._initializer = initializer
if children is None:
children = [name]
self._device = device
assert grad_reduce_by in ["id", "worker"]
self._grad_reduce_by = grad_reduce_by
self._initializer.set_shape([block_size] + embedding_shape)
self._initializer.set_dtype(dtype)
self._initializer.build()
self._dtype = dtype
self._name = name
for child in children:
info_str = json.dumps(
dict(
shape=embedding_shape,
dtype=str(dtype),
initializer=str(self._initializer),
)
)
HashtableRegister().register(child, info_str)
self._hashtable_impl = torch.ops.recis.make_hashtable(
block_size,
embedding_shape,
dtype,
device,
coalesced,
children,
self._initializer.impl(),
slice.slice_begin,
slice.slice_end,
slice.slice_size,
)
self._backward_holder = torch.tensor([0.0], requires_grad=True)
self._worker_num = int(os.environ.get("WORLD_SIZE", 1))
def state_dict_hook(
self: HashTable, state_dict: dict, prefix: str, local_metadata
):
state_dict[self._name] = self._hashtable_impl
self._register_state_dict_hook(state_dict_hook)
# TODO (sunhechen.shc) support more filter hook
if filter_hook is not None:
self._filter_hook_impl = HashtableHookFactory().create_filter_hook(
self, filter_hook
)
else:
self._filter_hook_impl = torch.nn.Identity()
[docs]
def forward(self, ids: torch.Tensor, admit_hook: AdmitHook = None) -> torch.Tensor:
"""Perform embedding lookup for given feature IDs.
This method looks up embeddings for the provided feature IDs,
handling deduplication, gradient computation, and optional
feature admission control.
Args:
ids (torch.Tensor): Feature IDs to lookup. Shape: [N] where N
is the number of features.
admit_hook (AdmitHook, optional): Hook for controlling feature
admission. Defaults to None.
Returns:
torch.Tensor: Looked up embeddings. Shape: [N, embedding_dim]
where embedding_dim is determined by embedding_shape.
Example:
.. code-block:: python
# Basic lookup
ids = torch.tensor([1, 2, 3, 2, 1]) # Note: duplicates
embeddings = hashtable(ids) # Shape: [5, embedding_dim]
# With admission hook
from recis.nn.hashtable_hook import FrequencyAdmitHook
admit_hook = FrequencyAdmitHook(min_frequency=3)
embeddings = hashtable(ids, admit_hook)
"""
admit_hook_impl = (
HashtableHookFactory().create_admit_hook(self, admit_hook)
if admit_hook
else None
)
ids, index = ids.unique(return_inverse=True)
if self.training and self._dtype not in (torch.int8, torch.int32, torch.int64):
emb_idx, embedding = HashTableLookupHelpFunction.apply(
ids, self._hashtable_impl, self._backward_holder, admit_hook_impl
)
if self._grad_reduce_by == "id":
embedding = GradIDMeanFunction.apply(embedding, index)
else:
slice_num = torch.scalar_tensor(self._worker_num)
embedding = GradWorkerMeanFunction.apply(embedding, index, slice_num)
self._filter_hook_impl(emb_idx)
else:
ids = ids.detach()
_, embedding = self._hashtable_impl.embedding_lookup(ids, True)
embedding = embedding.cuda()
embedding = torch.ops.recis.gather(index, embedding)
return embedding
def initializer(self):
"""Get the embedding initializer.
Returns:
Initializer: The initializer used for new embeddings.
"""
return self._initializer
@property
def device(self):
"""Get the computation device.
Returns:
torch.device: The device used for computation.
"""
return self._device
@property
def coalesce(self):
"""Check if coalesced operations are enabled.
Returns:
bool: True if coalesced operations are enabled.
"""
return self._hashtable_impl.children_info().is_coalesce()
@property
def children_hashtable(self):
"""Get the list of child hash tables.
Returns:
List[str]: Names of child hash tables.
"""
return self._hashtable_impl.children_info().children()
[docs]
def accept_grad(self, grad_index, grad) -> None:
"""Accept gradients for specific embedding indices.
Args:
grad_index (torch.Tensor): Indices of embeddings to update.
grad (torch.Tensor): Gradient values for the embeddings.
"""
self._hashtable_impl.accept_grad(grad_index, grad)
[docs]
def grad(self, acc_step=1) -> torch.Tensor:
"""Get accumulated gradients.
Args:
acc_step (int, optional): Accumulation step. Defaults to 1.
Returns:
torch.Tensor: Accumulated gradients.
"""
return self._hashtable_impl.grad(acc_step)
[docs]
def clear_grad(self) -> None:
"""Clear accumulated gradients."""
self._hashtable_impl.clear_grad()
[docs]
def insert(self, ids, embeddings) -> None:
"""Insert embeddings for specific IDs.
Args:
ids (torch.Tensor): Feature IDs to insert.
embeddings (torch.Tensor): Embedding values to insert.
"""
self._hashtable_impl.insert(ids, embeddings)
def clear(self) -> None:
"""Clear all stored embeddings."""
self._hashtable_impl.clear()
[docs]
def ids(self) -> torch.Tensor:
"""Get all stored feature IDs.
Returns:
torch.Tensor: All feature IDs currently stored in the table.
"""
return self._hashtable_impl.ids()
def ids_map(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mapping between feature IDs and internal indices.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (feature_ids, internal_indices)
"""
return self._hashtable_impl.ids_map()
[docs]
def embeddings(self) -> torch.Tensor:
"""Get all stored embeddings.
Returns:
torch.Tensor: All embedding values currently stored in the table.
"""
return self._hashtable_impl.slot_group().slot_by_name("embedding").value()
[docs]
def slot_group(self):
"""Get the slot group for advanced operations.
Returns:
SlotGroup: The slot group containing all storage slots.
"""
return self._hashtable_impl.slot_group()
[docs]
def children_info(self):
"""Get information about child hash tables.
Returns:
ChildrenInfo: Information about child hash tables.
"""
return self._hashtable_impl.children_info()
def __str__(self) -> str:
"""String representation of the hash table.
Returns:
str: String representation including the table name.
"""
return f"HashTable_{self._name}"
def __repr__(self) -> str:
"""Detailed string representation of the hash table.
Returns:
str: Detailed string representation.
"""
return self.__str__()
class HashTableLookupHelpFunction(torch.autograd.Function):
"""Autograd function for hash table embedding lookup with gradient support.
This function provides the forward and backward passes for embedding
lookup operations, handling gradient computation and admission hooks.
"""
@staticmethod
def forward(
ctx,
ids: torch.Tensor,
hashtable: object,
backward_holder: torch.Tensor,
admit_hook_impl,
) -> torch.Tensor:
"""Forward pass for embedding lookup.
Args:
ctx: Autograd context for storing information.
ids (torch.Tensor): Feature IDs to lookup.
hashtable (torch.classes.recis.HashtableImpl): Hash table implementation.
backward_holder (torch.Tensor): Tensor for gradient computation.
admit_hook_impl: Implementation of admission hook.
Returns:
Tuple[torch.Tensor, torch.Tensor]: (indices, embeddings)
Raises:
AssertionError: If admit_hook_impl is not None or ReadOnlyHookImpl.
"""
assert admit_hook_impl is None or isinstance(
admit_hook_impl, ReadOnlyHookImpl
), f"admit hook only support ReadOnlyHook yet, but got: {admit_hook_impl}"
ids = ids.detach()
index, embedding = hashtable.embedding_lookup(ids, admit_hook_impl is not None)
ctx.save_for_backward(index)
ctx.hashtable = hashtable
return index.to(device="cuda"), embedding.to(device="cuda")
@staticmethod
def backward(ctx, grad_output_index, grad_output_emb) -> torch.Tensor:
"""Backward pass for embedding lookup.
Args:
ctx: Autograd context containing forward pass information.
grad_output_index: Gradient for indices (unused).
grad_output_emb (torch.Tensor): Gradient for embeddings.
Returns:
Tuple: Gradients for all inputs (most are None).
"""
(index,) = ctx.saved_tensors
ctx.hashtable.accept_grad(
index.to(device="cuda"),
grad_output_emb.to(device="cuda"),
)
return (None, None, None, None, None)
class GradIDMeanFunction(torch.autograd.Function):
"""Autograd function for gradient aggregation by feature ID.
This function handles gradient computation when using ID-based
gradient reduction, ensuring proper gradient flow for duplicate IDs.
"""
@staticmethod
def forward(ctx, embedding, index):
"""Forward pass for ID-based gradient aggregation.
Args:
ctx: Autograd context.
embedding (torch.Tensor): Input embeddings.
index (torch.Tensor): Index mapping for gathering.
Returns:
torch.Tensor: Gathered embeddings.
"""
ctx.save_for_backward(index)
return torch.ops.recis.gather(index, embedding)
@staticmethod
def backward(ctx, grad_outputs):
"""Backward pass for ID-based gradient aggregation.
Args:
ctx: Autograd context.
grad_outputs (torch.Tensor): Output gradients.
Returns:
Tuple[torch.Tensor, None]: (reduced_gradients, None)
"""
grad_outputs = grad_outputs.cuda()
(index,) = ctx.saved_tensors
if index.numel() == 0:
return (
torch.zeros(
[0] + list(grad_outputs.shape)[1:], device=grad_outputs.device
),
None,
)
shape = [index.max() + 1] + list(grad_outputs.shape)[1:]
reduce_grad = torch.zeros(shape, device=grad_outputs.device)
reduce_grad.index_reduce_(0, index, grad_outputs, "mean", include_self=False)
return reduce_grad, None
class GradWorkerMeanFunction(torch.autograd.Function):
"""Autograd function for gradient aggregation by worker.
This function handles gradient computation when using worker-based
gradient reduction, distributing gradients across multiple workers.
"""
@staticmethod
def forward(ctx, embedding, index, slice_num):
"""Forward pass for worker-based gradient aggregation.
Args:
ctx: Autograd context.
embedding (torch.Tensor): Input embeddings.
index (torch.Tensor): Index mapping for gathering.
slice_num (torch.Tensor): Number of worker slices.
Returns:
torch.Tensor: Gathered embeddings.
"""
ctx.save_for_backward(index, slice_num)
return torch.ops.recis.gather(index, embedding)
@staticmethod
def backward(ctx, grad_outputs):
"""Backward pass for worker-based gradient aggregation.
Args:
ctx: Autograd context.
grad_outputs (torch.Tensor): Output gradients.
Returns:
Tuple[torch.Tensor, None, None]: (reduced_gradients, None, None)
"""
grad_outputs = grad_outputs.cuda()
(index, slice_num) = ctx.saved_tensors
if index.numel() == 0:
return (
torch.zeros(
[0] + list(grad_outputs.shape)[1:], device=grad_outputs.device
),
None,
None,
)
grad_outputs = grad_outputs / slice_num
index_unique, index_reverse = torch.unique(
index.view((-1,)), return_inverse=True, sorted=False
)
reduce_grad = torch.zeros(
[index_unique.numel()] + list(grad_outputs.shape)[1:],
dtype=grad_outputs.dtype,
device=grad_outputs.device,
)
reduce_grad.index_add_(0, index_reverse, grad_outputs)
return reduce_grad, None, None
def is_hashtable(obj):
"""Check if an object is a hash table.
Args:
obj: Object to check.
Returns:
bool: True if the object is a hash table, False otherwise.
"""
return hasattr(obj, "hashtable_tag")
def split_sparse_dense_state_dict(state_dict: dict) -> Tuple[dict, dict]:
"""Split state dictionary into sparse and dense parameters.
This function separates hash table parameters (sparse) from regular
tensor parameters (dense) in a model's state dictionary.
Args:
state_dict (dict): State dictionary from model.state_dict().
Format: {"parameter_name": parameter_value}.
Returns:
Tuple[dict, dict]: (sparse_state_dict, dense_state_dict)
- sparse_state_dict: Dictionary containing hash table parameters
- dense_state_dict: Dictionary containing regular tensor parameters
Example:
.. code-block:: python
model = MyModel() # Contains both hash tables and regular layers
state_dict = model.state_dict()
sparse_params, dense_params = split_sparse_dense_state_dict(state_dict)
print(f"Sparse parameters: {list(sparse_params.keys())}")
print(f"Dense parameters: {list(dense_params.keys())}")
"""
sparse_state_dict = {}
remove_key = set()
for key in state_dict:
value = state_dict[key]
if value is not None:
if is_hashtable(value):
sparse_state_dict[key] = value
remove_key.add(key)
for key in remove_key:
del state_dict[key]
return sparse_state_dict, state_dict
[docs]
def filter_out_sparse_param(model: torch.nn.Module) -> dict:
"""Extract sparse parameters from a PyTorch model.
This function extracts all hash table parameters from a model,
which is useful for separate handling of sparse parameters in
distributed training scenarios.
Args:
model (torch.nn.Module): PyTorch model containing hash tables.
Returns:
dict: Dictionary containing only sparse (hash table) parameters.
Example:
.. code-block:: python
from recis.nn.modules.hashtable import filter_out_sparse_param
# Separate parameters
sparse_params = filter_out_sparse_param(model)
# Create different optimizers
from recis.optim import SparseAdamW
from torch.optim import AdamW
sparse_optimizer = SparseAdamW(sparse_params, lr=0.001)
dense_optimizer = AdamW(model.parameters(), lr=0.001)
"""
state_dict = model.state_dict()
sparse_state_dict, _ = split_sparse_dense_state_dict(state_dict)
return sparse_state_dict
class HashtableRegister(metaclass=SingletonMeta):
"""Singleton registry for managing hash table instances.
This class provides a centralized registry for tracking hash table
instances across the application, ensuring proper management and
avoiding naming conflicts.
Attributes:
_hashtables (dict): Dictionary mapping hash table names to their
configuration information.
Example:
.. code-block:: python
# Register a hash table (usually done automatically)
register = HashtableRegister()
register.register("user_embedding", '{"shape": [64], "dtype": "float32"}')
# The registry is a singleton, so all instances are the same
register2 = HashtableRegister()
assert register is register2 # True
"""
def __init__(self) -> None:
"""Initialize the hash table registry."""
self._hashtables = {}
def register(self, name: str, info: str):
"""Register a hash table with the given name and configuration.
Args:
name (str): Unique name for the hash table.
info (str): JSON string containing hash table configuration.
Raises:
ValueError: If a hash table with the same name is already registered
with different configuration.
Example:
.. code-block:: python
register = HashtableRegister()
# Register a new hash table
config = '{"shape": [128], "dtype": "float32", "initializer": "constant"}'
register.register("item_embedding", config)
# This would raise ValueError due to duplicate name
# register.register("item_embedding", different_config)
"""
if name in self._hashtables:
raise ValueError(
f"Duplicate hashtable shard name: {name}, before: {self._hashtables[name]}, now: {info}"
)
self._hashtables[name] = info