[关闭]
@Pigmon 2017-03-23T15:05:45.000000Z 字数 1586 阅读 910

感知机示例程序

教案


  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Mar 22 21:47:27 2017
  4. @author: Yuan Sheng
  5. """
  6. import random
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. # 训练集 ((x1, x2), y)
  10. train_set = ( ((1, 1), -1), ((3, 3), 1), ((3, 4), 1) )
  11. # 准备训练出来的直线参数
  12. w1, w2, b = 0, 0, 0
  13. eta = 1 # 学习率
  14. ##################################
  15. def y_of_line(x, A, B, C):
  16. "根据直线方程求 y"
  17. if (A == 0 and B == 0):
  18. return x, [0] * 50
  19. elif (A != 0 and B == 0):
  20. x = [-1 * (C / float(A))] * 50
  21. y = np.linspace(0, 5, 50)
  22. return x, y
  23. else:
  24. return x, [-1.0 * (A / float(B)) * val - (C / float(B)) for val in x]
  25. ##################################
  26. def draw_plot():
  27. "分类结束后画图"
  28. global train_set, w1, w2, b
  29. plt.xlim(0, 5)
  30. plt.ylim(0, 5)
  31. plt.xlabel('x1')
  32. plt.ylabel('x2')
  33. # 实例点
  34. for simple in train_set:
  35. x1, x2, y = simple[0][0], simple[0][1], simple[1]
  36. color = 'r'
  37. if y == 1 : color = 'b'
  38. shape = 'x'
  39. if y == 1 : shape = '+'
  40. plt.scatter(x1, x2, 120, color, shape)
  41. # 直线
  42. x = np.linspace(0, 5, 50)
  43. x, y = y_of_line(x, w1, w2, b)
  44. plt.plot(x, y, 'g')
  45. ##################################
  46. def is_err_simple(_x1, _x2, _y):
  47. "判断一个点是否误分类. x1:横坐标; x2:纵坐标; y:标注分类"
  48. global w1, w2, b
  49. return _y * (w1 * _x1 + w2 * _x2 + b) <= 0
  50. ##################################
  51. def train():
  52. "算法的核心过程,即训练模型的过程"
  53. global train_set, w1, w2, b
  54. go_through = False
  55. cnter = 0
  56. while (cnter < 1000):
  57. err_list = [] # 存放一个步骤中所有的误分类点
  58. for simple in train_set:
  59. x1, x2, y = simple[0][0], simple[0][1], simple[1]
  60. # 找到误分类点则加入 err_list
  61. if is_err_simple(x1, x2, y):
  62. err_list.append(simple)
  63. # 如果分类完成则结束循环
  64. if len(err_list) == 0:
  65. go_through = True
  66. break
  67. # 随机选择一个误分类点
  68. err_simple = random.choice(err_list)
  69. # 调整直线参数
  70. w1 += eta * err_simple[1] * err_simple[0][0]
  71. w2 += eta * err_simple[1] * err_simple[0][1]
  72. b += eta * err_simple[1]
  73. # 计数
  74. cnter += 1
  75. model = "f(x1,x2) = sign(%d*x1 + %d*x2 + %d)" % (w1, w2, b);
  76. print model
  77. return go_through
  78. ##################################
  79. if __name__ == "__main__":
  80. train()
  81. draw_plot()
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注