HashTable Module
RecIS’s HashTable is the core structure of dynamic embedding tables. It supports dynamic expansion of embedding tables and provides feature admission and feature filtering capabilities.
HashTable
- class recis.nn.modules.hashtable.HashTable(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]
Distributed hash table for sparse parameter storage and lookup.
This module provides a distributed hash table implementation that supports dynamic sparse parameter storage, efficient lookup operations, and gradient computation. It’s designed for large-scale sparse learning scenarios where the feature vocabulary can grow dynamically.
- Key features:
Dynamic feature admission and eviction
Distributed storage across multiple workers
Efficient gradient computation and aggregation
Support for various initialization strategies
Hook-based filtering and admission control
Example
Basic usage:
import torch from recis.nn.modules.hashtable import HashTable # Create hash table hashtable = HashTable( embedding_shape=[64], block_size=1024, dtype=torch.float32, device=torch.device("cuda"), name="user_embedding", ) # Lookup embeddings ids = torch.tensor([1, 2, 3, 100, 1000]) embeddings = hashtable(ids) # Shape: [5, 64] Advanced usage with hooks:
from recis.nn.hashtable_hook import FrequencyFilterHook # Create hash table with filtering filter_hook = FrequencyFilterHook(min_frequency=5) hashtable = HashTable( embedding_shape=[128], block_size=2048, filter_hook=filter_hook, grad_reduce_by="id", )
- __init__(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]
Initialize hash table module.
- Parameters:
embedding_shape (List[int]) – Shape of embedding vectors.
block_size (int, optional) – Number of embeddings per block. Defaults to 5.
dtype (torch.dtype, optional) – Data type. Defaults to torch.float32.
device (torch.device, optional) – Computation device. Defaults to CPU.
coalesced (bool, optional) – Use coalesced operations. Defaults to False.
children (Optional[List[str]], optional) – Child table names. Defaults to None.
slice (Slice, optional) – Partitioning config. Defaults to _default_slice.
initializer (Initializer, optional) – Initializer. Defaults to None.
name (str, optional) – Table name. Defaults to “hashtable”.
grad_reduce_by (str, optional) – Gradient reduction. Defaults to “worker”.
filter_hook (Optional[FilterHook], optional) – Filter hook. Defaults to None.
- Raises:
AssertionError – If grad_reduce_by is not “id” or “worker”.
- accept_grad(grad_index, grad) None [source]
Accept gradients for specific embedding indices.
- Parameters:
grad_index (torch.Tensor) – Indices of embeddings to update.
grad (torch.Tensor) – Gradient values for the embeddings.
- children_info()[source]
Get information about child hash tables.
- Returns:
Information about child hash tables.
- Return type:
ChildrenInfo
- embeddings() Tensor [source]
Get all stored embeddings.
- Returns:
All embedding values currently stored in the table.
- Return type:
- forward(ids: Tensor, admit_hook: AdmitHook = None) Tensor [source]
Perform embedding lookup for given feature IDs.
This method looks up embeddings for the provided feature IDs, handling deduplication, gradient computation, and optional feature admission control.
- Parameters:
ids (torch.Tensor) – Feature IDs to lookup. Shape: [N] where N is the number of features.
admit_hook (AdmitHook, optional) – Hook for controlling feature admission. Defaults to None.
- Returns:
- Looked up embeddings. Shape: [N, embedding_dim]
where embedding_dim is determined by embedding_shape.
- Return type:
Example:
# Basic lookup ids = torch.tensor([1, 2, 3, 2, 1]) # Note: duplicates embeddings = hashtable(ids) # Shape: [5, embedding_dim] # With admission hook from recis.nn.hashtable_hook import FrequencyAdmitHook admit_hook = FrequencyAdmitHook(min_frequency=3) embeddings = hashtable(ids, admit_hook)
- grad(acc_step=1) Tensor [source]
Get accumulated gradients.
- Parameters:
acc_step (int, optional) – Accumulation step. Defaults to 1.
- Returns:
Accumulated gradients.
- Return type:
- ids() Tensor [source]
Get all stored feature IDs.
- Returns:
All feature IDs currently stored in the table.
- Return type:
- insert(ids, embeddings) None [source]
Insert embeddings for specific IDs.
- Parameters:
ids (torch.Tensor) – Feature IDs to insert.
embeddings (torch.Tensor) – Embedding values to insert.
Utility Functions
- recis.nn.modules.hashtable.filter_out_sparse_param(model: Module) dict [source]
Extract sparse parameters from a PyTorch model.
This function extracts all hash table parameters from a model, which is useful for separate handling of sparse parameters in distributed training scenarios.
- Parameters:
model (torch.nn.Module) – PyTorch model containing hash tables.
- Returns:
Dictionary containing only sparse (hash table) parameters.
- Return type:
Example: .. code-block:: python
from recis.nn.modules.hashtable import filter_out_sparse_param
# Separate parameters sparse_params = filter_out_sparse_param(model)
# Create different optimizers from recis.optim import SparseAdamW from torch.optim import AdamW
sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) dense_optimizer = AdamW(model.parameters(), lr=0.001)
- recis.nn.modules.hashtable.gen_slice(shard_index=0, shard_num=1, slice_size=65536)[source]
Generate slice configuration for distributed hash table partitioning.
This function creates a Slice object that defines how to partition the hash table’s key space across multiple workers. It ensures balanced distribution with proper handling of remainder keys.
- Parameters:
- Returns:
Slice configuration for the specified shard.
- Return type:
Slice
Example:
# Generate slice for worker 1 out of 4 workers slice_config = gen_slice(shard_index=1, shard_num=4, slice_size=65536) print( f"Worker 1 handles keys from {slice_config.slice_begin} " f"to {slice_config.slice_end}" )