HashTable Module

RecIS’s HashTable is the core structure of dynamic embedding tables. It supports dynamic expansion of embedding tables and provides feature admission and feature filtering capabilities.

HashTable

class recis.nn.modules.hashtable.HashTable(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]

Distributed hash table for sparse parameter storage and lookup.

This module provides a distributed hash table implementation that supports dynamic sparse parameter storage, efficient lookup operations, and gradient computation. It’s designed for large-scale sparse learning scenarios where the feature vocabulary can grow dynamically.

Key features:
  • Dynamic feature admission and eviction

  • Distributed storage across multiple workers

  • Efficient gradient computation and aggregation

  • Support for various initialization strategies

  • Hook-based filtering and admission control

Example

Basic usage:

import torch
from recis.nn.modules.hashtable import HashTable

# Create hash table
hashtable = HashTable(
    embedding_shape=[64],
    block_size=1024,
    dtype=torch.float32,
    device=torch.device("cuda"),
    name="user_embedding",
)

# Lookup embeddings
ids = torch.tensor([1, 2, 3, 100, 1000])
embeddings = hashtable(ids)  # Shape: [5, 64]


Advanced usage with hooks:
from recis.nn.hashtable_hook import FrequencyFilterHook

# Create hash table with filtering
filter_hook = FrequencyFilterHook(min_frequency=5)
hashtable = HashTable(
    embedding_shape=[128],
    block_size=2048,
    filter_hook=filter_hook,
    grad_reduce_by="id",
)
__init__(embedding_shape: ~typing.List, block_size: int = 5, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), coalesced: bool = False, children: ~typing.List[str] | None = None, slice: ~recis.nn.modules.hashtable.Slice = <recis.nn.modules.hashtable.Slice object>, initializer=None, name: str = 'hashtable', grad_reduce_by: str = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None)[source]

Initialize hash table module.

Parameters:
  • embedding_shape (List[int]) – Shape of embedding vectors.

  • block_size (int, optional) – Number of embeddings per block. Defaults to 5.

  • dtype (torch.dtype, optional) – Data type. Defaults to torch.float32.

  • device (torch.device, optional) – Computation device. Defaults to CPU.

  • coalesced (bool, optional) – Use coalesced operations. Defaults to False.

  • children (Optional[List[str]], optional) – Child table names. Defaults to None.

  • slice (Slice, optional) – Partitioning config. Defaults to _default_slice.

  • initializer (Initializer, optional) – Initializer. Defaults to None.

  • name (str, optional) – Table name. Defaults to “hashtable”.

  • grad_reduce_by (str, optional) – Gradient reduction. Defaults to “worker”.

  • filter_hook (Optional[FilterHook], optional) – Filter hook. Defaults to None.

Raises:

AssertionError – If grad_reduce_by is not “id” or “worker”.

accept_grad(grad_index, grad) None[source]

Accept gradients for specific embedding indices.

Parameters:
  • grad_index (torch.Tensor) – Indices of embeddings to update.

  • grad (torch.Tensor) – Gradient values for the embeddings.

children_info()[source]

Get information about child hash tables.

Returns:

Information about child hash tables.

Return type:

ChildrenInfo

clear_grad() None[source]

Clear accumulated gradients.

embeddings() Tensor[source]

Get all stored embeddings.

Returns:

All embedding values currently stored in the table.

Return type:

torch.Tensor

forward(ids: Tensor, admit_hook: AdmitHook = None) Tensor[source]

Perform embedding lookup for given feature IDs.

This method looks up embeddings for the provided feature IDs, handling deduplication, gradient computation, and optional feature admission control.

Parameters:
  • ids (torch.Tensor) – Feature IDs to lookup. Shape: [N] where N is the number of features.

  • admit_hook (AdmitHook, optional) – Hook for controlling feature admission. Defaults to None.

Returns:

Looked up embeddings. Shape: [N, embedding_dim]

where embedding_dim is determined by embedding_shape.

Return type:

torch.Tensor

Example:

# Basic lookup
ids = torch.tensor([1, 2, 3, 2, 1])  # Note: duplicates
embeddings = hashtable(ids)  # Shape: [5, embedding_dim]

# With admission hook
from recis.nn.hashtable_hook import FrequencyAdmitHook

admit_hook = FrequencyAdmitHook(min_frequency=3)
embeddings = hashtable(ids, admit_hook)
grad(acc_step=1) Tensor[source]

Get accumulated gradients.

Parameters:

acc_step (int, optional) – Accumulation step. Defaults to 1.

Returns:

Accumulated gradients.

Return type:

torch.Tensor

ids() Tensor[source]

Get all stored feature IDs.

Returns:

All feature IDs currently stored in the table.

Return type:

torch.Tensor

insert(ids, embeddings) None[source]

Insert embeddings for specific IDs.

Parameters:
slot_group()[source]

Get the slot group for advanced operations.

Returns:

The slot group containing all storage slots.

Return type:

SlotGroup

Utility Functions

recis.nn.modules.hashtable.filter_out_sparse_param(model: Module) dict[source]

Extract sparse parameters from a PyTorch model.

This function extracts all hash table parameters from a model, which is useful for separate handling of sparse parameters in distributed training scenarios.

Parameters:

model (torch.nn.Module) – PyTorch model containing hash tables.

Returns:

Dictionary containing only sparse (hash table) parameters.

Return type:

dict

Example: .. code-block:: python

from recis.nn.modules.hashtable import filter_out_sparse_param

# Separate parameters sparse_params = filter_out_sparse_param(model)

# Create different optimizers from recis.optim import SparseAdamW from torch.optim import AdamW

sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) dense_optimizer = AdamW(model.parameters(), lr=0.001)

recis.nn.modules.hashtable.gen_slice(shard_index=0, shard_num=1, slice_size=65536)[source]

Generate slice configuration for distributed hash table partitioning.

This function creates a Slice object that defines how to partition the hash table’s key space across multiple workers. It ensures balanced distribution with proper handling of remainder keys.

Parameters:
  • shard_index (int, optional) – Index of the current shard/worker. Defaults to 0.

  • shard_num (int, optional) – Total number of shards/workers. Defaults to 1.

  • slice_size (int, optional) – Total size of the key space. Defaults to 65536.

Returns:

Slice configuration for the specified shard.

Return type:

Slice

Example:

# Generate slice for worker 1 out of 4 workers
slice_config = gen_slice(shard_index=1, shard_num=4, slice_size=65536)
print(
    f"Worker 1 handles keys from {slice_config.slice_begin} "
    f"to {slice_config.slice_end}"
)