HashTable Module
RecIS’s HashTable is the core structure of dynamic embedding tables. It supports dynamic expansion of embedding tables and provides feature admission and feature filtering capabilities.
HashTable
- class recis.nn.modules.hashtable.HashTable(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]
- Distributed hash table for sparse parameter storage and lookup. - This module provides a distributed hash table implementation that supports dynamic sparse parameter storage, efficient lookup operations, and gradient computation. It’s designed for large-scale sparse learning scenarios where the feature vocabulary can grow dynamically. - Key features:
- Dynamic feature admission and eviction 
- Distributed storage across multiple workers 
- Efficient gradient computation and aggregation 
- Support for various initialization strategies 
- Hook-based filtering and admission control 
 
 - Example - Basic usage: - import torch from recis.nn.modules.hashtable import HashTable # Create hash table hashtable = HashTable( embedding_shape=[64], block_size=1024, dtype=torch.float32, device=torch.device("cuda"), name="user_embedding", ) # Lookup embeddings ids = torch.tensor([1, 2, 3, 100, 1000]) embeddings = hashtable(ids) # Shape: [5, 64] Advanced usage with hooks: - from recis.nn.hashtable_hook import FrequencyFilterHook # Create hash table with filtering filter_hook = FrequencyFilterHook(min_frequency=5) hashtable = HashTable( embedding_shape=[128], block_size=2048, filter_hook=filter_hook, grad_reduce_by="id", ) - __init__(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]
- Initialize hash table module. - Parameters:
- embedding_shape (List[int]) – Shape of embedding vectors. 
- block_size (int, optional) – Number of embeddings per block. Defaults to 5. 
- dtype (torch.dtype, optional) – Data type. Defaults to torch.float32. 
- device (torch.device, optional) – Computation device. Defaults to CPU. 
- coalesced (bool, optional) – Use coalesced operations. Defaults to False. 
- children (Optional[List[str]], optional) – Child table names. Defaults to None. 
- slice (Slice, optional) – Partitioning config. Defaults to _default_slice. 
- initializer (Initializer, optional) – Initializer. Defaults to None. 
- name (str, optional) – Table name. Defaults to “hashtable”. 
- grad_reduce_by (str, optional) – Gradient reduction. Defaults to “worker”. 
- filter_hook (Optional[FilterHook], optional) – Filter hook. Defaults to None. 
 
- Raises:
- AssertionError – If grad_reduce_by is not “id” or “worker”. 
 
 - accept_grad(grad_index, grad) None[source]
- Accept gradients for specific embedding indices. - Parameters:
- grad_index (torch.Tensor) – Indices of embeddings to update. 
- grad (torch.Tensor) – Gradient values for the embeddings. 
 
 
 - children_info()[source]
- Get information about child hash tables. - Returns:
- Information about child hash tables. 
- Return type:
- ChildrenInfo 
 
 - embeddings() Tensor[source]
- Get all stored embeddings. - Returns:
- All embedding values currently stored in the table. 
- Return type:
 
 - forward(ids: Tensor, admit_hook: AdmitHook = None) Tensor[source]
- Perform embedding lookup for given feature IDs. - This method looks up embeddings for the provided feature IDs, handling deduplication, gradient computation, and optional feature admission control. - Parameters:
- ids (torch.Tensor) – Feature IDs to lookup. Shape: [N] where N is the number of features. 
- admit_hook (AdmitHook, optional) – Hook for controlling feature admission. Defaults to None. 
 
- Returns:
- Looked up embeddings. Shape: [N, embedding_dim]
- where embedding_dim is determined by embedding_shape. 
 
- Return type:
 - Example: - # Basic lookup ids = torch.tensor([1, 2, 3, 2, 1]) # Note: duplicates embeddings = hashtable(ids) # Shape: [5, embedding_dim] # With admission hook from recis.nn.hashtable_hook import FrequencyAdmitHook admit_hook = FrequencyAdmitHook(min_frequency=3) embeddings = hashtable(ids, admit_hook) 
 - grad(acc_step=1) Tensor[source]
- Get accumulated gradients. - Parameters:
- acc_step (int, optional) – Accumulation step. Defaults to 1. 
- Returns:
- Accumulated gradients. 
- Return type:
 
 - ids() Tensor[source]
- Get all stored feature IDs. - Returns:
- All feature IDs currently stored in the table. 
- Return type:
 
 - insert(ids, embeddings) None[source]
- Insert embeddings for specific IDs. - Parameters:
- ids (torch.Tensor) – Feature IDs to insert. 
- embeddings (torch.Tensor) – Embedding values to insert. 
 
 
 
Utility Functions
- recis.nn.modules.hashtable.filter_out_sparse_param(model: Module) dict[source]
- Extract sparse parameters from a PyTorch model. - This function extracts all hash table parameters from a model, which is useful for separate handling of sparse parameters in distributed training scenarios. - Parameters:
- model (torch.nn.Module) – PyTorch model containing hash tables. 
- Returns:
- Dictionary containing only sparse (hash table) parameters. 
- Return type:
 - Example: .. code-block:: python - from recis.nn.modules.hashtable import filter_out_sparse_param - # Separate parameters sparse_params = filter_out_sparse_param(model) - # Create different optimizers from recis.optim import SparseAdamW from torch.optim import AdamW - sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) dense_optimizer = AdamW(model.parameters(), lr=0.001) 
- recis.nn.modules.hashtable.gen_slice(shard_index=0, shard_num=1, slice_size=65536)[source]
- Generate slice configuration for distributed hash table partitioning. - This function creates a Slice object that defines how to partition the hash table’s key space across multiple workers. It ensures balanced distribution with proper handling of remainder keys. - Parameters:
- Returns:
- Slice configuration for the specified shard. 
- Return type:
- Slice 
 - Example: - # Generate slice for worker 1 out of 4 workers slice_config = gen_slice(shard_index=1, shard_num=4, slice_size=65536) print( f"Worker 1 handles keys from {slice_config.slice_begin} " f"to {slice_config.slice_end}" )