from typing import Optional
import torch
[docs]
class Saver:
    """Saves model state dictionaries by sharding and parallel processing.
    This class handles both sparse (hashtable-based) and dense (tensor-based) state
    dictionaries, applying filtering and sharding logic before saving to disk.
    Examples:
    Typical usage example for saving a sharded checkpoint:
    >>> sparse_state_dict_copy = sparse_state_dict.copy()
    >>> sparse_state_dict, dense_state_dict = split_sparse_dense_state_dict(
    ...     sparse_state_dict_copy
    ... )
    >>> saver = Saver(
    ...     shard_index=shard_id,
    ...     shard_num=shard_num,
    ...     parallel=concurrent,
    ...     hashtables=sparse_state_dict,
    ...     tensors=dense_state_dict,
    ...     path=ckpt_path,
    ... )
    >>> saver.save()
    """
[docs]
    def __init__(
        self,
        shard_index: int = 0,
        shard_num: int = 1,
        parallel: int = 8,
        path: str = ".",
        hashtables: Optional[dict] = None,
        tensors: Optional[list] = None,
        filter_func=lambda x: x,
    ) -> None:
        """Initializes the Saver with configuration and state data.
        Args:
            shard_index: The index of the current shard (0-based). Defaults to 0.
            shard_num: The total number of shards to create. Defaults to 1 (no sharding).
            parallel: The degree of parallelism for write operations. Defaults to 8.
            path: The output directory for saved files. Defaults to current directory.
            hashtables: A dictionary of sparse state (hashtables). Defaults to empty dict.
            tensors: A list of dense state (tensors). Defaults to empty list.
            filter_func: A callable to filter write blocks. Defaults to identity function.
        """
        if tensors is None:
            tensors = {}
        if hashtables is None:
            hashtables = {}
        self._shard_index = shard_index
        self._shard_num = shard_num
        self._parallel = parallel
        self._path = path
        self._hashtables = hashtables
        self._tensors = tensors
        self._filter_func = filter_func
        self._saver_impl = torch.classes.recis.Saver(
            self._shard_index, self._shard_num, self._parallel, self._path
        ) 
[docs]
    def save(self):
        """Executes the saving process.
        Generates write blocks from the state data, applies the filter function,
        and delegates to the internal saver implementation for actual I/O operations.
        """
        write_blocks = self._saver_impl.make_write_blocks(
            self._hashtables, self._tensors
        )
        write_blocks = self._filter_func(write_blocks)
        self._saver_impl.save(write_blocks)