Source code for recis.nn.modules.hashtable

import json
import os
from typing import List, Optional, Tuple

import torch

from recis.common.singleton import SingletonMeta
from recis.nn.hashtable_hook import AdmitHook, FilterHook
from recis.nn.initializers import ConstantInitializer
from recis.nn.modules.hashtable_hook_impl import HashtableHookFactory, ReadOnlyHookImpl
from recis.utils.logger import Logger


logger = Logger(__name__)


class Slice:
    """Partitioning configuration for distributed hash table storage.

    This class defines how the hash table's key space is partitioned
    across different workers in a distributed setting.

    Attributes:
        slice_begin (int): Starting index of the slice.
        slice_end (int): Ending index of the slice (exclusive).
        slice_size (int): Total size of the key space.

    Example:
    .. code-block:: python
        # Create a slice for worker 0 out of 4 workers
        slice_config = Slice(0, 16384, 65536)

    """

    def __init__(self, slice_beg, slice_end, slice_size) -> None:
        """Initialize slice configuration.

        Args:
            slice_beg (int): Starting index of the slice.
            slice_end (int): Ending index of the slice (exclusive).
            slice_size (int): Total size of the key space.
        """
        self.slice_begin = slice_beg
        self.slice_end = slice_end
        self.slice_size = slice_size


[docs] def gen_slice(shard_index=0, shard_num=1, slice_size=65536): """Generate slice configuration for distributed hash table partitioning. This function creates a Slice object that defines how to partition the hash table's key space across multiple workers. It ensures balanced distribution with proper handling of remainder keys. Args: shard_index (int, optional): Index of the current shard/worker. Defaults to 0. shard_num (int, optional): Total number of shards/workers. Defaults to 1. slice_size (int, optional): Total size of the key space. Defaults to 65536. Returns: Slice: Slice configuration for the specified shard. Example: .. code-block:: python # Generate slice for worker 1 out of 4 workers slice_config = gen_slice(shard_index=1, shard_num=4, slice_size=65536) print( f"Worker 1 handles keys from {slice_config.slice_begin} " f"to {slice_config.slice_end}" ) """ shard_slice_size = slice_size // shard_num shard_slice_sizes = [shard_slice_size] * shard_num remain = slice_size % shard_num shard_slice_sizes = [ size + 1 if i < remain else size for i, size in enumerate(shard_slice_sizes) ] slice_infos = [] beg = 0 for size in shard_slice_sizes: end = beg + size slice_infos.append((beg, end)) beg = end slice_info = slice_infos[shard_index] return Slice(slice_info[0], slice_info[1], slice_size)
_default_slice = Slice(0, 65536, 65536)
[docs] class HashTable(torch.nn.Module): """Distributed hash table for sparse parameter storage and lookup. This module provides a distributed hash table implementation that supports dynamic sparse parameter storage, efficient lookup operations, and gradient computation. It's designed for large-scale sparse learning scenarios where the feature vocabulary can grow dynamically. Key features: - Dynamic feature admission and eviction - Distributed storage across multiple workers - Efficient gradient computation and aggregation - Support for various initialization strategies - Hook-based filtering and admission control Example: Basic usage: .. code-block:: python import torch from recis.nn.modules.hashtable import HashTable # Create hash table hashtable = HashTable( embedding_shape=[64], block_size=1024, dtype=torch.float32, device=torch.device("cuda"), name="user_embedding", ) # Lookup embeddings ids = torch.tensor([1, 2, 3, 100, 1000]) embeddings = hashtable(ids) # Shape: [5, 64] Advanced usage with hooks: .. code-block:: python from recis.nn.hashtable_hook import FrequencyFilterHook # Create hash table with filtering filter_hook = FrequencyFilterHook(min_frequency=5) hashtable = HashTable( embedding_shape=[128], block_size=2048, filter_hook=filter_hook, grad_reduce_by="id", ) """
[docs] def __init__( self, embedding_shape: List, block_size: int = 5, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), coalesced: bool = False, children: Optional[List[str]] = None, slice: Slice = _default_slice, initializer=None, name: str = "hashtable", grad_reduce_by: str = "worker", filter_hook: Optional[FilterHook] = None, ): """Initialize hash table module. Args: embedding_shape (List[int]): Shape of embedding vectors. block_size (int, optional): Number of embeddings per block. Defaults to 5. dtype (torch.dtype, optional): Data type. Defaults to torch.float32. device (torch.device, optional): Computation device. Defaults to CPU. coalesced (bool, optional): Use coalesced operations. Defaults to False. children (Optional[List[str]], optional): Child table names. Defaults to None. slice (Slice, optional): Partitioning config. Defaults to _default_slice. initializer (Initializer, optional): Initializer. Defaults to None. name (str, optional): Table name. Defaults to "hashtable". grad_reduce_by (str, optional): Gradient reduction. Defaults to "worker". filter_hook (Optional[FilterHook], optional): Filter hook. Defaults to None. Raises: AssertionError: If grad_reduce_by is not "id" or "worker". """ super().__init__() if initializer is None: self._initializer = ConstantInitializer(init_val=0) else: self._initializer = initializer if children is None: children = [name] self._device = device assert grad_reduce_by in ["id", "worker"] self._grad_reduce_by = grad_reduce_by self._initializer.set_shape([block_size] + embedding_shape) self._initializer.set_dtype(dtype) self._initializer.build() self._dtype = dtype self._name = name for child in children: info_str = json.dumps( dict( shape=embedding_shape, dtype=str(dtype), initializer=str(self._initializer), ) ) HashtableRegister().register(child, info_str) self._hashtable_impl = torch.ops.recis.make_hashtable( block_size, embedding_shape, dtype, device, coalesced, children, self._initializer.impl(), slice.slice_begin, slice.slice_end, slice.slice_size, ) self._backward_holder = torch.tensor([0.0], requires_grad=True) self._worker_num = int(os.environ.get("WORLD_SIZE", 1)) def state_dict_hook( self: HashTable, state_dict: dict, prefix: str, local_metadata ): state_dict[self._name] = self._hashtable_impl self._register_state_dict_hook(state_dict_hook) # TODO (sunhechen.shc) support more filter hook if filter_hook is not None: self._filter_hook_impl = HashtableHookFactory().create_filter_hook( self, filter_hook ) else: self._filter_hook_impl = torch.nn.Identity()
[docs] def forward(self, ids: torch.Tensor, admit_hook: AdmitHook = None) -> torch.Tensor: """Perform embedding lookup for given feature IDs. This method looks up embeddings for the provided feature IDs, handling deduplication, gradient computation, and optional feature admission control. Args: ids (torch.Tensor): Feature IDs to lookup. Shape: [N] where N is the number of features. admit_hook (AdmitHook, optional): Hook for controlling feature admission. Defaults to None. Returns: torch.Tensor: Looked up embeddings. Shape: [N, embedding_dim] where embedding_dim is determined by embedding_shape. Example: .. code-block:: python # Basic lookup ids = torch.tensor([1, 2, 3, 2, 1]) # Note: duplicates embeddings = hashtable(ids) # Shape: [5, embedding_dim] # With admission hook from recis.nn.hashtable_hook import FrequencyAdmitHook admit_hook = FrequencyAdmitHook(min_frequency=3) embeddings = hashtable(ids, admit_hook) """ admit_hook_impl = ( HashtableHookFactory().create_admit_hook(self, admit_hook) if admit_hook else None ) ids, index = ids.unique(return_inverse=True) if self.training and self._dtype not in (torch.int8, torch.int32, torch.int64): emb_idx, embedding = HashTableLookupHelpFunction.apply( ids, self._hashtable_impl, self._backward_holder, admit_hook_impl ) if self._grad_reduce_by == "id": embedding = GradIDMeanFunction.apply(embedding, index) else: slice_num = torch.scalar_tensor(self._worker_num) embedding = GradWorkerMeanFunction.apply(embedding, index, slice_num) self._filter_hook_impl(emb_idx) else: ids = ids.detach() _, embedding = self._hashtable_impl.embedding_lookup(ids, True) embedding = embedding.cuda() embedding = torch.ops.recis.gather(index, embedding) return embedding
def initializer(self): """Get the embedding initializer. Returns: Initializer: The initializer used for new embeddings. """ return self._initializer @property def device(self): """Get the computation device. Returns: torch.device: The device used for computation. """ return self._device @property def coalesce(self): """Check if coalesced operations are enabled. Returns: bool: True if coalesced operations are enabled. """ return self._hashtable_impl.children_info().is_coalesce() @property def children_hashtable(self): """Get the list of child hash tables. Returns: List[str]: Names of child hash tables. """ return self._hashtable_impl.children_info().children()
[docs] def accept_grad(self, grad_index, grad) -> None: """Accept gradients for specific embedding indices. Args: grad_index (torch.Tensor): Indices of embeddings to update. grad (torch.Tensor): Gradient values for the embeddings. """ self._hashtable_impl.accept_grad(grad_index, grad)
[docs] def grad(self, acc_step=1) -> torch.Tensor: """Get accumulated gradients. Args: acc_step (int, optional): Accumulation step. Defaults to 1. Returns: torch.Tensor: Accumulated gradients. """ return self._hashtable_impl.grad(acc_step)
[docs] def clear_grad(self) -> None: """Clear accumulated gradients.""" self._hashtable_impl.clear_grad()
[docs] def insert(self, ids, embeddings) -> None: """Insert embeddings for specific IDs. Args: ids (torch.Tensor): Feature IDs to insert. embeddings (torch.Tensor): Embedding values to insert. """ self._hashtable_impl.insert(ids, embeddings)
def clear(self) -> None: """Clear all stored embeddings.""" self._hashtable_impl.clear()
[docs] def ids(self) -> torch.Tensor: """Get all stored feature IDs. Returns: torch.Tensor: All feature IDs currently stored in the table. """ return self._hashtable_impl.ids()
def ids_map(self) -> Tuple[torch.Tensor, torch.Tensor]: """Get mapping between feature IDs and internal indices. Returns: Tuple[torch.Tensor, torch.Tensor]: (feature_ids, internal_indices) """ return self._hashtable_impl.ids_map()
[docs] def embeddings(self) -> torch.Tensor: """Get all stored embeddings. Returns: torch.Tensor: All embedding values currently stored in the table. """ return self._hashtable_impl.slot_group().slot_by_name("embedding").value()
[docs] def slot_group(self): """Get the slot group for advanced operations. Returns: SlotGroup: The slot group containing all storage slots. """ return self._hashtable_impl.slot_group()
[docs] def children_info(self): """Get information about child hash tables. Returns: ChildrenInfo: Information about child hash tables. """ return self._hashtable_impl.children_info()
def __str__(self) -> str: """String representation of the hash table. Returns: str: String representation including the table name. """ return f"HashTable_{self._name}" def __repr__(self) -> str: """Detailed string representation of the hash table. Returns: str: Detailed string representation. """ return self.__str__()
class HashTableLookupHelpFunction(torch.autograd.Function): """Autograd function for hash table embedding lookup with gradient support. This function provides the forward and backward passes for embedding lookup operations, handling gradient computation and admission hooks. """ @staticmethod def forward( ctx, ids: torch.Tensor, hashtable: object, backward_holder: torch.Tensor, admit_hook_impl, ) -> torch.Tensor: """Forward pass for embedding lookup. Args: ctx: Autograd context for storing information. ids (torch.Tensor): Feature IDs to lookup. hashtable (torch.classes.recis.HashtableImpl): Hash table implementation. backward_holder (torch.Tensor): Tensor for gradient computation. admit_hook_impl: Implementation of admission hook. Returns: Tuple[torch.Tensor, torch.Tensor]: (indices, embeddings) Raises: AssertionError: If admit_hook_impl is not None or ReadOnlyHookImpl. """ assert admit_hook_impl is None or isinstance( admit_hook_impl, ReadOnlyHookImpl ), f"admit hook only support ReadOnlyHook yet, but got: {admit_hook_impl}" ids = ids.detach() index, embedding = hashtable.embedding_lookup(ids, admit_hook_impl is not None) ctx.save_for_backward(index) ctx.hashtable = hashtable return index.to(device="cuda"), embedding.to(device="cuda") @staticmethod def backward(ctx, grad_output_index, grad_output_emb) -> torch.Tensor: """Backward pass for embedding lookup. Args: ctx: Autograd context containing forward pass information. grad_output_index: Gradient for indices (unused). grad_output_emb (torch.Tensor): Gradient for embeddings. Returns: Tuple: Gradients for all inputs (most are None). """ (index,) = ctx.saved_tensors ctx.hashtable.accept_grad( index.to(device="cuda"), grad_output_emb.to(device="cuda"), ) return (None, None, None, None, None) class GradIDMeanFunction(torch.autograd.Function): """Autograd function for gradient aggregation by feature ID. This function handles gradient computation when using ID-based gradient reduction, ensuring proper gradient flow for duplicate IDs. """ @staticmethod def forward(ctx, embedding, index): """Forward pass for ID-based gradient aggregation. Args: ctx: Autograd context. embedding (torch.Tensor): Input embeddings. index (torch.Tensor): Index mapping for gathering. Returns: torch.Tensor: Gathered embeddings. """ ctx.save_for_backward(index) return torch.ops.recis.gather(index, embedding) @staticmethod def backward(ctx, grad_outputs): """Backward pass for ID-based gradient aggregation. Args: ctx: Autograd context. grad_outputs (torch.Tensor): Output gradients. Returns: Tuple[torch.Tensor, None]: (reduced_gradients, None) """ grad_outputs = grad_outputs.cuda() (index,) = ctx.saved_tensors if index.numel() == 0: return ( torch.zeros( [0] + list(grad_outputs.shape)[1:], device=grad_outputs.device ), None, ) shape = [index.max() + 1] + list(grad_outputs.shape)[1:] reduce_grad = torch.zeros(shape, device=grad_outputs.device) reduce_grad.index_reduce_(0, index, grad_outputs, "mean", include_self=False) return reduce_grad, None class GradWorkerMeanFunction(torch.autograd.Function): """Autograd function for gradient aggregation by worker. This function handles gradient computation when using worker-based gradient reduction, distributing gradients across multiple workers. """ @staticmethod def forward(ctx, embedding, index, slice_num): """Forward pass for worker-based gradient aggregation. Args: ctx: Autograd context. embedding (torch.Tensor): Input embeddings. index (torch.Tensor): Index mapping for gathering. slice_num (torch.Tensor): Number of worker slices. Returns: torch.Tensor: Gathered embeddings. """ ctx.save_for_backward(index, slice_num) return torch.ops.recis.gather(index, embedding) @staticmethod def backward(ctx, grad_outputs): """Backward pass for worker-based gradient aggregation. Args: ctx: Autograd context. grad_outputs (torch.Tensor): Output gradients. Returns: Tuple[torch.Tensor, None, None]: (reduced_gradients, None, None) """ grad_outputs = grad_outputs.cuda() (index, slice_num) = ctx.saved_tensors if index.numel() == 0: return ( torch.zeros( [0] + list(grad_outputs.shape)[1:], device=grad_outputs.device ), None, None, ) grad_outputs = grad_outputs / slice_num index_unique, index_reverse = torch.unique( index.view((-1,)), return_inverse=True, sorted=False ) reduce_grad = torch.zeros( [index_unique.numel()] + list(grad_outputs.shape)[1:], dtype=grad_outputs.dtype, device=grad_outputs.device, ) reduce_grad.index_add_(0, index_reverse, grad_outputs) return reduce_grad, None, None def is_hashtable(obj): """Check if an object is a hash table. Args: obj: Object to check. Returns: bool: True if the object is a hash table, False otherwise. """ return hasattr(obj, "hashtable_tag") def split_sparse_dense_state_dict(state_dict: dict) -> Tuple[dict, dict]: """Split state dictionary into sparse and dense parameters. This function separates hash table parameters (sparse) from regular tensor parameters (dense) in a model's state dictionary. Args: state_dict (dict): State dictionary from model.state_dict(). Format: {"parameter_name": parameter_value}. Returns: Tuple[dict, dict]: (sparse_state_dict, dense_state_dict) - sparse_state_dict: Dictionary containing hash table parameters - dense_state_dict: Dictionary containing regular tensor parameters Example: .. code-block:: python model = MyModel() # Contains both hash tables and regular layers state_dict = model.state_dict() sparse_params, dense_params = split_sparse_dense_state_dict(state_dict) print(f"Sparse parameters: {list(sparse_params.keys())}") print(f"Dense parameters: {list(dense_params.keys())}") """ sparse_state_dict = {} remove_key = set() for key in state_dict: value = state_dict[key] if value is not None: if is_hashtable(value): sparse_state_dict[key] = value remove_key.add(key) for key in remove_key: del state_dict[key] return sparse_state_dict, state_dict
[docs] def filter_out_sparse_param(model: torch.nn.Module) -> dict: """Extract sparse parameters from a PyTorch model. This function extracts all hash table parameters from a model, which is useful for separate handling of sparse parameters in distributed training scenarios. Args: model (torch.nn.Module): PyTorch model containing hash tables. Returns: dict: Dictionary containing only sparse (hash table) parameters. Example: .. code-block:: python from recis.nn.modules.hashtable import filter_out_sparse_param # Separate parameters sparse_params = filter_out_sparse_param(model) # Create different optimizers from recis.optim import SparseAdamW from torch.optim import AdamW sparse_optimizer = SparseAdamW(sparse_params, lr=0.001) dense_optimizer = AdamW(model.parameters(), lr=0.001) """ state_dict = model.state_dict() sparse_state_dict, _ = split_sparse_dense_state_dict(state_dict) return sparse_state_dict
class HashtableRegister(metaclass=SingletonMeta): """Singleton registry for managing hash table instances. This class provides a centralized registry for tracking hash table instances across the application, ensuring proper management and avoiding naming conflicts. Attributes: _hashtables (dict): Dictionary mapping hash table names to their configuration information. Example: .. code-block:: python # Register a hash table (usually done automatically) register = HashtableRegister() register.register("user_embedding", '{"shape": [64], "dtype": "float32"}') # The registry is a singleton, so all instances are the same register2 = HashtableRegister() assert register is register2 # True """ def __init__(self) -> None: """Initialize the hash table registry.""" self._hashtables = {} def register(self, name: str, info: str): """Register a hash table with the given name and configuration. Args: name (str): Unique name for the hash table. info (str): JSON string containing hash table configuration. Raises: ValueError: If a hash table with the same name is already registered with different configuration. Example: .. code-block:: python register = HashtableRegister() # Register a new hash table config = '{"shape": [128], "dtype": "float32", "initializer": "constant"}' register.register("item_embedding", config) # This would raise ValueError due to duplicate name # register.register("item_embedding", different_config) """ if name in self._hashtables: raise ValueError( f"Duplicate hashtable shard name: {name}, before: {self._hashtables[name]}, now: {info}" ) self._hashtables[name] = info