0%

Mxnet

Mxnet训练流程

  • 定义好symbol:符号式编程,生成描述计算的计算图,Symbol API 这个包主要是用于提供神经网络的图和自动求导。使用c++写好的Symbol类,可进行加减乘除等符号运算。

  • 写好dataiter:如果是分布式训练,需要将训练数据平均切分到每台训练的机器上

  • 初始化参数:cpu or gpu、学习率、优化器参数、针对不同网络设置参数初始化方法、验证度量(evaluation metrics)、loss function、callbacks

  • 模型训练:

    • 创建Module:
      1
      2
      3
      4
      mod = mx.mod.Module(symbol=net,
      context=mx.cpu(),
      data_names=['data'],
      label_names=['softmax_label'])
    • mod.fit():
      • mod.bind(): 分配内存,为计算做准备
      • mod.inst_prams(): 初始化module参数
      • mod.init_optimizer(): 初始化优化器,默认sgd
      • mod.metri.create(): 创建evaluation metric
      • mod.forward(): 前向传播计算
      • mod.update_metric(): 更新预测精度
      • mod.backward(): 反向传播,更新计算梯度
      • mod.update(): 更新参数

Mxnet参数服务器架构

  • server节点可以跟其他server节点通信,每个server负责自己分到的参数,server group共同维持所有参数的更新
  • worker节点之间没有通信,只跟自己对应的server进行通信
  • 每个worker group有一个task scheduler,负责向worker分配任务,并且监控worker的运行情况。当有新的worker加入或者退出,task scheduler 负责重新分配任务。

参考文献

MXnet之ps-lite及parameter server原理