Source code for recis.serialize.saver

from typing import Optional

import torch


[docs] class Saver: """Saves model state dictionaries by sharding and parallel processing. This class handles both sparse (hashtable-based) and dense (tensor-based) state dictionaries, applying filtering and sharding logic before saving to disk. Examples: Typical usage example for saving a sharded checkpoint: >>> sparse_state_dict_copy = sparse_state_dict.copy() >>> sparse_state_dict, dense_state_dict = split_sparse_dense_state_dict( ... sparse_state_dict_copy ... ) >>> saver = Saver( ... shard_index=shard_id, ... shard_num=shard_num, ... parallel=concurrent, ... hashtables=sparse_state_dict, ... tensors=dense_state_dict, ... path=ckpt_path, ... ) >>> saver.save() """
[docs] def __init__( self, shard_index: int = 0, shard_num: int = 1, parallel: int = 8, path: str = ".", hashtables: Optional[dict] = None, tensors: Optional[list] = None, filter_func=lambda x: x, ) -> None: """Initializes the Saver with configuration and state data. Args: shard_index: The index of the current shard (0-based). Defaults to 0. shard_num: The total number of shards to create. Defaults to 1 (no sharding). parallel: The degree of parallelism for write operations. Defaults to 8. path: The output directory for saved files. Defaults to current directory. hashtables: A dictionary of sparse state (hashtables). Defaults to empty dict. tensors: A list of dense state (tensors). Defaults to empty list. filter_func: A callable to filter write blocks. Defaults to identity function. """ if tensors is None: tensors = {} if hashtables is None: hashtables = {} self._shard_index = shard_index self._shard_num = shard_num self._parallel = parallel self._path = path self._hashtables = hashtables self._tensors = tensors self._filter_func = filter_func self._saver_impl = torch.classes.recis.Saver( self._shard_index, self._shard_num, self._parallel, self._path )
[docs] def save(self): """Executes the saving process. Generates write blocks from the state data, applies the filter function, and delegates to the internal saver implementation for actual I/O operations. """ write_blocks = self._saver_impl.make_write_blocks( self._hashtables, self._tensors ) write_blocks = self._filter_func(write_blocks) self._saver_impl.save(write_blocks)