@Pigmon
2017-03-29T06:04:19.000000Z
字数 2885
阅读 1164
教案
# -*- coding: utf-8 -*-"""Created on Sat Mar 25 21:32:34 2017@author: Yuan Sheng"""import matplotlib.pyplot as pltimport numpy as npimport random##################################def createDataSet2():"生成两块随机点"groupA = []groupB = []direction = [1, -1]for i in range(0, 200):groupA.append([50 + random.choice(direction) * random.choice(range(50)), 50 + random.choice(direction) * random.choice(range(50))])groupB.append([150 + random.choice(direction) * random.choice(range(50)), 150 + random.choice(direction) * random.choice(range(50))])return groupA, groupB##################################def distance(A, B):"欧氏距离"return np.sqrt((A[0] - B[0])**2 + (A[1] - B[1])**2)def P_test_of_class(test, group_a, group_b):"计算 P(测试点|类别 A), P(测试点|类别 B) 可以不看"# 计算特征符合的最小距离LEAST_DIST = 20total_sample = len(group_a) + len(group_b)count_a, count_b = 0, 0# 计数 Afor pt in group_a:if (distance(test, pt) <= LEAST_DIST):count_a += 1;# 计数 Bfor pt in group_b:if (distance(test, pt) <= LEAST_DIST):count_b += 1;# P(测试点|类别 A)p_test_A = float(count_a) / float(total_sample)# P(测试点|类别 B)p_test_B = float(count_b) / float(total_sample)return p_test_A, p_test_Bdef P_class(group_a, group_b):"计算P(类别 A)和P(类别 B)"p_A = float(len(group_a)) / float(len(group_a) + len(group_b))p_B = 1. - p_Areturn p_A, p_B##################################def naive_bayes(test, group_a, group_b):"""核心函数. 示范用,没有安全检查# 根据贝叶斯公式:# P(类别 A | 测试点) = [P(测试点 | 类别 A) * P(类别 A)] / P(测试点)# P(类别 B | 测试点) = [P(测试点 | 类别 B) * P(类别 B)] / P(测试点)# 哪个概率大,测试点即属于哪个类别。# 因为 2 个表达式<分母相同>,所以<只比较分子>就可以了!"""# P(测试点|类别 A), P(测试点|类别 B)p_test_A, p_test_B = P_test_of_class(test, group_a, group_b)# P(类别 A), P(类别 B)p_A, p_B = P_class(group_a, group_b)# # 贝叶斯公式的分子# P(测试点 |类别 A) * P(类别 A)num_A_test = p_test_A * p_A# P(测试点 |类别 B) * P(类别 B)num_B_test = p_test_B * p_B# # 返回值 0-无法分类;1-A类;2-B类。if (num_A_test == 0. and num_B_test == 0.):ret = 0else:if num_A_test >= num_B_test:ret = 1else:ret = 2return ret, num_A_test, num_B_test##################################def draw_plot(group_a, group_b, cls, num_a, num_b):"画图表,可以不看"x, y, color, size = [], [], [], []for pt in group_a:x.append(pt[0])y.append(pt[1])color.append('b')size.append(40)for pt in group_b:x.append(pt[0])y.append(pt[1])color.append('r')size.append(40)plt.xlabel('Chocolate per-day')plt.ylabel('Ice Cream per-day')plt.scatter(x,y,size,color)x1 = [test[0]]y1 = [test[1]]text = 'None'color2 = 'k'text_num_A = 'P(A | x,y) = P(x,y | A) * P(A) / P(x,y) = ' + str(num_a) + " / P(x,y)"text_num_B = 'P(B | x,y) = P(x,y | B) * P(B) / P(x,y) = ' + str(num_b) + " / P(x,y)"if cls == 2:text = 'B'color2 = 'r'elif cls == 1:text = 'A'color2 = 'b'plt.scatter(x1, y1, [200], ['g'])plt.text(200, 0, text,verticalalignment='bottom', horizontalalignment='right',color=color2, fontsize=20)plt.text(-45, 225, text_num_A,verticalalignment='bottom', horizontalalignment='left',color='k', fontsize=12)plt.text(-45, 205, text_num_B,verticalalignment='bottom', horizontalalignment='left',color='k', fontsize=12)plt.show()##################################if __name__ == '__main__':groupA, groupB = createDataSet2()# 测试点, 可以自己修改试试test = [100., 100.]cls, num_a, num_b = naive_bayes(test, groupA, groupB)draw_plot(groupA, groupB, cls, num_a, num_b)print cls