Dynamic Embedding Tables
The RecIS Dynamic Embedding Tables provides efficient and scalable sparse parameter storage and lookup capabilities, supporting real-time updates of large-scale dynamic vocabularies and distributed training, offering complete sparse feature embedding solutions for recommendation systems and other scenarios.
Core Features
- Dynamic Embedding Management
Real-time Expanding: Support for dynamically adding new feature IDs during training without predefined vocabulary size
Feature Filtering: Provide filtering strategies to automatically remove low-frequency or expired features
- Distributed Storage Architecture
Distributed Sharding: Row-wise partitioning supporting multi-worker parallel training
Gradient Aggregation: Support for gradient aggregation strategies by ID or by worker
- High-Performance Computing Optimization
Operator Fusion: Batch processing and fusion optimization for multi-feature embedding lookups
GPU Acceleration: Complete CUDA operator support fully utilizing GPU parallel computing capabilities
Single Dynamic Embedding Table
- class recis.nn.DynamicEmbedding(emb_opt: EmbeddingOption, pg: ProcessGroup = None)[source]
Dynamic embedding module for distributed sparse feature learning.
This module provides a distributed dynamic embedding table that can automatically handle feature admission, eviction, and cross-worker communication. It supports both dense tensors and ragged tensors for flexible sparse feature handling.
The module uses a hash table backend for efficient sparse storage and supports various combiners (sum, mean, tile) for aggregating multiple embeddings per sample.
- Parameters:
emb_opt (EmbeddingOption) – Configuration options for the embedding.
pg (dist.ProcessGroup, optional) – Process group for distributed communication. Defaults to None (uses default group).
Example
Basic usage:
from recis.nn import DynamicEmbedding, EmbeddingOption from recis.nn.initializers import TruncNormalInitializer # Configure embedding options emb_opt = EmbeddingOption( embedding_dim=64, shared_name="user_embedding", combiner="sum", initializer=TruncNormalInitializer(std=0.01) ) # Create dynamic embedding embedding = DynamicEmbedding(emb_opt) # Forward propagation ids = torch.LongTensor([1, 2, 3, 4]) emb_output = embedding(ids) Advanced usage with ragged tensors:
from recis.ragged.tensor import RaggedTensor # Create ragged tensor for variable-length sequences values = torch.tensor([1, 2, 3, 4, 5, 6, 7]) offsets = torch.tensor([0, 3, 5, 7]) # Batch boundaries ragged_ids = RaggedTensor(values, offsets) # Forward pass embeddings = embedding(ragged_ids)
- __init__(emb_opt: EmbeddingOption, pg: ProcessGroup = None)[source]
Initialize dynamic embedding module.
- Parameters:
emb_opt (EmbeddingOption) – Configuration options for the embedding.
pg (dist.ProcessGroup, optional) – Process group for distributed communication. Defaults to None.
- forward(input_ids: Tensor | RaggedTensor, input_weights=None)[source]
Forward pass of dynamic embedding lookup.
Performs distributed embedding lookup with the following steps: 1. Process input tensor to extract IDs, offsets, and weights 2. Exchange IDs across workers for distributed lookup 3. Perform embedding lookup and exchange results back 4. Aggregate embeddings using the specified combiner
- Parameters:
input_ids (Union[torch.Tensor, RaggedTensor]) – Input feature IDs. For dense tensors, shape should be [batch_size, num_features]. For ragged tensors, supports variable-length sequences.
input_weights (torch.Tensor, optional) – Weights for weighted aggregation. Only used when input_ids is a dense tensor. Defaults to None.
- Returns:
- Aggregated embeddings with shape:
For sum/mean combiner: [batch_size, embedding_dim]
For tile combiner: [batch_size, tile_len * embedding_dim]
- Return type:
Example:
# Dense tensor input ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) embeddings = embedding(ids) # Shape: [2, embedding_dim] # With weights weights = torch.tensor([[0.5, 1.0, 0.8], [1.2, 0.9, 1.1]]) weighted_embs = embedding(ids, weights) # Ragged tensor input ragged_ids = RaggedTensor(values, offsets) ragged_embs = embedding(ragged_ids)
Embedding Configuration
- class recis.nn.EmbeddingOption(embedding_dim: int = 16, block_size: int = 10240, dtype: ~torch.dtype = torch.float32, device: ~torch.device = device(type='cpu'), trainable: bool = True, pg: ~torch.distributed.distributed_c10d.ProcessGroup = None, max_partition_num: int = 65536, shared_name: str = 'embedding', children: ~typing.List[str] = <factory>, coalesced: bool | None = False, initializer: ~recis.nn.initializers.Initializer | None = None, use_weight: bool | None = True, combiner: str | None = 'sum', combiner_kwargs: dict | None = None, grad_reduce_by: str | None = 'worker', filter_hook: ~recis.nn.hashtable_hook.FilterHook | None = None, admit_hook: ~recis.nn.hashtable_hook.AdmitHook | None = None, fp16_enabled: bool = False)[source]
Configuration class for dynamic embedding parameters.
This class encapsulates all configuration options for dynamic embedding, including dimension settings, device placement, training options, and distributed communication parameters.
- dtype
Data type for embeddings. Defaults to torch.float32.
- Type:
- device
Device for computation. Defaults to CPU.
- Type:
- pg
Process group for distributed training. Defaults to None.
- Type:
dist.ProcessGroup
Shared name for embedding table. Defaults to “embedding”.
- Type:
- initializer
Embedding initializer. Defaults to None.
- Type:
Optional[Initializer]
- filter_hook
Filter hook for feature filtering. Defaults to None.
- Type:
Optional[FilterHook]
Example:
from recis.nn.modules.embedding import EmbeddingOption # Basic configuration emb_opt = EmbeddingOption( embedding_dim=128, block_size=2048, dtype=torch.float32, trainable=True, combiner="mean", ) # Advanced configuration with hooks emb_opt = EmbeddingOption( embedding_dim=64, combiner="tile", combiner_kwargs={"tile_len": 10}, filter_hook=my_filter_hook, admit_hook=my_admit_hook, )
Embedding Engine
- class recis.nn.EmbeddingEngine(emb_options: dict[str, EmbeddingOption])[source]
Embedding engine for efficient batch processing of multiple embeddings.
This module provides a high-level interface for managing and processing multiple embedding tables efficiently. It automatically groups embeddings with similar configurations and processes them using coalesced operations to minimize communication overhead in distributed training.
- Key features:
Automatic grouping of similar embedding configurations
Coalesced operations for improved performance
Support for mixed tensor types (dense and ragged)
Flexible feature routing (embedding vs. pass-through)
Memory-efficient processing pipeline
- Parameters:
emb_options (dict[str, EmbeddingOption]) – Dictionary mapping feature names to their embedding options.
Example
Multi-embedding scenario:
import torch from recis.nn.modules.embedding import EmbeddingOption from recis.nn.modules.embedding_engine import EmbeddingEngine # Define embedding options for different features emb_options = { "user_id": EmbeddingOption( embedding_dim=128, shared_name="user_embedding", combiner="sum", trainable=True, ), "item_id": EmbeddingOption( embedding_dim=128, shared_name="item_embedding", combiner="sum", trainable=True, ), "category": EmbeddingOption( embedding_dim=64, shared_name="category_embedding", combiner="mean", trainable=True, ), } # Create embedding engine engine = EmbeddingEngine(emb_options) # Prepare input features batch_size = 32 features = { "user_id": torch.randint(0, 10000, (batch_size, 1)), "item_id": torch.randint(0, 50000, (batch_size, 5)), # Multi-hot "category": torch.randint(0, 100, (batch_size, 1)), "other_feature": torch.randn(batch_size, 10), # Pass-through } # Forward pass embeddings = engine(features) # Results contain embeddings for configured features # and pass-through for non-embedding features print(embeddings["user_id"].shape) # [32, 128] print(embeddings["item_id"].shape) # [32, 128] (summed) print(embeddings["category"].shape) # [32, 64] print(embeddings["other_feature"].shape) # [32, 10] (pass-through) Advanced usage with ragged tensors:
from recis.ragged.tensor import RaggedTensor # Variable-length sequences values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) offsets = torch.tensor([0, 3, 5, 8]) # Batch boundaries ragged_features = RaggedTensor(values, offsets) features = { "sequence_ids": ragged_features, "user_id": torch.tensor([[100], [200], [300]]), } embeddings = engine(features)
- __init__(emb_options: dict[str, EmbeddingOption])[source]
Initialize embedding engine with multiple embedding options.
- Parameters:
emb_options (dict[str, EmbeddingOption]) – Dictionary mapping feature names to their embedding configurations.
- Raises:
RuntimeError – If embedding options have conflicting configurations for the same shared name.
- forward(input_features: dict[str, Tensor])[source]
Forward pass for batch embedding processing.
This method processes multiple features efficiently by: 1. Grouping features by their runtime characteristics 2. Performing coalesced ID exchange across workers 3. Looking up embeddings in batches 4. Reducing embeddings using specified combiners 5. Splitting results back to individual features
- Parameters:
input_features (dict[str, torch.Tensor]) – Dictionary mapping feature names to their input tensors. Features not in embedding options will be passed through unchanged.
- Returns:
- Dictionary mapping feature names to their
processed outputs. Embedding features return embedding tensors, while non-embedding features are passed through.
- Return type:
Example:
features = { "user_id": torch.tensor([[1, 2], [3, 4]]), "item_id": torch.tensor([[10], [20]]), "raw_feature": torch.randn(2, 5), # Pass-through } outputs = engine(features) # outputs["user_id"]: embedding tensor [2, embedding_dim] # outputs["item_id"]: embedding tensor [2, embedding_dim] # outputs["raw_feature"]: original tensor [2, 5]