Source code for recis.nn.modules.embedding_engine

import copy
import hashlib
import math
from collections import defaultdict

import torch
from torch import nn

from recis.nn.modules.embedding import DynamicEmbedding, EmbeddingOption
from recis.ragged.tensor import RaggedTensor
from recis.utils.logger import Logger


logger = Logger(__name__)


class HashTableCoalescedGroup:
    """Groups embedding options with similar configurations for coalesced operations.

    This class manages multiple embedding options that share similar configurations,
    enabling them to be processed together efficiently. It handles the mapping
    between feature names and their corresponding embedding configurations.

    Attributes:
        _emb_opt (EmbeddingOption): Base embedding option for the group.
        _fea_to_runtime (dict): Mapping from feature names to runtime info.
        _fea_to_encode_id (dict): Mapping from feature names to encode IDs.
        _fea_to_info (dict): Mapping from feature names to detailed info.
        _runtime_info (set): Set of unique runtime information strings.
        _encode_idx (int): Current encode index counter.
        _children (dict): Mapping from shared names to encode IDs.
        _children_info (dict): Additional children information.
        _name (str): Name of the coalesced group.

    Example:

    .. code-block:: python

        group = HashTableCoalescedGroup("user_group")

        # Add embedding options to the group
        user_opt = EmbeddingOption(embedding_dim=64, shared_name="user_emb")
        group.add_option("user_id", user_opt)

        profile_opt = EmbeddingOption(embedding_dim=64, shared_name="profile_emb")
        group.add_option("user_profile", profile_opt)

        # Get embedding info for the group
        group_emb_info = group.embedding_info()

    """

    def __init__(self, name):
        """Initialize coalesced group.

        Args:
            name (str): Name of the coalesced group.
        """
        self._emb_opt = None
        self._fea_to_runtime = {}
        self._fea_to_encode_id = {}
        self._fea_to_info = {}
        self._runtime_info = set()
        self._encode_idx = -1
        self._children = {}
        self._children_info = {}
        self._name = name

    def add_option(self, fea_name, emb_opt):
        """Add an embedding option to the group.

        Args:
            fea_name (str): Name of the feature.
            emb_opt (EmbeddingOption): Embedding option to add.
        """
        if self._emb_opt is None:
            self._emb_opt = copy.deepcopy(emb_opt)
        if emb_opt.shared_name not in self._children:
            self._encode_idx += 1
            self._children[emb_opt.shared_name] = self._encode_idx
        encode_id = self._children[emb_opt.shared_name]
        combiner = emb_opt.combiner
        dim = emb_opt.embedding_dim
        use_weight = emb_opt.use_weight
        combiner_kwargs = emb_opt.combiner_kwargs
        admit_hook = emb_opt.admit_hook
        fp16_enabled = emb_opt.fp16_enabled
        if emb_opt.runtime_info() not in self._runtime_info:
            self._runtime_info.add(emb_opt.runtime_info())
        self._fea_to_runtime[fea_name] = emb_opt.runtime_info()
        self._fea_to_encode_id[fea_name] = encode_id
        self._fea_to_info[fea_name] = {
            "combiner": combiner,
            "dim": dim,
            "use_weight": use_weight,
            "combiner_kwargs": combiner_kwargs,
            "admit_hook": admit_hook,
            "fp16_enabled": fp16_enabled,
        }

    def runtime_info(self, fea_name):
        """Get runtime information for a feature.

        Args:
            fea_name (str): Name of the feature.

        Returns:
            str: Runtime information string.
        """
        return self._fea_to_runtime[fea_name]

    def encode_id(self, fea_name):
        """Get encode ID for a feature.

        Args:
            fea_name (str): Name of the feature.

        Returns:
            int: Encode ID for the feature.
        """
        return self._fea_to_encode_id[fea_name]

    def children_info(self, fea_name):
        """Get detailed information for a feature.

        Args:
            fea_name (str): Name of the feature.

        Returns:
            dict: Dictionary containing feature configuration details.
        """
        return self._fea_to_info[fea_name]

    def embedding_info(self):
        """Get embedding option for the entire group.

        Returns:
            EmbeddingOption: Embedding option configured for coalesced operations.

        Raises:
            RuntimeError: If no options have been added to the group.
        """
        if self._emb_opt is None:
            raise RuntimeError("HashTableCoalescedGroup has not build any option")
        return EmbeddingOption(
            embedding_dim=self._emb_opt.embedding_dim,
            dtype=self._emb_opt.dtype,
            device=self._emb_opt.device,
            shared_name=self._name,
            children=[
                k for k, v in sorted(self._children.items(), key=lambda item: item[1])
            ],
            coalesced=True,
            initializer=self._emb_opt.initializer,
            grad_reduce_by=self._emb_opt.grad_reduce_by,
            fp16_enabled=self._emb_opt.fp16_enabled,
            filter_hook=self._emb_opt.filter_hook,
        )


class RuntimeGroupFeature:
    """Groups features with the same runtime characteristics for efficient processing.

    This class collects features that have the same combiner, dimension, and other
    runtime properties, enabling them to be processed together in a single operation.
    It handles the coalescing of multiple feature tensors into unified representations.

    Attributes:
        _dim (int): Embedding dimension.
        _combiner (str): Combiner type ("sum", "mean", "tile").
        _admit_hook: Admission hook for feature control.
        _use_weight (bool): Whether to use weights.
        _offset_dtype (torch.dtype): Data type for offsets.
        _names (List[str]): List of feature names.
        _encode_ids (List[int]): List of encode IDs.
        _ids (List[torch.Tensor]): List of feature ID tensors.
        _weights (List[torch.Tensor]): List of weight tensors.
        _offsets (List[torch.Tensor]): List of offset tensors.
        _max_sizes (List[int]): List of maximum sizes.
        _split_sizes (List[int]): List of split sizes.
        _shapes (List[tuple]): List of output shapes.
        _coalesced_ids (torch.Tensor): Coalesced ID tensor.
        _coalesced_weights (torch.Tensor): Coalesced weight tensor.
        _coalesced_offsets (torch.Tensor): Coalesced offset tensor.
        _combiner_kwargs (dict): Additional combiner arguments.

    Example:

    .. code-block:: python

        # Create runtime group for sum combiner
        group = RuntimeGroupFeature(
            dim=64,
            combiner="sum",
            use_weight=True,
            offset_dtype=torch.int32,
            admit_hook=None,
        )

        # Add features to the group
        user_ids = torch.tensor([[1, 2], [3, 4]])
        group.add_fea("user_id", user_ids, encode_id=0, combiner_kwargs={})

        item_ids = torch.tensor([[10, 20], [30, 40]])
        group.add_fea("item_id", item_ids, encode_id=1, combiner_kwargs={})

        # Coalesce features for efficient processing
        group.coalesce()

        # Access coalesced tensors
        coalesced_ids = group.coalesced_ids()
        coalesced_offsets = group.coalesced_offsets()

    """

    def __init__(
        self,
        dim,
        combiner,
        use_weight,
        offset_dtype,
        admit_hook,
        fp16_enabled,
        **kwargs,
    ):
        """Initialize runtime group feature.

        Args:
            dim (int): Embedding dimension.
            combiner (str): Combiner type ("sum", "mean", "tile").
            use_weight (bool): Whether to use weights.
            offset_dtype (torch.dtype): Data type for offsets.
            admit_hook: Admission hook for feature control.
            fp16_enabled: (bool): Whether use fp16 for int emb.
            **kwargs: Additional keyword arguments.
        """
        self._dim = dim
        self._combiner = combiner
        self._admit_hook = admit_hook
        self._use_weight = use_weight
        self._offset_dtype = offset_dtype
        self._fp16_enabled = fp16_enabled
        # feature names
        self._names = []
        # coalesced ids
        self._encode_ids = []
        # coalesced ragged tensor values
        self._ids = []
        # coalesced ragged tensor weights
        self._weights = []
        # coalesced ragged tensor offsets
        self._offsets = []
        # coalesced ragged tensor max size
        self._max_sizes = []

        # coalesced ragged tensor split size
        self._split_sizes = []
        # coalesced ragged tensor shapes
        self._shapes = []
        self._coalesced_ids = None
        self._coalesced_weights = None
        self._coalesced_offsets = None
        self._combiner_kwargs = {"bs": [], "tile_len": []}

    def combiner(self):
        """Get the combiner type.

        Returns:
            str: Combiner type ("sum", "mean", "tile").
        """
        return self._combiner

    def admit_hook(self):
        """Get the admission hook.

        Returns:
            AdmitHook: Admission hook for feature control.
        """
        return self._admit_hook

    def fp16_enabled(self):
        return self._fp16_enabled

    def names(self):
        """Get the list of feature names.

        Returns:
            List[str]: List of feature names in the group.
        """
        return self._names

    def shapes(self):
        """Get the list of output shapes.

        Returns:
            List[tuple]: List of output shapes for each feature.
        """
        return self._shapes

    def coalesced_ids(self):
        """Get the coalesced ID tensor.

        Returns:
            torch.Tensor: Coalesced ID tensor containing all feature IDs.
        """
        return self._coalesced_ids

    def coalesced_weights(self):
        """Get the coalesced weight tensor.

        Returns:
            torch.Tensor: Coalesced weight tensor, or None if weights not used.
        """
        return self._coalesced_weights

    def coalesced_offsets(self):
        """Get the coalesced offset tensor.

        Returns:
            torch.Tensor: Coalesced offset tensor for segment operations.
        """
        return self._coalesced_offsets

    def combiner_kwargs(self):
        """Get additional combiner arguments.

        Returns:
            dict: Dictionary containing combiner-specific arguments.
        """
        return self._combiner_kwargs

    def split_size(self):
        """Get the list of split sizes.

        Returns:
            List[int]: List of split sizes for each feature.
        """
        return self._split_sizes

    def clear_ids(self):
        """Clear coalesced ID and offset tensors to free memory."""
        self._coalesced_ids = None
        self._coalesced_offsets = None

    def _format_tensor(self, input_tensor, combiner, dim, use_weight, combiner_kwargs):
        """Format input tensor for processing.

        This method processes both dense tensors and ragged tensors, extracting
        the necessary components for embedding lookup and aggregation.

        Args:
            input_tensor (Union[torch.Tensor, RaggedTensor]): Input tensor.
            combiner (str): Combiner type.
            dim (int): Embedding dimension.
            use_weight (bool): Whether to use weights.
            combiner_kwargs (dict): Additional combiner arguments.

        Returns:
            Tuple containing:
                - torch.Tensor: Values tensor
                - torch.Tensor: Weight tensor (or None)
                - torch.Tensor: Offsets tensor
                - int: Maximum size
                - int: Split size
                - tuple: Dense shape
                - dict: Updated combiner kwargs

        Raises:
            RuntimeError: If RaggedTensor is not properly padded.
            TypeError: If input tensor type is not supported.
        """
        combiner_kwargs = copy.copy(combiner_kwargs)
        if isinstance(input_tensor, RaggedTensor):
            val = input_tensor.values()
            weight = input_tensor.weight()
            offsets = input_tensor.offsets()[-1]
            max_size = val.numel()
            split_size = math.prod(input_tensor.real_shape(0, -1))
            if not split_size == (offsets.numel() - 1):
                raise RuntimeError(
                    f"RaggedTensor must pad before lookup, got: {input_tensor}"
                )
            if combiner == "tile":
                split_size *= combiner_kwargs["tile_len"]
                shape = (
                    input_tensor.real_shape(0, 1)[0],
                    combiner_kwargs["tile_len"],
                    dim,
                )
                combiner_kwargs["bs"] = input_tensor.shape[0]
            else:
                shape = input_tensor.real_shape(0, -1) + (dim,)
        elif isinstance(input_tensor, torch.Tensor):
            if input_tensor.is_sparse:
                raise TypeError("EmbeddingEngine doesn't support sparse ids")
            val = input_tensor.view(-1)
            weight = None
            bs = input_tensor.shape[0]
            fea_dim = input_tensor.shape[1]
            offsets = torch.arange(
                0,
                (bs + 1) * fea_dim,
                step=fea_dim,
                dtype=self._offset_dtype,
                device=val.device,
            )
            max_size = val.numel()
            split_size = math.prod(input_tensor.shape[:-1])
            if combiner == "tile":
                split_size *= combiner_kwargs["tile_len"]
                shape = (input_tensor.shape[0], dim * combiner_kwargs["tile_len"])
                combiner_kwargs["bs"] = input_tensor.shape[0]
            else:
                shape = input_tensor.shape[:-1] + (dim,)
        else:
            raise TypeError(
                f"EmbeddingEngine only support tensor but get {type(input_tensor)}"
            )
        if not use_weight:
            weight = None
        else:
            if weight is None:
                if self._fp16_enabled:
                    weight = torch.ones_like(val, dtype=torch.float16)
                else:
                    weight = torch.ones_like(val, dtype=torch.float32)
        offsets = offsets.to(self._offset_dtype)

        return val, weight, offsets, max_size, split_size, shape, combiner_kwargs

    def add_fea(self, fea_name, fea, encode_id, combiner_kwargs):
        """Add a feature to the runtime group.

        Args:
            fea_name (str): Name of the feature.
            fea (Union[torch.Tensor, RaggedTensor]): Feature tensor.
            encode_id (int): Encode ID for the feature.
            combiner_kwargs (dict): Additional combiner arguments.
        """
        self._names.append(fea_name)
        self._encode_ids.append(encode_id)
        ids, weight, offset, max_size, split_size, dense_shape, combiner_kwargs = (
            self._format_tensor(
                fea, self._combiner, self._dim, self._use_weight, combiner_kwargs
            )
        )
        self._ids.append(ids)
        self._weights.append(weight)
        self._offsets.append(offset)
        self._max_sizes.append(max_size)
        self._split_sizes.append(split_size)
        self._shapes.append(dense_shape)
        if self._combiner == "tile":
            self._combiner_kwargs["bs"].append(combiner_kwargs["bs"])
            self._combiner_kwargs["tile_len"].append(combiner_kwargs["tile_len"])

    def clear_child(self):
        """Clear child data to free memory."""
        # coalesced ids
        self._encode_ids = []
        # coalesced ragged tensor values
        self._ids = []
        # coalesced ragged tensor offsets
        self._offsets = []
        # coalesced ragged tensor max size
        self._max_sizes = []
        # coalesced ragged tensor weights
        self._weights = []

    def coalesce(self):
        """Coalesce all features in the group into unified tensors.

        This method combines all individual feature tensors into coalesced
        representations that can be processed efficiently in a single operation.
        """
        merge_id = torch.ops.recis.ids_encode(
            self._ids,
            torch.tensor(
                self._encode_ids, dtype=torch.int64, device=self._ids[0].device
            ),
        )
        merge_offset = torch.ops.recis.merge_offsets(
            self._offsets,
            torch.tensor(self._max_sizes, dtype=self._offsets[0].dtype, device="cpu"),
        )
        merge_weight = (
            None if self._weights[0] is None else torch.cat(self._weights, dim=0)
        )
        self._coalesced_ids = merge_id
        self._coalesced_weights = merge_weight
        self._coalesced_offsets = merge_offset
        self.clear_child()


[docs] class EmbeddingEngine(nn.Module): """Embedding engine for efficient batch processing of multiple embeddings. This module provides a high-level interface for managing and processing multiple embedding tables efficiently. It automatically groups embeddings with similar configurations and processes them using coalesced operations to minimize communication overhead in distributed training. Key features: - Automatic grouping of similar embedding configurations - Coalesced operations for improved performance - Support for mixed tensor types (dense and ragged) - Flexible feature routing (embedding vs. pass-through) - Memory-efficient processing pipeline Args: emb_options (dict[str, EmbeddingOption]): Dictionary mapping feature names to their embedding options. Example: Multi-embedding scenario: .. code-block:: python import torch from recis.nn.modules.embedding import EmbeddingOption from recis.nn.modules.embedding_engine import EmbeddingEngine # Define embedding options for different features emb_options = { "user_id": EmbeddingOption( embedding_dim=128, shared_name="user_embedding", combiner="sum", trainable=True, ), "item_id": EmbeddingOption( embedding_dim=128, shared_name="item_embedding", combiner="sum", trainable=True, ), "category": EmbeddingOption( embedding_dim=64, shared_name="category_embedding", combiner="mean", trainable=True, ), } # Create embedding engine engine = EmbeddingEngine(emb_options) # Prepare input features batch_size = 32 features = { "user_id": torch.randint(0, 10000, (batch_size, 1)), "item_id": torch.randint(0, 50000, (batch_size, 5)), # Multi-hot "category": torch.randint(0, 100, (batch_size, 1)), "other_feature": torch.randn(batch_size, 10), # Pass-through } # Forward pass embeddings = engine(features) # Results contain embeddings for configured features # and pass-through for non-embedding features print(embeddings["user_id"].shape) # [32, 128] print(embeddings["item_id"].shape) # [32, 128] (summed) print(embeddings["category"].shape) # [32, 64] print(embeddings["other_feature"].shape) # [32, 10] (pass-through) Advanced usage with ragged tensors: .. code-block:: python from recis.ragged.tensor import RaggedTensor # Variable-length sequences values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) offsets = torch.tensor([0, 3, 5, 8]) # Batch boundaries ragged_features = RaggedTensor(values, offsets) features = { "sequence_ids": ragged_features, "user_id": torch.tensor([[100], [200], [300]]), } embeddings = engine(features) """
[docs] def __init__(self, emb_options: dict[str, EmbeddingOption]): """Initialize embedding engine with multiple embedding options. Args: emb_options (dict[str, EmbeddingOption]): Dictionary mapping feature names to their embedding configurations. Raises: RuntimeError: If embedding options have conflicting configurations for the same shared name. """ super().__init__() self._ht = nn.ModuleDict() self._fea_group = {} self._fea_to_ht = {} self._fea_to_group = {} self._emb_opts = emb_options self._offset_dtype = torch.int32 tmp_ht_to_coalesced = {} # check hashtable for fea_name, emb_opt in emb_options.items(): ht_name = f"CoalescedHashtable_{hashlib.sha256(emb_opt.coalesced_info().encode()).hexdigest()}" if emb_opt.shared_name not in tmp_ht_to_coalesced: tmp_ht_to_coalesced[emb_opt.shared_name] = ht_name elif not tmp_ht_to_coalesced[emb_opt.shared_name] == ht_name: raise RuntimeError( f"Create embedding failed, emb sahred name already created by info: {self._fea_group[ht_name]._emb_opt.coalesced_info()}, current: {emb_opt.coalesced_info()}" ) if ht_name not in self._fea_group: self._fea_group[ht_name] = HashTableCoalescedGroup(ht_name) self._fea_group[ht_name].add_option(fea_name, emb_opt) self._fea_to_ht[fea_name] = ht_name self._fea_to_group[fea_name] = self._fea_group[ht_name] for ht_name, fea_group in self._fea_group.items(): self._ht[ht_name] = DynamicEmbedding(fea_group.embedding_info()) logger.info( f"ht name: {ht_name}, coalesced info: {fea_group.embedding_info().coalesced_info()}, children: {fea_group.embedding_info().children}" )
[docs] def forward(self, input_features: dict[str, torch.Tensor]): """Forward pass for batch embedding processing. This method processes multiple features efficiently by: 1. Grouping features by their runtime characteristics 2. Performing coalesced ID exchange across workers 3. Looking up embeddings in batches 4. Reducing embeddings using specified combiners 5. Splitting results back to individual features Args: input_features (dict[str, torch.Tensor]): Dictionary mapping feature names to their input tensors. Features not in embedding options will be passed through unchanged. Returns: dict[str, torch.Tensor]: Dictionary mapping feature names to their processed outputs. Embedding features return embedding tensors, while non-embedding features are passed through. Example: .. code-block:: python features = { "user_id": torch.tensor([[1, 2], [3, 4]]), "item_id": torch.tensor([[10], [20]]), "raw_feature": torch.randn(2, 5), # Pass-through } outputs = engine(features) # outputs["user_id"]: embedding tensor [2, embedding_dim] # outputs["item_id"]: embedding tensor [2, embedding_dim] # outputs["raw_feature"]: original tensor [2, 5] """ group_features, direct_outs = self.group_features(input_features) group_exchange_ids = self.group_exchange_ids(group_features) group_exchange_embs = self.group_exchange_embs( group_exchange_ids, group_features ) del group_exchange_ids group_embs = self.group_reduce(group_exchange_embs, group_features) del group_exchange_embs emb_outs = self.split_group_embs(group_embs, group_features) del group_embs, group_features direct_outs = self.format_direct_out(direct_outs) emb_outs.update(direct_outs) return emb_outs
def format_direct_out(self, ori_outs): """Format direct output features (non-embedding features). Args: ori_outs (dict): Dictionary of original output features. Returns: dict: Dictionary of formatted output features. Raises: TypeError: If sparse tensors are encountered. """ outs = {} for k, v in ori_outs.items(): if isinstance(v, RaggedTensor): outs[k] = v.to_dense() elif isinstance(v, torch.Tensor): if v.is_sparse: raise TypeError("EmbeddingEngine doesn't support sparse tensor") outs[k] = v else: outs[k] = v return outs def group_features(self, input_dict: dict[str, torch.Tensor]): """Group input features by their runtime characteristics. This method separates embedding features from pass-through features and groups embedding features by their runtime properties (combiner, dimension, etc.) for efficient batch processing. Args: input_dict (dict[str, torch.Tensor]): Input feature dictionary. Returns: Tuple containing: - dict: Grouped features for embedding processing - dict: Direct output features (pass-through) """ group_features = defaultdict(dict) direct_out = {} for fea_name, fea_tensor in input_dict.items(): if fea_name not in self._fea_to_ht: direct_out[fea_name] = fea_tensor else: ht_name = self._fea_to_ht[fea_name] runtime_info = self._fea_to_group[fea_name].runtime_info(fea_name) if runtime_info not in group_features[ht_name]: group_features[ht_name][runtime_info] = RuntimeGroupFeature( **(self._fea_to_group[fea_name].children_info(fea_name)), offset_dtype=self._offset_dtype, ) group_features[ht_name][runtime_info].add_fea( fea_name, fea_tensor, self._fea_to_group[fea_name].encode_id(fea_name), self._fea_to_group[fea_name].children_info(fea_name)[ "combiner_kwargs" ], ) return group_features, direct_out def group_exchange_ids(self, group_features): """Exchange feature IDs across workers for distributed lookup. Args: group_features (dict): Grouped features for processing. Returns: dict: Dictionary containing exchange ID results for each group. """ group_exchange_ids = defaultdict(dict) for ht_name, group_fea in group_features.items(): ht = self._ht[ht_name] for run_name, run_fea in group_fea.items(): run_fea.coalesce() group_exchange_ids[ht_name][run_name] = ht.exchange_ids( run_fea.coalesced_ids(), run_fea.coalesced_offsets() ) run_fea.clear_ids() return group_exchange_ids def group_exchange_embs(self, group_exchange_ids, group_features): """Perform embedding lookup and exchange results back. Args: group_exchange_ids (dict): Exchange ID results from previous step. group_features (dict): Grouped features for processing. Returns: dict: Dictionary containing exchange embedding results for each group. """ group_exchange_embs = defaultdict(dict) for ht_name, exchange_ids in group_exchange_ids.items(): ht = self._ht[ht_name] for run_name, run_exchange_ids in exchange_ids.items(): group_exchange_embs[ht_name][run_name] = ht.lookup_exchange_emb( run_exchange_ids, group_features[ht_name][run_name].admit_hook(), ) return group_exchange_embs def group_reduce(self, group_exchange_embs, group_features): """Reduce embeddings using specified combiners. Args: group_exchange_embs (dict): Exchange embedding results. group_features (dict): Grouped features for processing. Returns: dict: Dictionary containing reduced embeddings for each group. """ group_embs = defaultdict(dict) for ht_name, exchange_emb in group_exchange_embs.items(): group_fea = group_features[ht_name] ht = self._ht[ht_name] for run_name in exchange_emb.keys(): group_embs[ht_name][run_name] = ht.emb_reduce( exchange_emb[run_name], group_fea[run_name].coalesced_weights(), group_fea[run_name].combiner(), group_fea[run_name].combiner_kwargs(), group_fea[run_name].fp16_enabled(), ) return group_embs def split_group_embs(self, group_embs, group_features): """Split group embeddings back to individual feature embeddings. Args: group_embs (dict): Reduced group embeddings. group_features (dict): Grouped features for processing. Returns: dict[str, torch.Tensor]: Dictionary mapping feature names to their individual embedding tensors. """ emb_outs = {} for ht_name, group_emb in group_embs.items(): group_fea = group_features[ht_name] for run_name, run_emb in group_emb.items(): emb_list = list( torch.split(run_emb, group_fea[run_name].split_size(), dim=0) ) for name, emb, out_shape in zip( *( group_fea[run_name].names(), emb_list, group_fea[run_name].shapes(), ) ): if not self._emb_opts[name].trainable: emb = emb.detach() emb_outs[name] = emb.view(out_shape) return emb_outs