import json
from typing import Optional
from recis.utils.logger import Logger
logger = Logger(__name__)
class BaseHook:
    def __init__(self, name: str, params: Optional[dict] = None):
        self._name = name
        self._params = params if params is not None else {}
    @property
    def name(self) -> str:
        """Get the registered name of the hook policy.
        Returns:
            str: The name of the hook policy as registered in the system.
                This name is used to identify and instantiate the appropriate
                policy implementation in the HashtableHookFactory.
        """
        return self._name
    @property
    def params(self) -> dict:
        """Get the configuration parameters for the hook policy.
        Returns:
            dict: A dictionary containing the configuration parameters for
                the hook policy. The specific parameters depend on the policy
                implementation and requirements.
        """
        return self._params
    def __str__(self):
        """Return JSON string representation of the hook configuration.
        This method provides a standardized string representation of the hook
        that includes both the policy name and parameters in JSON format.
        This is useful for logging, debugging, and serialization purposes.
        Returns:
            str: JSON string containing the hook's name and parameters,
                with keys sorted for consistent output.
        """
        info = {"name": self._name, "params": self._params}
        return json.dumps(info, sort_keys=True)
[docs]
class AdmitHook(BaseHook):
    """Feature admission hook for controlling HashTable feature acceptance.
    AdmitHook implements feature admission policies that control whether new
    features (IDs) are allowed to be added to HashTable embeddings. This is
    useful for implementing read-only modes, feature freezing, or custom
    admission criteria.
    The most common use case is the "ReadOnly" policy, which prevents new
    features from being added to the embedding table and returns zero embeddings
    for unknown IDs instead of creating new entries.
    Example:
        Read-only HashTable usage:
    .. code-block:: python
        from recis.nn import HashTable
        from recis.nn.hashtable_hook import AdmitHook
        # Create HashTable
        ht = HashTable(embedding_shape=[64])
        # Create read-only admission hook
        ro_hook = AdmitHook("ReadOnly")
        # Lookup with admission control
        # Known IDs return their embeddings, unknown IDs return zeros
        embeddings = ht(ids, admit_hook=ro_hook)
        Integration with DynamicEmbedding:
    .. code-block:: python
        from recis.nn import DynamicEmbedding, EmbeddingOption
        from recis.nn.hashtable_hook import AdmitHook
        # Configure embedding with admission hook
        emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="user_embedding",
            combiner="sum",
            admit_hook=AdmitHook("ReadOnly"),
        )
        # Create embedding with read-only policy
        embedding = DynamicEmbedding(emb_opt)
        # Use in inference mode (no new embeddings created)
        ids = torch.LongTensor([1, 2, 3, 4])
        emb_output = embedding(ids)
        Multi-embedding setup with selective admission:
    .. code-block:: python
        from recis.nn import EmbeddingEngine, EmbeddingOption
        from recis.nn.hashtable_hook import AdmitHook
        # Configure different admission policies
        user_emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="user_emb",
            admit_hook=AdmitHook("ReadOnly"),  # Read-only for users
        )
        item_emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="item_emb",
            # No admission hook = normal mode (new items allowed)
        )
        # Create embedding engine
        embedding_engine = EmbeddingEngine(
            {"user_emb": user_emb_opt, "item_emb": item_emb_opt}
        )
        # Mixed mode: user embeddings read-only, item embeddings normal
        samples = {"user_emb": user_ids, "item_emb": item_ids}
        outputs = embedding_engine(samples)
    """
    @property
    def type(self) -> str:
        """Get the hook type identifier.
        Returns:
            str: Always returns "admit" to identify this as an admission hook.
                This type identifier is used by the system to distinguish
                between different hook categories.
        """
        return "admit"
    def __str__(self):
        """Return JSON string representation of the admission hook.
        This method provides a standardized string representation of the
        admission hook configuration in JSON format, including the policy
        name and parameters.
        Returns:
            str: JSON string containing the hook's name and parameters,
                with keys sorted for consistent output.
        """
        info = {"name": self._name, "params": self._params}
        return json.dumps(info, sort_keys=True) 
[docs]
class FilterHook(BaseHook):
    """Feature filtering hook for implementing HashTable cleanup strategies.
    FilterHook implements feature filtering policies that automatically remove
    unused or outdated features from HashTable embeddings. This helps manage
    memory usage and maintain embedding table quality by removing features
    that are no longer relevant.
    The most common policy is "GlobalStepFilter", which removes features that
    haven't been accessed for a specified number of training steps. This is
    particularly useful in online learning scenarios where feature relevance
    changes over time.
    Example:
        Basic filtering with step-based cleanup:
    .. code-block:: python
        from recis.nn import EmbeddingEngine, EmbeddingOption
        from recis.nn.hashtable_hook import FilterHook
        from recis.hooks.filter_hook import HashTableFilterHook
        # Configure embedding with filtering policy
        user_emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="user_emb",
            combiner="sum",
            # Remove IDs not seen for 10 steps
            filter_hook=FilterHook("GlobalStepFilter", {"filter_step": 20}),
        )
        # Create embedding engine
        embedding_engine = EmbeddingEngine({"user_emb": user_emb_opt})
        # Setup filtering hook for periodic cleanup
        filter_hook = HashTableFilterHook(filter_interval=10)  # Check every 10 steps
        # Training loop with automatic filtering
        for step in range(100):
            outputs = embedding_engine(samples)
            # Trigger filtering check
            filter_hook.after_step(None, step)
            if step % 10 == 0:
                print(f"Step {step}: Automatic cleanup performed")
        Advanced filtering configuration:
    .. code-block:: python
        # Multiple embeddings with different filtering policies
        user_emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="user_emb",
            # Aggressive filtering for user features
            filter_hook=FilterHook("GlobalStepFilter", {"filter_step": 5}),
        )
        item_emb_opt = EmbeddingOption(
            embedding_dim=64,
            shared_name="item_emb",
            # Conservative filtering for item features
            filter_hook=FilterHook("GlobalStepFilter", {"filter_step": 50}),
        )
        category_emb_opt = EmbeddingOption(
            embedding_dim=32,
            shared_name="category_emb",
            # No filtering for stable category features
        )
        # Create engine with mixed filtering policies
        embedding_engine = EmbeddingEngine(
            {
                "user_emb": user_emb_opt,
                "item_emb": item_emb_opt,
                "category_emb": category_emb_opt,
            }
        )
    """
    @property
    def type(self) -> str:
        """Get the hook type identifier.
        Returns:
            str: Always returns "filter" to identify this as a filtering hook.
                This type identifier is used by the system to distinguish
                between different hook categories.
        """
        return "filter"
    def __str__(self):
        """Return JSON string representation of the filtering hook.
        This method provides a standardized string representation of the
        filtering hook configuration in JSON format, including the policy
        name and parameters.
        Returns:
            str: JSON string containing the hook's name and parameters,
                with keys sorted for consistent output.
        """
        info = {"name": self._name, "params": self._params}
        return json.dumps(info, sort_keys=True)