@Pigmon
2017-03-23T07:05:45.000000Z
字数 1586
阅读 1183
教案
# -*- coding: utf-8 -*-"""Created on Wed Mar 22 21:47:27 2017@author: Yuan Sheng"""import randomimport matplotlib.pyplot as pltimport numpy as np# 训练集 ((x1, x2), y)train_set = ( ((1, 1), -1), ((3, 3), 1), ((3, 4), 1) )# 准备训练出来的直线参数w1, w2, b = 0, 0, 0eta = 1 # 学习率##################################def y_of_line(x, A, B, C):"根据直线方程求 y"if (A == 0 and B == 0):return x, [0] * 50elif (A != 0 and B == 0):x = [-1 * (C / float(A))] * 50y = np.linspace(0, 5, 50)return x, yelse:return x, [-1.0 * (A / float(B)) * val - (C / float(B)) for val in x]##################################def draw_plot():"分类结束后画图"global train_set, w1, w2, bplt.xlim(0, 5)plt.ylim(0, 5)plt.xlabel('x1')plt.ylabel('x2')# 实例点for simple in train_set:x1, x2, y = simple[0][0], simple[0][1], simple[1]color = 'r'if y == 1 : color = 'b'shape = 'x'if y == 1 : shape = '+'plt.scatter(x1, x2, 120, color, shape)# 直线x = np.linspace(0, 5, 50)x, y = y_of_line(x, w1, w2, b)plt.plot(x, y, 'g')##################################def is_err_simple(_x1, _x2, _y):"判断一个点是否误分类. x1:横坐标; x2:纵坐标; y:标注分类"global w1, w2, breturn _y * (w1 * _x1 + w2 * _x2 + b) <= 0##################################def train():"算法的核心过程,即训练模型的过程"global train_set, w1, w2, bgo_through = Falsecnter = 0while (cnter < 1000):err_list = [] # 存放一个步骤中所有的误分类点for simple in train_set:x1, x2, y = simple[0][0], simple[0][1], simple[1]# 找到误分类点则加入 err_listif is_err_simple(x1, x2, y):err_list.append(simple)# 如果分类完成则结束循环if len(err_list) == 0:go_through = Truebreak# 随机选择一个误分类点err_simple = random.choice(err_list)# 调整直线参数w1 += eta * err_simple[1] * err_simple[0][0]w2 += eta * err_simple[1] * err_simple[0][1]b += eta * err_simple[1]# 计数cnter += 1model = "f(x1,x2) = sign(%d*x1 + %d*x2 + %d)" % (w1, w2, b);print modelreturn go_through##################################if __name__ == "__main__":train()draw_plot()
