Parameter Saving and Loading
Save Format
Model Files
N files with the suffix .safetensors will be generated, where N is the number of GPU multiplied by the parallel thread count set when saving the model. The file format is the same as safetensor.
Auxiliary Files
Some intermediate files will be generated when saving the model to facilitate debugging when saving or loading models encounters bugs.
index file: This file format is the same as json file format. It has two keys, the first key is block_index, which records which model file the tensor is saved to; the second key is file_index, which records the index of the model file.
tensorkey.json file: Records the conversion relationship of tensor_name when saving the model, such as user_id/part_0_2, where 2 indicates double GPU training, and 0 indicates that the current tensor is assigned to GPU 0
torch_rank_weights_embs_table_multi_shard.json, this json file will specifically describe all tensors, as shown below:
"feature_layer.feature_ebc.ebc.embedding_bags.gbdt_cvr.embedding/part_0_1": {
"name": "feature_layer.feature_ebc.ebc.embedding_bags.gbdt_ctr.embedding/part_0_1", # tensor name
"dense": false, # non-dense tensor
"dimension": 16, # dimension
"dtype": "float32", # data type
"hashmap_key": "feature_layer.feature_ebc.ebc.embedding_bags.gbdt_ctr.id/part_0_1", # key and value in hash table
"hashmap_value": "feature_layer.feature_ebc.ebc.embedding_bags.gbdt_ctr.embedding/part_0_1",
"shape": [
1211,
16
],
"is_hashmap": true
},
# This is dense tensor, same meaning as above
"feature_layer.feature_ebc.ebc.embedding_bags.buyer_star_name.weight": {
"name": "feature_layer.feature_ebc.ebc.embedding_bags.buyer_star_name.weight",
"dense": false,
"dimension": 4,
"dtype": "float32",
"shape": [
64,
4
],
"is_hashmap": false
},
Save and Load Interface
CheckpointReader
- class recis.serialize.checkpoint_reader.CheckpointReader(path)[source]
Provides read access to checkpoint files and their metadata.
This class serves as a wrapper around the low-level torch.classes.recis.CheckpointReader, offering a convenient interface to inspect and read tensors from checkpoint files.
Examples
Typical usage example for reading checkpoint contents:
>>> reader = CheckpointReader("/path/to/checkpoint") >>> tensor_names = reader.tensor_names() >>> for name in tensor_names: ... shape = reader.tensor_shape(name) ... dtype = reader.tensor_dtype(name) ... tensor_data = reader.read_tensor(name)
- reader
The underlying implementation object handling low-level checkpoint reading.
- __init__(path)[source]
Initializes the CheckpointReader with a path to checkpoint files.
- Parameters:
path – The directory path containing checkpoint files to read.
Note
The reader initialization may involve loading metadata and preparing for subsequent read operations.
- read_tensor(name)[source]
Reads and returns the tensor data for the specified tensor name.
- Parameters:
name – The identifier of the tensor to read.
- Returns:
The tensor data as an appropriate array or tensor object.
- Raises:
KeyError – If the specified tensor name does not exist in the checkpoint.
- tensor_dtype(name)[source]
Retrieves the data type of the specified tensor.
- Parameters:
name – The identifier of the tensor.
- Returns:
The data type object representing the tensor’s element type.
- Raises:
KeyError – If the specified tensor name does not exist in the checkpoint.
Saver
- class recis.serialize.saver.Saver(shard_index: int = 0, shard_num: int = 1, parallel: int = 8, path: str = '.', hashtables: dict | None = None, tensors: list | None = None, filter_func=<function Saver.<lambda>>)[source]
Saves model state dictionaries by sharding and parallel processing.
This class handles both sparse (hashtable-based) and dense (tensor-based) state dictionaries, applying filtering and sharding logic before saving to disk.
Examples: Typical usage example for saving a sharded checkpoint:
>>> sparse_state_dict_copy = sparse_state_dict.copy() >>> sparse_state_dict, dense_state_dict = split_sparse_dense_state_dict( ... sparse_state_dict_copy ... ) >>> saver = Saver( ... shard_index=shard_id, ... shard_num=shard_num, ... parallel=concurrent, ... hashtables=sparse_state_dict, ... tensors=dense_state_dict, ... path=ckpt_path, ... ) >>> saver.save()
- __init__(shard_index: int = 0, shard_num: int = 1, parallel: int = 8, path: str = '.', hashtables: dict | None = None, tensors: list | None = None, filter_func=<function Saver.<lambda>>) None [source]
Initializes the Saver with configuration and state data.
- Parameters:
shard_index – The index of the current shard (0-based). Defaults to 0.
shard_num – The total number of shards to create. Defaults to 1 (no sharding).
parallel – The degree of parallelism for write operations. Defaults to 8.
path – The output directory for saved files. Defaults to current directory.
hashtables – A dictionary of sparse state (hashtables). Defaults to empty dict.
tensors – A list of dense state (tensors). Defaults to empty list.
filter_func – A callable to filter write blocks. Defaults to identity function.
Loader
- class recis.serialize.loader.Loader(checkpoint_path: str, hashtables: dict | None = None, tensors: dict | None = None, parallel: int = 16, filter_func=<function Loader.<lambda>>)[source]
Loads model state dictionaries from checkpoint files with parallel processing.
This class handles loading both sparse (hashtable-based) and dense (tensor-based) state dictionaries from disk, applying filtering logic to the load configuration.
Examples
Typical usage example for loading a checkpoint:
>>> loader = Loader( ... checkpoint_path="/path/to/checkpoint", ... hashtables=sparse_state_dict, ... tensors=dense_state_dict, ... parallel=16, ... ) >>> loader.load()
- __init__(checkpoint_path: str, hashtables: dict | None = None, tensors: dict | None = None, parallel: int = 16, filter_func=<function Loader.<lambda>>) None [source]
Initializes the Loader with configuration and target state dictionaries.
- Parameters:
checkpoint_path – The directory path containing checkpoint files to load.
hashtables – A dictionary to receive loaded sparse state data. If None, an empty dictionary will be created.
tensors – A dictionary to receive loaded dense state data. If None, an empty dictionary will be created.
parallel – The degree of parallelism for read operations. Defaults to 16.
filter_func – A callable to filter load information. Defaults to identity function.
- load()[source]
Executes the loading process.
Retrieves default load information from the checkpoint, applies the filter function to modify the load configuration, and delegates to the internal loader implementation for actual I/O operations.
The load operation involves: 1. Retrieving default load information from the checkpoint metadata; 2. Applying the filter function to modify the load configuration; 3. Loading the state data into the provided hashtables and tensors dictionaries using parallel processing;
The actual file reading and data reconstruction are handled by the torch.classes.recis.Loader class.