import json
import math
import os
from dataclasses import dataclass, field
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from recis.nn.functional.embedding_ops import (
ids_partition,
ragged_embedding_segment_reduce,
)
from recis.nn.hashtable_hook import AdmitHook, FilterHook
from recis.nn.initializers import ConstantInitializer, Initializer
from recis.nn.modules.hashtable import HashTable, gen_slice
from recis.ragged.tensor import RaggedTensor
from recis.utils.logger import Logger
logger = Logger(__name__)
class EmbeddingExchange(torch.autograd.Function):
"""Autograd function for distributed embedding exchange across workers.
This function handles the forward and backward passes for exchanging
embeddings between different workers in a distributed setting using
all-to-all communication patterns.
"""
@staticmethod
def forward(
ctx,
embedding: torch.Tensor,
parts: List[int],
parts_reverse: List[int],
pg: dist.ProcessGroup = None,
):
"""Forward pass for embedding exchange.
Args:
ctx: Autograd context for storing information for backward pass.
embedding (torch.Tensor): Input embedding tensor to exchange.
parts (List[int]): List of partition sizes for input splitting.
parts_reverse (List[int]): List of partition sizes for output splitting.
pg (dist.ProcessGroup, optional): Process group for communication.
Defaults to None (uses default group).
Returns:
Tuple containing:
- torch.Tensor: Exchanged embedding tensor
- object: Async operation handle for synchronization
- List[int]: Original partial shape for reconstruction
"""
ctx.pg = pg
ctx.parts = parts
ctx.parts_reverse = parts_reverse
ctx.origin_partial_shape = list(embedding.shape)
ctx.origin_partial_shape[0] = -1
ctx.block_size = None
if ctx.pg is None:
ctx.pg = dist.distributed_c10d._get_default_group()
# exchange embedding
block_size = math.prod(embedding.shape[1:])
embedding = embedding.view(-1)
ctx.block_size = block_size
input_split_sizes = [block_size * part for part in parts]
output_split_sizes = [block_size * part for part in parts_reverse]
output_embedding = torch.empty(
(sum(output_split_sizes)), dtype=embedding.dtype, device=embedding.device
)
emb_await = dist.all_to_all_single(
output=output_embedding,
input=embedding,
input_split_sizes=input_split_sizes,
output_split_sizes=output_split_sizes,
group=pg,
async_op=True,
)
return output_embedding, emb_await, ctx.origin_partial_shape
@staticmethod
def backward(ctx, grad_output, _1, _2):
"""Backward pass for embedding exchange.
Args:
ctx: Autograd context containing forward pass information.
grad_output (torch.Tensor): Gradient from the next layer.
_1: Unused gradient argument.
_2: Unused gradient argument.
Returns:
Tuple containing:
- torch.Tensor: Gradient for embedding input
- None: No gradient for parts
- None: No gradient for parts_reverse
- None: No gradient for pg
"""
parts = ctx.parts
reverse_parts = ctx.parts_reverse
block_size = ctx.block_size
input_split_sizes = [part * block_size for part in reverse_parts]
output_split_sizes = [part * block_size for part in parts]
output_grad = torch.empty(
(sum(output_split_sizes)),
dtype=grad_output.dtype,
device=grad_output.device,
)
dist.all_to_all_single(
output=output_grad,
input=grad_output.view(-1),
input_split_sizes=input_split_sizes,
output_split_sizes=output_split_sizes,
group=ctx.pg,
)
output = output_grad.view(ctx.origin_partial_shape)
return output, None, None, None
@dataclass
class ExchangeIDsResults:
"""Data class for storing ID exchange results.
Attributes:
ids (torch.Tensor): Exchanged ID tensor.
parts (List[int]): Original partition sizes.
parts_reverse (List[int]): Reverse partition sizes.
reverse_index (torch.Tensor): Index for reversing the exchange.
offsets (torch.Tensor): Offset tensor for ragged operations.
ids_await (Optional[object]): Async operation handle. Defaults to None.
"""
ids: torch.Tensor
parts: List[int]
parts_reverse: List[int]
reverse_index: torch.Tensor
offsets: torch.Tensor
ids_await: Optional[object] = None
@dataclass
class ExchangeEmbResults:
"""Data class for storing embedding exchange results.
Attributes:
emb (torch.Tensor): Exchanged embedding tensor.
reverse_index (torch.Tensor): Index for reversing the exchange.
emb_shape (List[int]): Original embedding shape.
offsets (torch.Tensor): Offset tensor for ragged operations.
emb_await (Optional[object]): Async operation handle. Defaults to None.
"""
emb: torch.Tensor
reverse_index: torch.Tensor
emb_shape: list[int]
offsets: torch.Tensor
emb_await: Optional[object] = None
[docs]
@dataclass
class EmbeddingOption:
"""Configuration class for dynamic embedding parameters.
This class encapsulates all configuration options for dynamic embedding,
including dimension settings, device placement, training options, and
distributed communication parameters.
Attributes:
embedding_dim (int): Dimension of embedding vectors. Defaults to 16.
block_size (int): Block size for hash table storage. Defaults to 10240.
dtype (torch.dtype): Data type for embeddings. Defaults to torch.float32.
device (torch.device): Device for computation. Defaults to CPU.
trainable (bool): Whether embeddings are trainable. Defaults to True.
pg (dist.ProcessGroup): Process group for distributed training. Defaults to None.
max_partition_num (int): Maximum partition number. Defaults to 65536.
shared_name (str): Shared name for embedding table. Defaults to "embedding".
children (List[str]): List of child embedding names. Defaults to empty list.
coalesced (Optional[bool]): Whether to use coalesced operations. Defaults to False.
initializer (Optional[Initializer]): Embedding initializer. Defaults to None.
use_weight (Optional[bool]): Whether to use weights. Defaults to True.
combiner (Optional[str]): Combiner type ("sum", "mean", "tile"). Defaults to "sum".
combiner_kwargs (Optional[dict]): Additional combiner arguments. Defaults to None.
grad_reduce_by (Optional[str]): Gradient reduction strategy. Defaults to "worker".
filter_hook (Optional[FilterHook]): Filter hook for feature filtering. Defaults to None.
admit_hook (Optional[AdmitHook]): Admit hook for feature admission. Defaults to None.
Example:
.. code-block:: python
from recis.nn.modules.embedding import EmbeddingOption
# Basic configuration
emb_opt = EmbeddingOption(
embedding_dim=128,
block_size=2048,
dtype=torch.float32,
trainable=True,
combiner="mean",
)
# Advanced configuration with hooks
emb_opt = EmbeddingOption(
embedding_dim=64,
combiner="tile",
combiner_kwargs={"tile_len": 10},
filter_hook=my_filter_hook,
admit_hook=my_admit_hook,
)
"""
embedding_dim: int = 16
block_size: int = 10240
dtype: torch.dtype = torch.float32
device: torch.device = torch.device("cpu")
trainable: bool = True
pg: dist.ProcessGroup = None
max_partition_num: int = 65536
shared_name: str = "embedding"
children: List[str] = field(default_factory=list)
coalesced: Optional[bool] = False
initializer: Optional[Initializer] = None
use_weight: Optional[bool] = True
combiner: Optional[str] = "sum"
combiner_kwargs: Optional[dict] = None
grad_reduce_by: Optional[str] = "worker"
filter_hook: Optional[FilterHook] = None
admit_hook: Optional[AdmitHook] = None
# Convert embeddings of int8 type to fp16; otherwise, convert them to fp32
fp16_enabled: bool = False
def __post_init__(self):
"""Post-initialization validation and setup.
Raises:
AssertionError: If combiner is not in supported types.
RuntimeError: If tile combiner is used without proper configuration.
"""
if not self.children:
self.children = [self.shared_name]
if self.initializer is None:
self.initializer = ConstantInitializer(init_val=0)
if self.fp16_enabled:
assert self.dtype in (torch.int8,), "only int8 emb can set fp16_enabled"
assert self.combiner in [
"sum",
"mean",
"tile",
], f"Hashtable combiner only support [sum/mean/tile], but got {self.combiner}"
if self.combiner == "tile":
if self.combiner_kwargs is None:
raise RuntimeError("combiner_kwargs must be set when combiner is tile.")
if "tile_len" not in self.combiner_kwargs:
raise RuntimeError(
"tile_len must be in combiner_kwargs when combiner is tile."
)
def coalesced_info(self):
"""Get coalesced configuration information.
Returns:
str: JSON string containing coalesced configuration.
"""
info = {
"dim": self.embedding_dim,
"dtype": str(self.dtype),
"device": str(self.device.type),
"initializer": str(self.initializer),
"grad_reduce_by": self.grad_reduce_by,
"filter_hook": str(self.filter_hook),
}
return json.dumps(info)
def runtime_info(self):
"""Get runtime configuration information.
Returns:
str: JSON string containing runtime configuration.
"""
info = {
"combiner": self.combiner,
"use_weight": self.use_weight,
"trainable": self.trainable,
"admit_hook": str(self.admit_hook),
"fp16_enabled": self.fp16_enabled,
}
return json.dumps(info)
[docs]
class DynamicEmbedding(torch.nn.Module):
"""Dynamic embedding module for distributed sparse feature learning.
This module provides a distributed dynamic embedding table that can
automatically handle feature admission, eviction, and cross-worker
communication. It supports both dense tensors and ragged tensors
for flexible sparse feature handling.
The module uses a hash table backend for efficient sparse storage
and supports various combiners (sum, mean, tile) for aggregating
multiple embeddings per sample.
Args:
emb_opt (EmbeddingOption): Configuration options for the embedding.
pg (dist.ProcessGroup, optional): Process group for distributed
communication. Defaults to None (uses default group).
Example:
Basic usage:
.. code-block:: python
from recis.nn import DynamicEmbedding, EmbeddingOption
from recis.nn.initializers import TruncNormalInitializer
# Configure embedding options
emb_opt = EmbeddingOption(
embedding_dim=64,
shared_name="user_embedding",
combiner="sum",
initializer=TruncNormalInitializer(std=0.01)
)
# Create dynamic embedding
embedding = DynamicEmbedding(emb_opt)
# Forward propagation
ids = torch.LongTensor([1, 2, 3, 4])
emb_output = embedding(ids)
Advanced usage with ragged tensors:
.. code-block:: python
from recis.ragged.tensor import RaggedTensor
# Create ragged tensor for variable-length sequences
values = torch.tensor([1, 2, 3, 4, 5, 6, 7])
offsets = torch.tensor([0, 3, 5, 7]) # Batch boundaries
ragged_ids = RaggedTensor(values, offsets)
# Forward pass
embeddings = embedding(ragged_ids)
"""
[docs]
def __init__(self, emb_opt: EmbeddingOption, pg: dist.ProcessGroup = None):
"""Initialize dynamic embedding module.
Args:
emb_opt (EmbeddingOption): Configuration options for the embedding.
pg (dist.ProcessGroup, optional): Process group for distributed
communication. Defaults to None.
"""
super().__init__()
self._emb_opt = emb_opt
self._world_size = int(os.environ.get("WORLD_SIZE", 1))
self._rank = int(os.environ.get("RANK", 0))
if pg is None:
self._pg = dist.distributed_c10d._get_default_group()
self._cpu_device = torch.device("cpu")
self._gpu_device = torch.device(int(os.environ.get("LOCAL_RANK", 0)))
self._hashtable = HashTable(
embedding_shape=[self._emb_opt.embedding_dim],
block_size=self._emb_opt.block_size,
dtype=self._emb_opt.dtype,
device=self._emb_opt.device,
initializer=self._emb_opt.initializer,
children=self._emb_opt.children,
name=self._emb_opt.shared_name,
coalesced=self._emb_opt.coalesced,
slice=gen_slice(shard_index=self._rank, shard_num=self._world_size),
grad_reduce_by=self._emb_opt.grad_reduce_by,
filter_hook=self._emb_opt.filter_hook,
)
def deal_with_tensor(self, input_tensor: Union[torch.Tensor, RaggedTensor]):
"""Process input tensor to extract values, offsets, weights, and shape.
This method handles both dense tensors and ragged tensors, extracting
the necessary components for embedding lookup and aggregation.
Args:
input_tensor (Union[torch.Tensor, RaggedTensor]): Input tensor
containing feature IDs. Can be either a dense tensor or
a ragged tensor for variable-length sequences.
Returns:
Tuple containing:
- torch.Tensor: Flattened values tensor
- torch.Tensor: Offsets tensor for segment operations
- Optional[torch.Tensor]: Weights tensor (None for dense tensors)
- Tuple: Original shape for output reconstruction
Raises:
RuntimeError: If RaggedTensor is not properly padded.
TypeError: If input tensor type is not supported or is sparse.
"""
if isinstance(input_tensor, RaggedTensor):
val = input_tensor.values()
weight = input_tensor.weight()
offsets = input_tensor.offsets()[-1]
shape = input_tensor.real_shape(0, -1)
if not math.prod(shape) == (offsets.numel() - 1):
raise RuntimeError(
f"RaggedTensor must pad before lookup, got: {input_tensor}"
)
elif isinstance(input_tensor, torch.Tensor):
if input_tensor.is_sparse:
raise TypeError("RaggedDynamicEmbedding doesn't support sparse ids")
else:
val = input_tensor.view(-1)
bs = input_tensor.shape[0]
fea_dim = input_tensor.shape[1]
offsets = torch.arange(bs + 1, device=val.device) * (fea_dim)
weight = None
shape = input_tensor.shape[:-1]
else:
raise TypeError(
f"RaggedDynamicEmbedding only support tensor but get {type(input_tensor)}"
)
return val, offsets, weight, shape
[docs]
def forward(
self,
input_ids: Union[torch.Tensor, RaggedTensor],
input_weights=None,
):
"""Forward pass of dynamic embedding lookup.
Performs distributed embedding lookup with the following steps:
1. Process input tensor to extract IDs, offsets, and weights
2. Exchange IDs across workers for distributed lookup
3. Perform embedding lookup and exchange results back
4. Aggregate embeddings using the specified combiner
Args:
input_ids (Union[torch.Tensor, RaggedTensor]): Input feature IDs.
For dense tensors, shape should be [batch_size, num_features].
For ragged tensors, supports variable-length sequences.
input_weights (torch.Tensor, optional): Weights for weighted
aggregation. Only used when input_ids is a dense tensor.
Defaults to None.
Returns:
torch.Tensor: Aggregated embeddings with shape:
- For sum/mean combiner: [batch_size, embedding_dim]
- For tile combiner: [batch_size, tile_len * embedding_dim]
Example:
.. code-block:: python
# Dense tensor input
ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
embeddings = embedding(ids) # Shape: [2, embedding_dim]
# With weights
weights = torch.tensor([[0.5, 1.0, 0.8], [1.2, 0.9, 1.1]])
weighted_embs = embedding(ids, weights)
# Ragged tensor input
ragged_ids = RaggedTensor(values, offsets)
ragged_embs = embedding(ragged_ids)
"""
# deal with tensor or ragged tensor
ids, offsets, weight, shape = self.deal_with_tensor(input_ids)
if input_ids is not RaggedTensor:
weight = input_weights
# exchange ids
ids_exchange_result: ExchangeIDsResults = self.exchange_ids(ids, offsets)
# lookup && exchange emb
emb_exchange_result: ExchangeEmbResults = self.lookup_exchange_emb(
ids_exchange_result, self._emb_opt.admit_hook
)
# reduce by segment
combiner_kwargs = {}
if self._emb_opt.combiner == "tile":
combiner_kwargs["tile_len"] = [self._emb_opt.combiner_kwargs["tile_len"]]
combiner_kwargs["bs"] = [ids.offsets.numel() - 1]
emb = self.emb_reduce(
emb_exchange_result, weight, self._emb_opt.combiner, combiner_kwargs
)
if self._emb_opt.combiner == "tile":
emb = emb.view(
-1, self._emb_opt.combiner_kwargs["tile_len"] * emb.shape[-1]
)
else:
out_shape = shape + (emb.shape[-1],)
emb = emb.view(out_shape)
if not self._emb_opt.trainable:
emb = emb.detach()
return emb
def exchange_ids(self, ids: torch.Tensor, offsets: torch.Tensor):
"""Exchange feature IDs across workers for distributed lookup.
This method partitions the input IDs and exchanges them across workers
using all-to-all communication, enabling each worker to lookup its
assigned portion of the embedding table.
Args:
ids (torch.Tensor): Flattened feature IDs tensor.
offsets (torch.Tensor): Offsets tensor for segment operations.
Returns:
ExchangeIDsResults: Data class containing exchanged IDs and
metadata needed for the reverse operation.
"""
ids, ids_parts, ids_reverse_index = ids_partition(
ids, self._emb_opt.max_partition_num, self._world_size
)
# sync all to all: exchange parts num
ids_parts_reverse = torch.empty_like(ids_parts)
dist.all_to_all_single(ids_parts_reverse, ids_parts, group=self._pg)
ids_parts_reverse = ids_parts_reverse.to(device="cpu")
ids_parts = ids_parts.to(device="cpu")
ids_parts_reverse = ids_parts_reverse.tolist()
ids_parts = ids_parts.tolist()
output_ids = torch.empty(
size=[sum(ids_parts_reverse)], dtype=ids.dtype, device=ids.device
)
# async all to all: exhange parts ids
ids_await = dist.all_to_all_single(
output_ids,
ids,
output_split_sizes=ids_parts_reverse,
input_split_sizes=ids_parts,
async_op=True,
)
return ExchangeIDsResults(
ids=output_ids,
ids_await=ids_await,
parts=ids_parts,
parts_reverse=ids_parts_reverse,
reverse_index=ids_reverse_index,
offsets=offsets,
)
def lookup_exchange_emb(
self, ids_exchange_result: ExchangeIDsResults, admit_hook=None
):
"""Perform embedding lookup and exchange results back to original workers.
This method waits for ID exchange to complete, performs embedding lookup
using the hash table, and then exchanges the embeddings back to their
original workers.
Args:
ids_exchange_result (ExchangeIDsResults): Results from ID exchange
containing the IDs to lookup and exchange metadata.
admit_hook (AdmitHook, optional): Hook for feature admission control.
Defaults to None.
Returns:
ExchangeEmbResults: Data class containing exchanged embeddings and
metadata needed for aggregation.
"""
ids_exchange_result.ids_await.wait()
embedding = self._hashtable(ids_exchange_result.ids, admit_hook)
embedding_async, emb_await, emb_shape = EmbeddingExchange.apply(
embedding,
ids_exchange_result.parts_reverse,
ids_exchange_result.parts,
self._pg,
)
return ExchangeEmbResults(
emb=embedding_async,
emb_await=emb_await,
reverse_index=ids_exchange_result.reverse_index,
emb_shape=emb_shape,
offsets=ids_exchange_result.offsets,
)
def wait_exchange_emb(self, emb_exchange_result):
"""Wait for embedding exchange to complete and reshape embeddings.
Args:
emb_exchange_result (ExchangeEmbResults): Results from embedding exchange.
Returns:
torch.Tensor: Reshaped embedding tensor.
"""
emb_exchange_result.emb_await.wait()
emb = emb_exchange_result.emb.view(emb_exchange_result.emb_shape)
return emb
def emb_reduce(
self, emb_exchange_result, weight, combiner, combiner_kwargs, fp16_enable=False
):
"""Aggregate embeddings using the specified combiner.
This method waits for embedding exchange to complete and then performs
segment-wise reduction using the specified combiner (sum, mean, or tile).
Args:
emb_exchange_result (ExchangeEmbResults): Results from embedding exchange.
weight (torch.Tensor, optional): Weights for weighted aggregation.
combiner (str): Combiner type ("sum", "mean", or "tile").
combiner_kwargs (dict): Additional arguments for the combiner.
fp16_enable (bool): Enable fp16 for int embedding.
Returns:
torch.Tensor: Aggregated embeddings.
"""
emb = self.wait_exchange_emb(emb_exchange_result)
if emb.dtype in (torch.int8,):
if fp16_enable:
emb = emb.to(torch.float16)
else:
emb = emb.to(torch.float32)
emb = ragged_embedding_segment_reduce(
emb,
weight,
emb_exchange_result.reverse_index,
emb_exchange_result.offsets,
combiner,
combiner_kwargs,
)
return emb