Home

MXNet上手

折腾了一段时间Tensorflw后,感觉Tensorflow各方面都非常不错,无论是Tensorflow Serving还是TensorBoard,不得不说Google的大腿就是粗啊。不过Tensorflow在性能上的并不怎么好,训练的时间和内存占用都很高。不过Tensorflow 1.0后也加入了新的实验特性来提高性能,相信Tensorflow以后性能不会再是大的问题。穷人家的孩子还得节俭持家, 所以我把目光投向了MXNet。

MXNet是一个兼具性灵活和效率的深度学习框架。灵活指的是MXNet兼容声明式和指令式两种编程方式,可以很灵活的定义自己需要的东西。而效率则指MXNet在底层做了很多的优化(具体可以看这篇文章),速度与内存消耗方面都做的很好。

不过相比Tensorflow,MXNet的上手难度要高,文档方面对新手还是不太友好。入门教程可以很快的跑一个MNIST识别的demo,但用高级的module封装了很多细节,当新手要自定义loss函数时往往不知所措。其实我觉得学习MXNet基本概念最开始要看的是Symbolic Configuration and Execution in Pictures这篇文章。以这篇文章的例子来说 build process

我们定义SymbolA,B,C,D,E,其中A, B, D为输入节点,E为输出节点,然后绑定输入数据a,b,dE生成Executor, 最后Executor调用forward计算网络输出。

除了多一步绑定数据生成Executor,其他跟Tensorflow还是很像的,可以类比Tensorflow的

sess.run(E, feed_dict={A: a, B: b, D: d})

当我们需要训练模型的时候,我们就需要计算模型中参数的梯度,计算梯度需要调用Executor的backward方法。 build process

如上,绑定的时候通过args_grad对待求解参数输入一个空的NDArray用于存放梯度计算结果,新版MXNet的API和上图有些不同,如果需要计算梯度,调用forward时需要带上is_train参数

executor.forward(is_train=True)

执行forward后再执行backward方法求解梯度,之后就能得到节点A的梯度

In [4]: exec.grad_dict['A'].asnumpy()
Out[4]: array([ 4.], dtype=float32)

计算出梯度后我们就可以根据梯度更新模型参数,一遍一遍的迭代直到训练结束。

从上面的过程可以看出,bind这套api还是很低层的,很多东西都要我们手工完成,所以MXNet在这基础上提供了更高层的api方便开发。我们可以在MXNet项目代码的example/module中的demo里面学习基础用法,这里就直接copy过来

中等层次的api

################################################################################
# Intermediate-level API
################################################################################
mod = mx.mod.Module(softmax)
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
mod.init_params()

mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
metric = mx.metric.create('acc')

for i_epoch in range(n_epoch):
    for i_iter, batch in enumerate(train_dataiter):
        mod.forward(batch)
        mod.update_metric(metric, batch.label)

        mod.backward()
        mod.update()

    for name, val in metric.get_name_value():
        print('epoch %03d: %s=%f' % (i_epoch, name, val))
    metric.reset()
    train_dataiter.reset()

高级api,这个就很简单了

################################################################################
# High-level API
################################################################################
logging.basicConfig(level=logging.DEBUG)
train_dataiter.reset()
mod = mx.mod.Module(softmax)
mod.fit(train_dataiter, eval_data=val_dataiter,
        optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)

总的来说个人感觉MXNet还是一个非常不错的框架, 开发社区也很活跃,目前有了AWS官方的支持,希望后面越来越好。