@Pigmon
2017-05-28T08:30:27.000000Z
字数 6166
阅读 1232
教案
简单的很,只是做了个界面可以用鼠标画数字,给之前基于MNIST数据集训练好的模型来Predict,然后显示出识别的结果。
这东西毫无实用价值,看看就好。
我做的比较简单,输入的图像的笔画是单像素单一颜色的,所以识别率并不高。下面链接里自带的模型,包括了我在这样的界面上输入的120个训练样本,但效果并不明显(输入太累手,不想再弄了。)
用这样的界面生成MNIST格式训练样本的程序之后我会整理下发出,因为包含MNIST格式说明的部分,所以文档内容会多点,不像这个没啥可说的。
下载链接:
链接:http://pan.baidu.com/s/1pLI2I1d 密码:7wdh
如果你机器上有 Python 3.5, PyQT5, Tensorflow,TFLearn以及它们的依赖项都有安装的话,解压就可以直接运行了。
我猜你们没有,如果只想看看程序的话:
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
MNIST 手写数字识别 Demo
@Yuan Sheng
这就是一个 PyQT5 的界面程序,
目的是把你用鼠标画出来的数字给mnist_model.py去识别。
识别率很低,具体原因没你的事,留给我自己思考。
本程序受开源软件协议 @DWFYWDEC 保护。
@DWFYWDEC:Do What the Fuck You Want to Do Except Cheating.
FYI:
@DWFYWDEC 协议是本程序作者基于 @DWFYWD 协议胡编乱造的(就在几秒钟之前)。
@DWFYWD 协议是另外一个程序的作者胡编乱造的。
FYI2:
虽然是胡编乱造的但你一旦违反本协议会遭到本人疯狂的报复。
我是个疯狂的人以至于我都不知道一旦你违反本协议我会怎么报复你。
不过这个程序P用都没有所以你不用担心。
"""
import sys
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import Qt, QRect
from PyQt5.QtWidgets import QWidget, QApplication, QPushButton, QLCDNumber
import numpy as np
from mnist_model import *
class Example(QWidget):
def __init__(self):
super().__init__()
self.InitInputData()
self.initUI()
def initUI(self):
self.setGeometry(300, 300, 640, 480)
self.setWindowTitle(u'MNIST 演示')
self.initInputWidget()
self.show()
def InitInputData(self):
# 鼠标左键是不是按下了
self.mouse_pressed = False
# 目前识别的结果,为了在LCD控件中显示数字
self.predicted_nbr = -1
# 输入窗口中,每 scale*scale 个像素视作传给MNIST模型的一个像素
self.scale = 10
# 28*28的数组:输入窗口显示数据,
# 以及传递给MNIST模型进行Predict
self.pt_arr = np.zeros((28, 28))
# 把输入控件的面积分成 28*28 个rect。
# 每个rect的尺寸为 scale*scale
# 作用是检测手写输入时需要绘制成黑色的部分,
# 以及推算逻辑上的28*28的输入图像哪些像素是黑色的
self.rect_arr = np.array([QRect(0, 0, self.scale, self.scale)] * 28 * 28).reshape((28, 28))
for i in range(28):
for j in range(28):
rect = QRect(i * self.scale, j * self.scale, self.scale, self.scale)
self.rect_arr[i][j] = rect
# 读取模型
self.model = MnistModel('models/mnist5/mnist5.tfl')
def initInputWidget(self):
self.widget = QtWidgets.QWidget(self)
self.widget.setGeometry(QtCore.QRect(10, 10, 280, 280))
# -> 下面这一大段就是为了让手写输入的widget有个白色的背景而已
palette = QtGui.QPalette()
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.Base, brush)
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Active, QtGui.QPalette.Window, brush)
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.Base, brush)
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Inactive, QtGui.QPalette.Window, brush)
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.Base, brush)
brush = QtGui.QBrush(QtGui.QColor(255, 255, 255))
brush.setStyle(QtCore.Qt.SolidPattern)
palette.setBrush(QtGui.QPalette.Disabled, QtGui.QPalette.Window, brush)
self.widget.setPalette(palette)
self.widget.setAutoFillBackground(True)
# <- 到这为止
self.widget.setObjectName("input")
self.widget.paintEvent = self.paintEvent
self.widget.mousePressEvent = self.inputMousePressed
self.widget.mouseReleaseEvent = self.inputMouseReleased
self.widget.mouseMoveEvent = self.inputMouseMove
# LCD Number
self.lcd = QLCDNumber(self)
self.lcd.setGeometry(QRect(300, 10, 240, 100))
self.lcd.display("")
# 按钮
self.btn1 = QPushButton(u"识别", self)
self.btn1.move(10, 300)
self.btn2 = QPushButton(u"清除", self)
self.btn2.move(120, 300)
# 按钮事件
self.btn1.clicked.connect(self.btnPredictClicked)
self.btn2.clicked.connect(self.btnClearClicked)
def paintEvent(self, event):
"OnPaint回调,把需要画成黑色的rect画黑"
paint=QtGui.QPainter()
paint.begin(self.widget)
paint.setPen(QtCore.Qt.black)
for i in range(28):
for j in range(28):
if self.pt_arr[i][j] > 0.5:
paint.fillRect(self.rect_arr[i][j], Qt.black)
paint.end()
def keyPressEvent(self, e):
if e.key() == Qt.Key_Escape:
self.close()
def inputMouseReleased(self, e):
"鼠标左键是否弹起"
if e.button() == Qt.LeftButton:
self.mouse_pressed = False
self.widget.repaint()
def inputMousePressed(self, e):
"鼠标左键是否按下"
if e.button() == Qt.LeftButton:
self.mouse_pressed = True
else:
self.mouse_pressed = False
self.widget.repaint()
def inputMouseMove(self, e):
"""
手写输入时,按下鼠标左键后的事件响应。
即:如果鼠标按下了,那么鼠标移动的过程中,
检测rect_arr中哪些rect需要被画成黑色
"""
if self.mouse_pressed:
for i in range(28):
for j in range(28):
rect = self.rect_arr[i][j]
if rect.contains(e.pos()):
self.pt_arr[i][j] = 1.0
self.widget.repaint()
break
def btnPredictClicked(self):
"识别按钮事件响应函数"
arr = self.pt_arr.transpose().reshape((1, 28, 28, 1))
result = self.model.predict(arr)[0]
self.predicted_nbr = result.index(max(result))
self.lcd.display(self.predicted_nbr)
self.lcd.repaint()
def btnClearClicked(self):
"清除按钮事件响应函数"
self.pt_arr = np.zeros((28, 28))
self.widget.repaint()
self.predicted_nbr = -1
self.lcd.display("")
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = Example()
sys.exit(app.exec_())
"""
MNIST Predict
@Yuan Sheng
FYI:
这个程序就是TFLearn的例子简单的修改。
因为那个例子里没有声明什么吓人的东西,
所以我也不知道一旦你拿这个程序做些为非作歹的事情,
TFLearn 和 Y. LeCun 会怎么报复你。
"""
""" Convolutional Neural Network for MNIST dataset classification task.
References:
Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
learning applied to document recognition." Proceedings of the IEEE,
86(11):2278-2324, November 1998.
Links:
[MNIST Dataset] http://yann.lecun.com/exdb/mnist/
"""
from __future__ import division, print_function, absolute_import
import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
class MnistModel:
def __init__(self, _model_path):
network = input_data(shape=[None, 28, 28, 1], name='input')
network = conv_2d(network, 32, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = conv_2d(network, 64, 3, activation='relu', regularizer="L2")
network = max_pool_2d(network, 2)
network = local_response_normalization(network)
network = fully_connected(network, 128, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 256, activation='tanh')
network = dropout(network, 0.8)
network = fully_connected(network, 10, activation='softmax')
network = regression(network, optimizer='adam', learning_rate=0.01,
loss='categorical_crossentropy', name='target')
self.model = tflearn.DNN(network)
self.model.load(_model_path)
def predict(self, _input_tensor):
return self.model.predict(_input_tensor)