Hooks
详细API文档: Hook System
训练指标统计大盘(阿里内部功能)
参考文档:nebula-mltracker
# trainer: Trainer
# project_name: 指标统计项目名称
# experiment_name: 指标统计实验名称
# track_config: 指标统计其他配置
if os.environ['RANK'] == '0':
ml_tracker_hook = MLTrackerHook(
project_name,
experiment_name,
track_config,
)
trainer.add_hooks([ml_tracker_hook])
预测结果Trace到ODPS(阿里内部功能)
# trainer: Trainer
# fields: 预测结果列名
# fields = ["id", "preds", "labels"]
# types: 预测结果列类型
# types = ["string", "string", "string"]
trace_hook_config = {
"access_id": "xxx",
"access_key": "xxx",
"end_point": "xxx",
"project": "xxx",
"table_name": "xxx",
"partition": "ds1=xxx,ds2=xxx",
}
trace_hook = TraceToOdpsHook(
config=trace_hook_config, fields=fields, types=types, worker_num=8
)
trainer.add_hooks([trace_hook])
Timeline分析
# trainer: Trainer
# output_dir: 输出目录
if int(os.environ.get("RANK", 0)) == 0:
hooks = ProfilerHook(
wait=1,
warmup=249,
active=1,
repeat=2,
output_dir=output_dir,
)
trainer.add_hook(hooks)