import json
from typing import Optional
import torch
[docs]
class Loader:
    """Loads model state dictionaries from checkpoint files with parallel processing.
    This class handles loading both sparse (hashtable-based) and dense (tensor-based)
    state dictionaries from disk, applying filtering logic to the load configuration.
    Examples:
        Typical usage example for loading a checkpoint:
        >>> loader = Loader(
        ...     checkpoint_path="/path/to/checkpoint",
        ...     hashtables=sparse_state_dict,
        ...     tensors=dense_state_dict,
        ...     parallel=16,
        ... )
        >>> loader.load()
    """
[docs]
    def __init__(
        self,
        checkpoint_path: str,
        hashtables: Optional[dict] = None,
        tensors: Optional[dict] = None,
        parallel: int = 16,
        filter_func=lambda x: x,
    ) -> None:
        """Initializes the Loader with configuration and target state dictionaries.
        Args:
            checkpoint_path: The directory path containing checkpoint files to load.
            hashtables: A dictionary to receive loaded sparse state data.
                If None, an empty dictionary will be created.
            tensors: A dictionary to receive loaded dense state data.
                If None, an empty dictionary will be created.
            parallel: The degree of parallelism for read operations. Defaults to 16.
            filter_func: A callable to filter load information. Defaults to identity function.
        """
        self._checkpoint_path = checkpoint_path
        self._hashtables = hashtables
        if self._hashtables is None:
            self._hashtables = {}
        self._tensors = tensors
        if self._tensors is None:
            self._tensors = {}
        self._impl = torch.classes.recis.Loader(
            self._checkpoint_path,
            parallel,
            self._hashtables,
            self._tensors,
        )
        self._filter_func = filter_func 
[docs]
    def load(self):
        """Executes the loading process.
        Retrieves default load information from the checkpoint, applies the filter function
        to modify the load configuration, and delegates to the internal loader implementation
        for actual I/O operations.
        The load operation involves:
        1. Retrieving default load information from the checkpoint metadata;
        2. Applying the filter function to modify the load configuration;
        3. Loading the state data into the provided hashtables and tensors dictionaries using parallel processing;
        The actual file reading and data reconstruction are handled by the torch.classes.recis.Loader class.
        """
        load_info = json.loads(self._impl.default_load_info())
        load_info = self._filter_func(load_info)
        self._impl.load(json.dumps(load_info))