@Pigmon
2017-03-23T15:05:45.000000Z
字数 1586
阅读 910
教案
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 22 21:47:27 2017
@author: Yuan Sheng
"""
import random
import matplotlib.pyplot as plt
import numpy as np
# 训练集 ((x1, x2), y)
train_set = ( ((1, 1), -1), ((3, 3), 1), ((3, 4), 1) )
# 准备训练出来的直线参数
w1, w2, b = 0, 0, 0
eta = 1 # 学习率
##################################
def y_of_line(x, A, B, C):
"根据直线方程求 y"
if (A == 0 and B == 0):
return x, [0] * 50
elif (A != 0 and B == 0):
x = [-1 * (C / float(A))] * 50
y = np.linspace(0, 5, 50)
return x, y
else:
return x, [-1.0 * (A / float(B)) * val - (C / float(B)) for val in x]
##################################
def draw_plot():
"分类结束后画图"
global train_set, w1, w2, b
plt.xlim(0, 5)
plt.ylim(0, 5)
plt.xlabel('x1')
plt.ylabel('x2')
# 实例点
for simple in train_set:
x1, x2, y = simple[0][0], simple[0][1], simple[1]
color = 'r'
if y == 1 : color = 'b'
shape = 'x'
if y == 1 : shape = '+'
plt.scatter(x1, x2, 120, color, shape)
# 直线
x = np.linspace(0, 5, 50)
x, y = y_of_line(x, w1, w2, b)
plt.plot(x, y, 'g')
##################################
def is_err_simple(_x1, _x2, _y):
"判断一个点是否误分类. x1:横坐标; x2:纵坐标; y:标注分类"
global w1, w2, b
return _y * (w1 * _x1 + w2 * _x2 + b) <= 0
##################################
def train():
"算法的核心过程,即训练模型的过程"
global train_set, w1, w2, b
go_through = False
cnter = 0
while (cnter < 1000):
err_list = [] # 存放一个步骤中所有的误分类点
for simple in train_set:
x1, x2, y = simple[0][0], simple[0][1], simple[1]
# 找到误分类点则加入 err_list
if is_err_simple(x1, x2, y):
err_list.append(simple)
# 如果分类完成则结束循环
if len(err_list) == 0:
go_through = True
break
# 随机选择一个误分类点
err_simple = random.choice(err_list)
# 调整直线参数
w1 += eta * err_simple[1] * err_simple[0][0]
w2 += eta * err_simple[1] * err_simple[0][1]
b += eta * err_simple[1]
# 计数
cnter += 1
model = "f(x1,x2) = sign(%d*x1 + %d*x2 + %d)" % (w1, w2, b);
print model
return go_through
##################################
if __name__ == "__main__":
train()
draw_plot()