保存部分Sparse参数
详细API文档: recis.framework.trainer.TrainingArguments - TrainingArguments
通过名字指定
train_arg = TrainingArguments(
# item表中的全部字段都不保存
params_not_save=[
"item@id",
"item@emb",
"item@sparse_adamw_tf_exp_avg",
"item@sparse_adamw_tf_exp_avg_sq"],
# ... 其他配置
)
# 定义trainer
# trainer = Trainer(train_arg, ...)
自定义过滤函数
# 过滤hashtable中以item开头的表
def filter_fn(blocks):
out_blocks = []
for block in blocks:
if not block.tensor_name().startswith("item"):
out_blocks.append(block)
return out_blocks
train_arg = TrainingArguments(
# item表中的全部字段都不保存
save_filter_fn=filter_fn,
# ... 其他配置
)
# 定义trainer
# trainer = Trainer(train_arg, ...)