折腾了一段时间Tensorflw后,感觉Tensorflow各方面都非常不错,无论是Tensorflow Serving还是TensorBoard,不得不说Google的大腿就是粗啊。不过Tensorflow在性能上的并不怎么好,训练的时间和内存占用都很高。不过Tensorflow 1.0后也加入了新的实验特性来提高性能,相信Tensorflow以后性能不会再是大的问题。穷人家的孩子还得节俭持家, 所以我把目光投向了MXNet。
MXNet是一个兼具性灵活和效率的深度学习框架。灵活指的是MXNet兼容声明式和指令式两种编程方式,可以很灵活的定义自己需要的东西。而效率则指MXNet在底层做了很多的优化(具体可以看这篇文章),速度与内存消耗方面都做的很好。
不过相比Tensorflow,MXNet的上手难度要高,文档方面对新手还是不太友好。入门教程可以很快的跑一个MNIST识别的demo,但用高级的module
封装了很多细节,当新手要自定义loss函数时往往不知所措。其实我觉得学习MXNet基本概念最开始要看的是Symbolic Configuration and Execution in Pictures这篇文章。以这篇文章的例子来说
我们定义SymbolA
,B
,C
,D
,E
,其中A
, B
, D
为输入节点,E
为输出节点,然后绑定输入数据a
,b
,d
到E
生成Executor, 最后Executor调用forward
计算网络输出。
除了多一步绑定数据生成Executor,其他跟Tensorflow还是很像的,可以类比Tensorflow的
sess.run(E, feed_dict={A: a, B: b, D: d})
当我们需要训练模型的时候,我们就需要计算模型中参数的梯度,计算梯度需要调用Executor的backward
方法。
如上,绑定的时候通过args_grad
对待求解参数输入一个空的NDArray
用于存放梯度计算结果,新版MXNet的API和上图有些不同,如果需要计算梯度,调用forward
时需要带上is_train
参数
executor.forward(is_train=True)
执行forward
后再执行backward
方法求解梯度,之后就能得到节点A
的梯度
In [4]: exec.grad_dict['A'].asnumpy()
Out[4]: array([ 4.], dtype=float32)
计算出梯度后我们就可以根据梯度更新模型参数,一遍一遍的迭代直到训练结束。
从上面的过程可以看出,bind
这套api还是很低层的,很多东西都要我们手工完成,所以MXNet在这基础上提供了更高层的api方便开发。我们可以在MXNet项目代码的example/module
中的demo里面学习基础用法,这里就直接copy过来
中等层次的api
################################################################################
# Intermediate-level API
################################################################################
mod = mx.mod.Module(softmax)
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label)
mod.init_params()
mod.init_optimizer(optimizer_params={'learning_rate':0.01, 'momentum': 0.9})
metric = mx.metric.create('acc')
for i_epoch in range(n_epoch):
for i_iter, batch in enumerate(train_dataiter):
mod.forward(batch)
mod.update_metric(metric, batch.label)
mod.backward()
mod.update()
for name, val in metric.get_name_value():
print('epoch %03d: %s=%f' % (i_epoch, name, val))
metric.reset()
train_dataiter.reset()
高级api,这个就很简单了
################################################################################
# High-level API
################################################################################
logging.basicConfig(level=logging.DEBUG)
train_dataiter.reset()
mod = mx.mod.Module(softmax)
mod.fit(train_dataiter, eval_data=val_dataiter,
optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)
总的来说个人感觉MXNet还是一个非常不错的框架, 开发社区也很活跃,目前有了AWS官方的支持,希望后面越来越好。