[关闭]
@atry 2017-11-16T09:59:54.000000Z 字数 5424 阅读 1478

神经网络与函数式编程(四)Monadic Deep Learning

神经网络与函数式编程


在本系列的上一篇文章神经网络与函数式编程(三)怎么解释神经网络?中,我们考察了神经网络结构如何与高阶函数一一对应。在本篇文章中,我将展示如何用Monad表达有条件启用的动态神经网络。


三种带条件的神经网络结构

神经网络中常用的条件启用的结构有三种:

第一种是用非线性函数“过滤”输入,比如ReLU激活函数就相当于if (x > 0) x else 0的条件语句。

第二种是把“门”用于循环单元,比如LSTM和GRU的每个循环单元中都有多个门,每个门以零到一的系数乘以需要过滤的特征,相当于“软性”的条件语句。

第三种是根据每一个样本设置条件启用子网络,比如Sparsely-Gated Mixture-of-Experts。当这类条件不满足时,未启用的子网络不会运行。所以,用这种“硬性”的条件可以训练出非常巨大的神经网络,但针对每个样本只稀疏激活个别子网络,能让网络既拥有巨大的知识容量,又可以运行得很快。

接下来本文将主要聊聊如何用DeepLearning.scala实现最有意思的第三种条件结构,即“硬性”条件。

利用predict实现“硬性”条件

实现“硬性”条件的一种明显思路是先执行一遍条件,然后根据条件选择子网络即可。在DeepLearning.scala中可以写成这样:

  1. type Input = INDArrayLayer
  2. type Output = INDArrayLayer
  3. type Condition = (DoubleLayer, DoubleLayer)
  4. def leftSubnet(input: Input): Output = ???
  5. def rightSubnet(input: Input): Output = ???
  6. def gate(input: Input): Condition = ???
  7. def naiveGatedNet(input: Input): Output = {
  8. val condition = gate(input)
  9. if (condition._1.predict.blockingAwait > condition._2.predict.blockingAwait) {
  10. conditionLeft * leftSubnet(input)
  11. } else {
  12. conditionRight * rightSubnet(input)
  13. }
  14. }

以上的naiveGatedNet涉及gateleftSubnetrightSubnet三个子网络。其中gate返回两个系数,表示更倾向于选择leftSubnetrightSubnet中哪一个子网络分支。

接着,在执行选择的子网络leftSubnetrightSubnet时,乘以gate返回的系数。这样做可以让神经网络可以反向传播到gate里,使得gate也得到训练。

这样一来,naiveGatedNet就实现了“针对每个样本根据条件启用部分子网络”的效果。

然而,naiveGatedNet有一个缺陷。

在DeepLearning.scala,所有的Layer,包括标量DoubleLayer和向量INDArrayLayer,都表示一个惰性执行的可微分计算图,直到调用predict或者train才真正执行。

所以naiveGatedNet中调用了两次predict,就会导致计算图求值两次,而naiveGatedNet的用户将来调用naiveGatedNet(input).train或者naiveGatedNet(input).predict时,还会再对计算图求值一次。这样一来gate(input)就被求值了三次。更糟糕的是,Input也是个计算图,也会求值三次。而Input可能包含复杂的特征提取逻辑,重复执行可能浪费太多的计算资源。

所以,用predict实现的“硬性”条件尽管节省了计算leftSubnetrightSubnet分支的开销,但是代价是前置条件gate(input)被重复计算。这种做法未必划算。

动态生成计算图

如果要在实现“硬性”条件计算图时避免重复计算,就不能使用predict,而必须用flatMap动态生成计算图。

  1. def monadicGatedNet(input: Input): Output = {
  2. val condition = gate(input)
  3. val gatedForward = (condition._1.forward.tuple2(condition._2.forward)).flatMap { pair =>
  4. if (pair._1.data > pair._2.data) {
  5. (condition._1 * leftSubnet(input)).forward
  6. } else {
  7. (condition._2 * rightSubnet(input)).forward
  8. }
  9. }
  10. INDArrayLayer(gatedForward)
  11. }

flatMap方法是Monad.bind的别名:

  1. package scalaz
  2. trait Monad[F[_]] extends Apply[F[_]] {
  3. def point(a: => A): F[A]
  4. def bind[A, B](fa: F[A])(f: A => F[B]): F[B]
  5. }

我们之所以能写形如fa.flatMap(f)这样的代码,是因为Scalaz提供了一些Ops作为语法糖,能把flatMap转为Monad[F].bind(fa, f)调用,具体在monadicGatedNetFDo,所以(condition._1.forward.tuple2(condition._2.forward)).flatMap { pair => ... }相当于Monad[Do].bind(condition._1.forward.tuple2(condition._2.forward), { pair => ...})

如果以面向对象的角度看待Monad的话,可以把Monad视为设计模式里的“抽象工厂”,其高阶类型参数F[_]表示的则是Monad能生产的产品类型,所以monadicGatedNet中的flatMap生产的产品也就是Do

Do可以用来表示神经网络中的泛型计算图。事实上DoubleLayerINDArrayLayer就是对Do的一层包装,可以和Do互相转换:

  1. val myDoubleLayer: DoubleLayer = ???
  2. val doDouble: Do[Tape[Double, Double]] = myDoubleLayer.forward
  3. val myDoubleLayer2: DoubleLayer = DoubleLayer(doDouble)
  1. val myINDArrayLayer: INDArrayLayer = ???
  2. val doINDArray: Do[Tape[INDArray, INDArray]] = myINDArrayLayer.forward
  3. val myINDArrayLayer2: INDArrayLayer = INDArrayLayer(doINDArray)

Tape是计算图中的临时结构,这里不用细究,只要知道通过tape.data能够取得计算结果就行了。本系列的后续文章会讲解Tape的内部构造。

回到Monad[Do]中的bind,也就是我们所使用的flatMap,其签名等价于:

  1. def bind[A, B](fa: Do[A])(f: A => Do[B]): Do[B]

它的含义是生产一个两阶段的动态计算图Do[B]。第一阶段的计算图是fa。第二阶段计算图得在第一阶段计算完毕后,再由回调函数f根据第一阶段的结果A的值动态生成,而不能事先确定。

粗粒度并行计算

另一处细节是(condition._1.forward.tuple2(condition._2.forward)),调用了`tuple2把两个计算图和成一个计算图:

  1. package scalaz
  2. trait Apply[F[_]] {
  3. def tuple2[A, B](fa: => F[A], fb: => F[B]): F[(A, B)]
  4. }

这是因为gate签名里返回的是两个计算图,所以需要合并以后才能flatMap

不过,直接调用tuple2可能不一定是最高效的写法,因为Monadtuple2的默认实现是串行计算,这意味着condition._2.forward要等到condition._1.forward算完后才开始计算。理想情况下应该可以让condition._1.forwardcondition._2.forward同时计算,以便缩短延时。在DeepLearning.scala可以这样实现:

  1. def parallelGatedNet(input: Input): Output = {
  2. val condition = gate(input)
  3. val parallelCondition1Forward: ParallelDo[Tape[Double, Double]] = Parallel(condition._1.forward)
  4. val parallelCondition2Forward: ParallelDo[Tape[Double, Double]] = Parallel(condition._2.forward)
  5. val parallelConditionForward = condition1ForwardParallel.tuple2(condition2ForwardParallel)
  6. val gatedForward = parallelConditionForward.unwrap.flatMap { pair =>
  7. if (pair._1.data > pair._2.data) {
  8. (condition._1 * leftSubnet(input)).forward
  9. } else {
  10. (condition._2 * rightSubnet(input)).forward
  11. }
  12. }
  13. INDArrayLayer(gatedForward)
  14. }

以上代码中,Parallel(...)把一个Do包装成了ParallelDo,表示需要并行执行的计算图。这样一来,如果对ParallelDo调用tuple2以及其他所有的高阶函数,就会生成并行计算的计算图。在tuple2以后,需要调用upwrapParallelDo转回Do,因为flatMap是两阶段计算图,不可能支持并行计算。最后生成的计算图如下:

TODO: 此处缺一张图

事实上,所有INDArrayLayer内置二元操作都会自动用ParallelDo来并行计算。

虽然GPU本身也支持并行计算,但GPU算完一个操作后,在等待CPU给GPU提交下一次计算命令以前,GPU只能干等,计算力就被浪费了。我们利用ParallelDo提供的粗粒度并行计算,同时提交多个计算命令,可以维持GPU一直繁忙,避免发生GPU饥饿。

ThoughtWorks Each

最后,对函数式编程的初学者来说,可能比较难理解Monad的概念。我们提供了ThoughtWorks Each,能够自动把.each调用编译成flatMap或者map。初学者不需要手写flatMap就可以利用Each生成动态神经网络了:

  1. def eachGatedNet(input: Input): Output = INDArrayLayer(monadic[Do] {
  2. val condition = gate(input)
  3. val parallelCondition1Forward: ParallelDo[Tape[Double, Double]] = Parallel(condition._1.forward)
  4. val parallelCondition2Forward: ParallelDo[Tape[Double, Double]] = Parallel(condition._2.forward)
  5. val parallelConditionForward = condition1ForwardParallel.tuple2(condition2ForwardParallel)
  6. val pair = parallelConditionForward.unwrap.each
  7. if (pair._1.data > pair._2.data) {
  8. (condition._1 * leftSubnet(input)).forward.each
  9. } else {
  10. (condition._2 * rightSubnet(input)).forward.each
  11. }
  12. })

结论

Monad既简单又强大。在ThoughtWorks Each的帮助下,你可以同时获得动态计算图和粗粒度并行计算的好处,而依旧保持很简洁的代码结构。

谷歌最近的论文One Model To Learn Them All展示了通用神经网络的前景。人类的神经元数量约有860亿,包含各种不同形态和功能的专用神经细胞。未来的通用人工神经网络一定也和人类神经系统一样,既巨大又复杂,处理特定任务时动态稀疏激活。我们相信Monadic Deep Learning将在通用人工智能上大有可为!

除了动态神经网络以外,通用的神经网络会内置不同功能的模块,这就需要以可扩展的方式添加新功能。我将在下一篇文章中介绍DeepLearning.scala的插件系统如何利用依赖类型系统解决Expression Problem,既支持添加新功能也可以修改现有功能的行为。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注