[关闭]
@liuhui0803 2017-05-13T16:45:32.000000Z 字数 4006 阅读 2740

MXNet API入门 —第4篇

机器学习 深度学习 神经网络 MXNet AWS


摘要:

Apache MXNet是一种功能全面、可以灵活编程并且扩展能力超强的深度学习框架,支持包括卷积神经网络(CNN)与长短期记忆网络(LSTM)在内的顶尖深度模型。这一系列文章介绍了MXNet的基本概念和使用方法。本篇主要介绍如何使用训练模型进行图片分类。

正文:

第3篇文章中,我们构建并训练了第一个神经网络,接下来可以处理一些更复杂的样本了。

最顶尖的深度学习模型通常都复杂到让人难以置信。其中可能包含数百层,就算用不了数周,往往也要数天时间来使用海量数据进行训练。这类模型的构建和优化需要大量经验。

好在这些模型的使用还是很简单的,通常只需要编写几行代码。本文将使用一个名为Inception v3的预训练模型进行图片分类。

Inception v3

诞生于2015年12月的Inception v3GoogleNet模型(曾赢得2014年度ImageNet挑战赛)的改进版。本文不准备深入介绍该模型的研究论文,不过打算强调一下论文的结论:相比当时最棒的模型,Inception v3的准确度高出了15%–25%,同时计算的经济性方面低六倍,并且至少将参数的数量减少了五倍(例如使用该模型对内存的要求更低)。

简直就是神器!那么我们该如何使用?

MXNet model zoo

Model zoo提供了一系列可直接使用的预训练模型,并且通常还会提供模型定义模型参数(例如神经元权重),(也许还会提供)使用说明。

首先来下载定义和参数(你也许需要更改文件名)。第一个文件可以直接打开:其中包含了每一层的定义。第二个文件是一个二进制文件,请不要打开 ;)

  1. $ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-symbol.json
  2. $ wget http://data.dmlc.ml/models/imagenet/inception-bn/Inception-BN-0126.params
  3. $ mv Inception-BN-0126.params Inception-BN-0000.params

该模型已通过ImageNet数据集进行了训练,因此我们还需要下载对应的图片分类清单(共有1000个分类)。

  1. $ wget http://data.dmlc.ml/models/imagenet/synset.txt
  2. $ wc -l synset.txt
  3. 1000 synset.txt
  4. $ head -5 synset.txt
  5. n01440764 tench, Tinca tinca
  6. n01443537 goldfish, Carassius auratus
  7. n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
  8. n01491361 tiger shark, Galeocerdo cuvieri
  9. n01494475 hammerhead, hammerhead shark

搞定,开始实战。

加载模型

我们需要:

  1. import mxnet as mx
  2. sym, arg_params, aux_params = mx.model.load_checkpoint('Inception-BN', 0)
  1. mod = mx.mod.Module(symbol=sym)
  1. mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
  1. mod.set_params(arg_params, aux_params)

这样就可以了。只需要四行代码!随后可以放入一些数据看看会发生什么。嗯……先别急。

准备数据

数据准备:从七十年代以来,这一直是个痛苦的过程……从关系型数据库到机器学习,再到深度学习,这方面没有任何改进。虽然乏味但很必要。开始吧。

还记得吗,这个模型需要通过四维NDArray来保存一张224x224分辨率图片的红、绿、蓝通道数据。我们将使用流行的OpenCV库从输入图片中构建这样的NDArray。如果还没安装OpenCV,考虑到本例的要求,直接运行pip install opencv-python就够了 :)。

随后的步骤如下:

  1. img = cv2.imread(filename)
  1. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  1. img = cv2.resize(img, (224, 224,))
  1. img = np.swapaxes(img, 0, 2)
  2. img = np.swapaxes(img, 1, 2)
  1. img = img[np.newaxis, :]
  2. array = mx.nd.array(img)
  3. >>> print array.shape
  4. (1L, 3L, 224L, 224L)

晕了?一起用个例子看看吧。输入下列这张图片:

1-sPdrfGtDd_6RQfYvD5qcyg.jpeg-45.6kB
输入448x336的图片(来源:metaltraveller.com)

处理完毕后,该图会被缩小尺寸并拆分为RGB通道,存储在array[0]中(生成下文图片的代码可参阅这里)。

1-yqdl78KIugYepzJ4-lMY8g.jpeg-11.8kB
array[0][0]:224x224,红色通道

1-sitDwoAzPDLrav0dXQrcbA.jpeg-12.3kB
array0:224x224,绿色通道

1-h1RyEPvd2fqIgd2jkfES-w.jpeg-7kB
array0:224x224,蓝色通道

如果批大小大于1,那么可以通过array1指定第二张图片,使用array2指定第三张图片,以此类推。

无论这个过程是乏味还是有趣,接下来我们开始预测吧!

开始预测

你可能还记得第3篇文章中提到,Module对象必须以为单位向模型提供数据:最常见的做法是使用数据迭代器(因此我们使用了NDArrayIter对象)。

在这里我们想要预测一张图片,因此尽管可以使用数据迭代器,不过也没啥必要。但我们可以创建一个名为Batch的具名元组(Named tuple),它可以充当假的迭代器,在引用数据属性时返回输入的NDArray。

  1. from collections import namedtuple
  2. Batch = namedtuple('Batch', ['data'])

随后即可将这个“Batch”传递给模型开始预测。

  1. mod.forward(Batch([array]))

这个模型会输出一个包含1000个可能性的NDArray,每个可能性对应一个分类。由于批大小等于1,因此只需要一行代码。

  1. prob = mod.get_outputs()[0].asnumpy()
  2. >>> prob.shape
  3. (1, 1000)

使用squeeze()将其转换为数组,随后使用argsort()创建第二个数组,其中保存了这些可能性按照降序排列的指数

  1. prob = np.squeeze(prob)
  2. >>> prob.shape
  3. (1000,)
  4. >> prob
  5. [ 4.14978594e-08 1.31608676e-05 2.51907986e-05 2.24045834e-05
  6. 2.30327873e-06 3.40798979e-05 7.41563645e-06 3.04062659e-08 etc.
  7. sortedprob = np.argsort(prob)[::-1]
  8. >> sortedprob.shape
  9. (1000,)

根据模型的计算,这张图片最可能的分类是#546,可能性为58%

  1. >> sortedprob
  2. [546 819 862 818 542 402 650 420 983 632 733 644 513 875 776 917 795
  3. etc.
  4. >> prob[546]
  5. 0.58039135

这个分类叫什么名字呢?我们可以使用synset.txt文件构建分类清单,并找出546号的名称。

  1. synsetfile = open('synset.txt', 'r')
  2. categorylist = []
  3. for line in synsetfile:
  4. categorylist.append(line.rstrip())
  5. >>> categorylist[546]
  6. 'n03272010 electric guitar'

可能性第二大的分类是什么?

  1. >>> prob[819]
  2. 0.27168664
  3. >>> categorylist[819]
  4. 'n04296562 stage

挺棒的,你说呢?

就是这样,我们已经了解了如何使用预训练的顶尖模型进行图片分类。而这一切只需要4行代码……除此之外只要准备好数据就够了。

完整代码如下,请自行尝试并继续保持关注 😃

代码已发布至GitHub:mxnet_example2.py

后续内容:

作者Julien Simon阅读英文原文An introduction to the MXNet API — part 4

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注