[笔记] Golang小试实现神经网络框架

prog_6103 · · 8477 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

国庆节宅在家里,看完了40集的电视剧,也刷了4K代码。最近看看Golang,然后想摆弄一下神经网络。虽然如今都是第三方库泛滥,开源代码拿来即用,而这次对自己的要求是没有第三方库,代码自成闭包。这样更有自主控制性,加深对神经网络实现的理解,也省去了学习各种三方包的用法,顺带着熟练下Golang这门新语言。

搜索引擎可以找到各种神经网络的入门文章。
A Neural Network in 11 lines of Python
Anyone Can Learn To Code an LSTM-RNN in Python
这两篇算是不错的入门级文章,虽然第二篇直接跳到比较新的循环神经网络。写文章的前辈代码质量虽然不是特别高,但是也是一行一行用尽洪荒之力来分析代码。第一篇的11行代码的神经网络很清晰,让人一看就知道如何继续修改自己玩。一句话总结,神经网络三步曲:Predict,Learn和Update。Predict就是系统尝试给出结果,Learn就是根据系统的自我判断或者参照外部结果得到系统偏差,Update则是根据偏差更新系统内部结构。

在GitHub上找Golang的机器学习包,还不是很多。这年头,机器学习都被Matlab,R和python占去了风头吧。尤其,Python有着NumPy社区,Theano和TensorFlow等优秀框架。

不多说,直接开始写代码吧:GitHub/hotpot.seafood
其中有三个例子,分别是最基础的xor,循环神经网络训练256内加法add,卷积神经网络识别手写数字mnist。

首先,就想到最基础的矩阵乘法。那得有Matrix类和一些基础方法。这个就不必搜索了,根据以数学知识实现就好。期间开小差,编译了一下LAPACK和OpenBLAS库,发现Fortron的代码是编译成.o文件的,这样其他语言接extension很容易,只是document太少,不知道api要学到什么时候。

本来学习前辈文章中的Python代码,照猫画虎写好了循环神经网络的add,发现没什么规律。直到看到 nnet 才发现原来TensorFlow,theano等框架都是往程序Graph的结构去的,就连2015年PyCon大会上的Lazy Expression估计也是为机器学习框架设计的。nnet的代码十分易读,建立一个神经网络,里面包含的层也看得很清楚。

    nn = nnet.NeuralNetwork(
        layers=[
            nnet.Conv(
                n_feats=12,
                filter_shape=(5, 5),
                strides=(1, 1),
                weight_scale=0.1,
                weight_decay=0.001,
            ),
            nnet.Activation('relu'),
            nnet.Pool(
                pool_shape=(2, 2),
                strides=(2, 2),
                mode='max',
            ),
            nnet.Conv(
                n_feats=16,
                filter_shape=(5, 5),
                strides=(1, 1),
                weight_scale=0.1,
                weight_decay=0.001,
            ),
            nnet.Activation('relu'),
            nnet.Flatten(),
            nnet.Linear(
                n_out=n_classes,
                weight_scale=0.1,
                weight_decay=0.02,
            ),
            nnet.LogRegression(),
        ],
    )

于是照着nnet的结构,写起了神经网络计算框架。对于网络的每一层,都会有计算Forward和Backward,然后Update参数。一个神经网络先把所有的Forward计算完,得到的结果就完成了Predict任务。按自己聚类或外部提供的正确结果,再倒着把误差传递一遍运行所有层的Backward过程,能完成Learn任务。最后的Update任务当然就是Update所有层的参数了。

Forward过程其实很容易理解,比如一个线性函数啦y=kx+b,不过这里是用矩阵表示就是了:Y = X.W + B,所以这种类型就是LayerLinear了。当然还有形如11+ex sigmoid一类的叫作激活函数的层,就是矩阵每个元素都按这个元素计算一遍,相当于matrix.elements.map(sigmoid),激活函数层对于神经网络收敛有不可或缺的作用,如果不用它们,很可能导致一个线性层参数瞬间刷爆到Inf。

Backward是计算误差,每层可以从下一层pop上来这一层结果的误差在y=f(x)中就相当于Δy。然后就要使用这个误差去推算输入的误差,即上一层输出的误差,并计算这一层参数更新的量。举个例子就是买了股票预测下一时刻的股价,然后到时候看实际价格和预测差,然后调整对这支股票的期待。在控制论的PID算法里,这里就是去求D。Backward比较基础的可以参考http://cs231n.github.io/optimization-2/#backprop,里面有一些例子,在流程图表上标出Forward和Backward的计算结果。

最后的Update就比较傻瓜,学习速率和参数更新量相乘后再累加到参数里就搞定了。不过这里的坑也有很多,因为牵涉到神经网络这门学科我觉得最坑爹的点,就是收敛。如果参数更新控制得不好,那么你可能发现新世界新学科:混沌与分形十分典型(x(t+dt)=4x(t)(1x(t))是经典,它可以制作0-1之间的随机数生成器)。如果神经网络不收敛,或者达不到稳定收敛,训练数据就很蛮烦。比如本来训练100个样本数据,准确度已经达到98%了,要是收敛不稳定再来100个样本,可能结果就变成正确率2%了。当然,稳定收敛也有弊端,那就是,稳定收敛的地方可能是函数的一个沟,可能在遥远的某个地方还有比这个沟更深的沟,所谓“坑的爹“。所以,请“敬请“怀疑图灵机这一模型是否能够构造真正的智能吧。

实现xor很顺利,因为是最基础的神经网络调用,LayerLinear和sigmoid层只要不出问题就okay。线性层不带初中学的记得叫截距b的话,只要记住Forward是Y=XW,Backward是dX=dYWTdW=XTdY,Update就W+αdW

在实现add的时候,循环神经网络确实比较复杂,要记录之前的状态。就拿线性层来说,本来Yi=XW,现在要多出来一项Yi=XW+Yi1H。所以刚算完的Y要存到一个地方,下次算Y还得用上。这里还遇到Golang的特性坑。一直以OOP的思路写程序,结果遇到了Golang的继承问题:

type I interface {
   Public () int
   Abstract () int
}

type A struct {
   I
   x int
}

func (a *A) Public () int {
   return a.Abstract()
}
func (a *A) Abstract () int {
   a.x = 1
   return a.x
}

type B struct {
   A
}

func (b *B) Abstract () int {
   b.x = 2
   return b.x
}

虽然是自己傻逼,程序明明在A的Public里写的是a.Abstract(),当然会执行A的Abstract。解决方案有两种,一个是在A的Public里,将a强制转化为I,然后调用Abstract;另一种是再写一个interface,定义一个函数,这个函数有一个参数是传自己进去。

最后是循环神经网络,是耗时最多在调试上的。MNIST识别0-9的手写体数字,因为最初写得程序设置层或参数导致不收敛,或者收敛成输出固定值正确率10%,一度都想还是直接用Python吧。最终偶然把relu函数换成tanh函数,decay参数调小,程序学习的时候就开始比较好得收敛起来。用Python读MNIST数据源,输出了一个减量版的json数据集。这么Golang就好读了,Marshal一下搞定。至此国庆九天假,结束了,该要上班了,懒腰…


有疑问加站长微信联系(非本文作者)

本文来自:CSDN博客

感谢作者:prog_6103

查看原文:[笔记] Golang小试实现神经网络框架

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

8477 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传