训练流程

详细API文档: Training Framework Module

典型使用

训练

# dataset: 训练数据集
# model: 模型
# dense_opt: 稠密参数优化器
# sparse_opt: 稀疏参数优化器

train_config = TrainingArguments(
    output_dir="./ckpt/",
    model_bank=None,
    log_steps=10,
    save_steps=1000,
)
trainer = Trainer(
    model=model,
    args=train_config,
    train_dataset=dataset,
    dense_optimizers=(dense_opt, None),
    sparse_optimizer=sparse_opt,
)
# 开始训练
trainer.train()

预测

# dataset: evaluate数据集
# model: 模型
# model_bank_conf: 预测任务需要加载的模型配置

train_config = TrainingArguments(
    output_dir=None,
    model_bank=model_bank_conf,
    log_steps=10,
    save_steps=1000,
)
trainer = Trainer(
    model=model,
    args=train_config,
    eval_dataset=dataset,
)
# 开始训练
trainer.evaluate()

边训练边预测

# train_dataset: 训练数据集
# eval_dataset: evaluate数据集
# model: 模型
# dense_opt: 稠密参数优化器
# sparse_opt: 稀疏参数优化器

train_config = TrainingArguments(
    output_dir="./ckpt/",
    model_bank=None,
    log_steps=10,
    save_steps=1000,
)
trainer = Trainer(
    model=model,
    args=train_config,
    train_dataset=dataset,
    eval_dataset=eval_dataset,
    dense_optimizers=(dense_opt, None),
    sparse_optimizer=sparse_opt,
)
# 开始训练
trainer.train_and_evaluate()

导图

参考: 在线交付(阿里内部功能)