import json
import os
import torch
from recis.framework.filesystem import get_file_system
from recis.info import is_internal_enabled
from recis.nn.modules.hashtable import filter_out_sparse_param
from recis.serialize import Loader, Saver
from recis.utils.torch_fx_tool.export_torch_fx_tool import ExportTorchFxTool
if is_internal_enabled() and not os.environ.get("BUILD_DOCUMENT", None) == "1":
from pangudfs_client.common.exception.exceptions import PanguException
from recis.utils.mos import Mos
else:
PanguException = None
Mos = None
TMP_EXPORT_LOCAL_PATH = "./__tmp_export_path__/"
[docs]
class Exporter:
"""Model exporter for RecIS framework with support for sparse and dense models.
The Exporter class handles the export process for trained RecIS models,
managing both sparse embedding tables and dense neural network components.
It supports distributed export across multiple workers and handles various
storage backends including local filesystem and cloud storage.
Key Features:
- Separate export of sparse and dense model components
- Distributed export with automatic file partitioning
- Support for multiple storage backends (local, cloud)
- Configuration export for feature generation and model compilation
- Automatic model preparation and state loading
Attributes:
rank (int): Current worker rank in distributed setup.
shard_num (int): Total number of workers in distributed setup.
model: Complete model containing both sparse and dense components.
sparse_model: Sparse embedding component of the model.
dense_model: Dense neural network component of the model.
dataset: Dataset used for model tracing during export.
ckpt_dir (str): Directory containing model checkpoints.
export_dir (str): Target directory for exported model files.
dense_optimizer: Optional optimizer for dense model components.
export_model_name (str): Name for the exported model.
export_outputs: Specification of model output nodes.
filter_sparse_opt (bool): Whether to filter sparse optimization parameters.
fg_conf (dict): Feature generation configuration.
mc_conf (dict): Model compilation configuration.
fx_tool (ExportTorchFxTool): Tool for exporting TorchFX models.
"""
[docs]
def __init__(
self,
model,
sparse_model_name,
dense_model_name,
dataset,
ckpt_dir,
export_dir,
dense_optimizer=None,
export_folder_name="fx_user_model",
export_model_name="user_model",
export_outputs=None,
fg=None,
fg_conf_or_path=None,
mc_conf_or_path=None,
filter_sparse_opt=False,
):
"""Initialize the model exporter with configuration parameters.
Args:
model: Complete RecIS model containing sparse and dense components.
sparse_model_name (str): Name of the sparse model submodule.
dense_model_name (str): Name of the dense model submodule.
dataset: Dataset for model tracing during export process.
ckpt_dir (str): Directory path containing model checkpoints.
export_dir (str): Target directory for exported model files.
dense_optimizer: Optional optimizer for dense model components.
export_folder_name (str, optional): Name of the export folder.
Defaults to "fx_user_model".
export_model_name (str, optional): Name for the exported model.
Defaults to "user_model".
export_outputs: Specification of model output nodes for export.
fg: Feature generator instance for configuration extraction.
fg_conf_or_path: Feature generation configuration dict or file path.
mc_conf_or_path: Model compilation configuration dict or file path.
filter_sparse_opt (bool, optional): Whether to filter sparse
optimization parameters. Defaults to False.
Raises:
AssertionError: If neither fg nor fg_conf_or_path is provided.
AssertionError: If neither fg nor mc_conf_or_path is provided.
AssertionError: If MOS is required but not available for model paths.
"""
self.rank = int(os.environ.get("RANK", 0))
self.shard_num = int(os.environ.get("WORLD_SIZE", 1))
self.model = model
self.sparse_model = model.get_submodule(sparse_model_name)
self.sparse_model_name = sparse_model_name
self.dense_model = model.get_submodule(dense_model_name)
self.dense_model_name = dense_model_name
self.dataset = dataset
if ckpt_dir.startswith("model"):
assert Mos is not None, "Cannot import mos, check interneal version."
ckpt_dir = Mos(ckpt_dir, True).real_physical_path
self.ckpt_dir = ckpt_dir
if export_dir.startswith("model"):
assert Mos is not None, "Cannot import mos, check interneal version."
export_dir = Mos(export_dir, True).real_physical_path
self.export_dir = export_dir
self.dense_optimizer = dense_optimizer
self.export_model_name = export_model_name
self.export_outputs = export_outputs
self.filter_sparse_opt = filter_sparse_opt
assert fg is not None or fg_conf_or_path is not None, (
"one of fg or fg_config must be not None"
)
assert fg is not None or mc_conf_or_path is not None, (
"one of fg or mc_config must be not None"
)
if fg_conf_or_path is not None:
if not isinstance(fg_conf_or_path, dict):
with open(fg_conf_or_path) as f:
fg_conf_or_path = json.load(f)
self.fg_conf = fg_conf_or_path
else:
self.fg_conf = fg.get_fg_conf()
if mc_conf_or_path is not None:
if not isinstance(mc_conf_or_path, dict):
with open(mc_conf_or_path) as f:
mc_conf_or_path = json.load(f)
self.mc_conf = mc_conf_or_path
else:
self.mc_conf = fg.get_mc_conf()
self.fx_tool = ExportTorchFxTool(
fx_folder=os.path.join(TMP_EXPORT_LOCAL_PATH, export_folder_name),
model_name=self.export_model_name,
)
self.fx_tool.set_output_nodes_name(export_outputs)
[docs]
def export(self):
"""Execute the complete model export process.
This method orchestrates the entire export workflow, including model
preparation, sparse component export, dense component export, and
metadata export. The process is designed to work in distributed
environments with automatic work partitioning.
The export process consists of:
1. Model preparation and checkpoint loading
2. Sparse model component export
3. Dense model component export with TorchFX
4. Configuration metadata export
Note:
This method should be called on all workers in a distributed setup.
File operations are automatically partitioned based on worker rank.
"""
self.prepare_model()
self.export_sparse()
self.export_dense()
self.export_meta()
def prepare_model(self):
"""Prepare the model for export by loading checkpoints and creating directories.
This method handles the initial setup required for model export:
- Creates the export directory if it doesn't exist
- Loads dense model state from checkpoint files
- Optionally loads sparse model parameters if filtering is enabled
The method supports both local filesystem and cloud storage backends,
automatically handling path resolution for different storage types.
Raises:
PanguException: If directory creation fails due to permission issues.
"""
fs = get_file_system(self.export_dir)
if not fs.exists(self.export_dir):
try:
fs.makedirs(self.export_dir + "/", exist_ok=True)
except PanguException as e:
if e.pangu_err_no == 7:
pass
# load dense model
pt_file = os.path.join(self.ckpt_dir, "model.pt")
fs = get_file_system(pt_file)
with fs.open(pt_file, "rb") as f:
state_dict = torch.load(f=f)
state_dict = {
k.replace(f"{self.dense_model_name}.", ""): v for k, v in state_dict.items()
}
self.dense_model.load_state_dict(state_dict, strict=False)
# maybe load sparse model
if self.filter_sparse_opt:
sparse_params = filter_out_sparse_param(self.sparse_model)
loader = Loader(self.ckpt_dir, hashtables=sparse_params, tensors={})
loader.load()
def export_dense(self):
"""Export the dense model component using TorchFX compilation.
This method handles the export of the dense neural network component:
1. Processes a sample batch through the sparse model to get dense inputs
2. Uses TorchFX to trace and export the dense model
3. Copies the exported model files to the target directory
The dense model is exported in a format suitable for deployment,
with optimizations applied through the TorchFX compilation process.
Only the rank 0 worker performs the final file copying to avoid conflicts.
Note:
The method requires at least one batch from the dataset for model tracing.
The sparse model must be properly initialized before calling this method.
"""
# export dense to local tmp dir
iterator = iter(self.dataset)
stop_flag, data = next(iterator)
dense_data = self.sparse_model(data)
if self.dense_optimizer:
self.dense_optimizer.zero_grad()
self.fx_tool.export_fx_model(self.dense_model, dense_data[0], self.mc_conf)
if self.rank == 0:
# copy local tmp dir to dst dir
fs = get_file_system(self.export_dir)
fs.put(TMP_EXPORT_LOCAL_PATH, self.export_dir, recursive=True)
def export_sparse(self):
"""Export the sparse model component with distributed file handling.
This method manages the export of sparse embedding tables and related
parameters. It supports two modes of operation:
1. Direct file copying mode (filter_sparse_opt=False):
- Identifies all sparse parameter files (safetensors, json)
- Distributes files across workers for parallel copying
- Each worker handles a subset of files based on rank
2. Filtered export mode (filter_sparse_opt=True):
- Extracts only relevant sparse parameters from the model
- Uses the Saver class for optimized sparse parameter serialization
- Automatically handles distributed saving across workers
The method ensures efficient distribution of work across multiple workers
while maintaining data consistency and avoiding file conflicts.
Note:
File distribution is based on worker rank to ensure balanced workload.
All workers participate in the export process simultaneously.
"""
if not self.filter_sparse_opt:
# get all sparse files (ends with `safetensors` or `json`)
candidate_files = []
fs = get_file_system(self.ckpt_dir)
for file_name in fs.ls(self.ckpt_dir, detail=False):
if file_name.endswith(("safetensors", "json")):
candidate_files.append(file_name)
file_to_copy = []
# partition files
for i, file_name in enumerate(candidate_files):
if i % self.shard_num == self.rank:
file_to_copy.append(file_name)
# copy
for file_name in file_to_copy:
fs.copy(file_name, self.export_dir)
else:
# save sparse files
sparse_params = filter_out_sparse_param(self.sparse_model)
saver = Saver(
shard_index=self.rank,
shard_num=self.shard_num,
parallel=1,
hashtables=sparse_params,
tensors={},
path=self.export_dir,
)
saver.save()
def export_meta(self):
"""Export model metadata and configuration files.
This method exports the configuration metadata required for model
deployment and inference. Currently exports the feature generation
configuration as a JSON file.
The metadata includes:
- Feature generation configuration (fg.json)
- Model compilation settings (future extension)
- Deployment-specific parameters (future extension)
Only the rank 0 worker performs metadata export to avoid file conflicts
and ensure consistency across the distributed export process.
Note:
Additional metadata types can be added by extending this method.
The exported configurations must match the training setup exactly.
"""
if self.rank == 0:
fs = get_file_system(self.export_dir)
# fg config
with fs.open(os.path.join(self.export_dir, "fg.json"), "w") as out_f:
json.dump(self.fg_conf, out_f)