Source code for recis.hooks.filter_hook

import torch

from recis.hooks.hook import Hook
from recis.nn.modules.hashtable_hook_impl import HashtableHookFactory
from recis.utils.logger import Logger


[docs] class HashTableFilterHook(Hook): """Hook for automatic hash table feature filtering during training. This hook manages the lifecycle of features in hash tables by coordinating filtering operations across multiple hash table instances. It automatically updates step counters and triggers filtering operations at configurable intervals to remove stale or inactive features. The hook integrates with the hash table filter system to: - Track global training steps for each hash table filter - Execute filtering operations at specified intervals - Provide comprehensive logging of filter activities - Support dynamic adjustment of filtering frequency Args: filter_interval (int, optional): Number of training steps between filter operations. If None, filtering is disabled. Defaults to 100. Examples: Please refer to the documentation :doc:`nn/filter` .. code-block:: python # Create and configure filter hook filter_hook = HashTableFilterHook(filter_interval=200) # Training loop integration for epoch in range(num_epochs): for step, batch in enumerate(dataloader): # ... training logic ... # Hook automatically manages filtering filter_hook.after_step(None, global_step) global_step += 1 """
[docs] def __init__(self, filter_interval: int = 100): """Initialize the hash table filter hook. Args: filter_interval (int, optional): Number of training steps between filter operations. Must be positive. If None, filtering is disabled. Defaults to 100. Example: .. code-block:: python # Standard filtering every 100 steps hook = HashTableFilterHook(filter_interval=100)""" super().__init__() self.filter_interval = filter_interval self.ht_filters = HashtableHookFactory().get_filters() self.logger = Logger("HashTableFilterHook") self.last_filter_step = 0 self.logger.info( f"HashTableFilterHook {self.ht_filters}, filter_interval {self.filter_interval}" )
def reset_filter_interval(self, interval: int = 100): """Reset the filtering interval to a new value. This method allows dynamic adjustment of the filtering frequency during training, which can be useful for adaptive memory management or performance optimization strategies. Args: interval (int, optional): New filtering interval in training steps. Must be positive. Defaults to 100. """ self.filter_interval = interval def after_step(self, _, gstep): """Execute filter management operations after each training step. This method is called after each training step to: 1. Update step counters for all registered hash table filters 2. Determine if filtering should be executed based on the interval 3. Trigger filtering operations when the interval is reached 4. Update the last filter step tracking Args: _ (Any): Unused parameter (typically model or trainer instance). gstep (Union[int, torch.Tensor]): Current global training step. Can be either an integer or a tensor containing the step value. Note: This method is typically called automatically by the training framework's hook system. Manual calls should ensure proper step sequencing to maintain filtering accuracy. """ if self.last_filter_step == 0: self.last_filter_step = ( gstep.item() if isinstance(gstep, torch.Tensor) else gstep ) exec_filter = ( self.filter_interval is not None and gstep - self.last_filter_step >= self.filter_interval ) for hooks in self.ht_filters.values(): for ft in hooks.values(): ft.update_step() if exec_filter: ft.do_filter() if exec_filter: self.last_filter_step = ( gstep.item() if isinstance(gstep, torch.Tensor) else gstep )