[关闭]
@Pigmon 2017-03-29T14:04:19.000000Z 字数 2885 阅读 875

朴素贝叶斯分类器 示范程序 Python

教案


  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sat Mar 25 21:32:34 2017
  4. @author: Yuan Sheng
  5. """
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import random
  9. ##################################
  10. def createDataSet2():
  11. "生成两块随机点"
  12. groupA = []
  13. groupB = []
  14. direction = [1, -1]
  15. for i in range(0, 200):
  16. groupA.append([50 + random.choice(direction) * random.choice(range(50)), 50 + random.choice(direction) * random.choice(range(50))])
  17. groupB.append([150 + random.choice(direction) * random.choice(range(50)), 150 + random.choice(direction) * random.choice(range(50))])
  18. return groupA, groupB
  19. ##################################
  20. def distance(A, B):
  21. "欧氏距离"
  22. return np.sqrt((A[0] - B[0])**2 + (A[1] - B[1])**2)
  23. def P_test_of_class(test, group_a, group_b):
  24. "计算 P(测试点|类别 A), P(测试点|类别 B) 可以不看"
  25. # 计算特征符合的最小距离
  26. LEAST_DIST = 20
  27. total_sample = len(group_a) + len(group_b)
  28. count_a, count_b = 0, 0
  29. # 计数 A
  30. for pt in group_a:
  31. if (distance(test, pt) <= LEAST_DIST):
  32. count_a += 1;
  33. # 计数 B
  34. for pt in group_b:
  35. if (distance(test, pt) <= LEAST_DIST):
  36. count_b += 1;
  37. # P(测试点|类别 A)
  38. p_test_A = float(count_a) / float(total_sample)
  39. # P(测试点|类别 B)
  40. p_test_B = float(count_b) / float(total_sample)
  41. return p_test_A, p_test_B
  42. def P_class(group_a, group_b):
  43. "计算P(类别 A)和P(类别 B)"
  44. p_A = float(len(group_a)) / float(len(group_a) + len(group_b))
  45. p_B = 1. - p_A
  46. return p_A, p_B
  47. ##################################
  48. def naive_bayes(test, group_a, group_b):
  49. """核心函数. 示范用,没有安全检查
  50. # 根据贝叶斯公式:
  51. # P(类别 A | 测试点) = [P(测试点 | 类别 A) * P(类别 A)] / P(测试点)
  52. # P(类别 B | 测试点) = [P(测试点 | 类别 B) * P(类别 B)] / P(测试点)
  53. # 哪个概率大,测试点即属于哪个类别。
  54. # 因为 2 个表达式<分母相同>,所以<只比较分子>就可以了!"""
  55. # P(测试点|类别 A), P(测试点|类别 B)
  56. p_test_A, p_test_B = P_test_of_class(test, group_a, group_b)
  57. # P(类别 A), P(类别 B)
  58. p_A, p_B = P_class(group_a, group_b)
  59. # # 贝叶斯公式的分子
  60. # P(测试点 |类别 A) * P(类别 A)
  61. num_A_test = p_test_A * p_A
  62. # P(测试点 |类别 B) * P(类别 B)
  63. num_B_test = p_test_B * p_B
  64. # # 返回值 0-无法分类;1-A类;2-B类。
  65. if (num_A_test == 0. and num_B_test == 0.):
  66. ret = 0
  67. else:
  68. if num_A_test >= num_B_test:
  69. ret = 1
  70. else:
  71. ret = 2
  72. return ret, num_A_test, num_B_test
  73. ##################################
  74. def draw_plot(group_a, group_b, cls, num_a, num_b):
  75. "画图表,可以不看"
  76. x, y, color, size = [], [], [], []
  77. for pt in group_a:
  78. x.append(pt[0])
  79. y.append(pt[1])
  80. color.append('b')
  81. size.append(40)
  82. for pt in group_b:
  83. x.append(pt[0])
  84. y.append(pt[1])
  85. color.append('r')
  86. size.append(40)
  87. plt.xlabel('Chocolate per-day')
  88. plt.ylabel('Ice Cream per-day')
  89. plt.scatter(x,y,size,color)
  90. x1 = [test[0]]
  91. y1 = [test[1]]
  92. text = 'None'
  93. color2 = 'k'
  94. text_num_A = 'P(A | x,y) = P(x,y | A) * P(A) / P(x,y) = ' + str(num_a) + " / P(x,y)"
  95. text_num_B = 'P(B | x,y) = P(x,y | B) * P(B) / P(x,y) = ' + str(num_b) + " / P(x,y)"
  96. if cls == 2:
  97. text = 'B'
  98. color2 = 'r'
  99. elif cls == 1:
  100. text = 'A'
  101. color2 = 'b'
  102. plt.scatter(x1, y1, [200], ['g'])
  103. plt.text(200, 0, text,
  104. verticalalignment='bottom', horizontalalignment='right',
  105. color=color2, fontsize=20)
  106. plt.text(-45, 225, text_num_A,
  107. verticalalignment='bottom', horizontalalignment='left',
  108. color='k', fontsize=12)
  109. plt.text(-45, 205, text_num_B,
  110. verticalalignment='bottom', horizontalalignment='left',
  111. color='k', fontsize=12)
  112. plt.show()
  113. ##################################
  114. if __name__ == '__main__':
  115. groupA, groupB = createDataSet2()
  116. # 测试点, 可以自己修改试试
  117. test = [100., 100.]
  118. cls, num_a, num_b = naive_bayes(test, groupA, groupB)
  119. draw_plot(groupA, groupB, cls, num_a, num_b)
  120. print cls
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注