@Pigmon
2017-04-05T14:36:59.000000Z
字数 2018
阅读 1362
Python
# -*- coding: utf-8 -*-import randomimport numpy as np# (x1, x2, y)train_set = ((0, 0, 0), (1, 1, 0), (0, 1, 1), (1, 0, 1))eta = 0.2# 阈值threshold = 0.1# [# [w_13, w_23, w_b13],# [w_14, w_24, w_b14],# [w_1b2, w_2b2, w_b1b2],# [w35, w45, wb25]# ]w = [[1, 1, 1], [1, 1, 1], [1, 1, 1], [1, 1, 1]]def sigmoid(x):return 1. / (1. + np.e ** (-x))def total_err_var():global train_set, thresholdtotal = 0for sample in train_set:total += (result_of(sample[0], sample[1]) - sample[2]) ** 2return total / len(train_set)def make_data():global wfor arr in w:r = 2 * np.random.random((3, 1)) - 1for i in range(0, 3):arr[i] = r[i][0]def result_of(x1, x2):global wo3 = sigmoid(w[0][0] * x1 + w[0][1] * x2 + w[0][2])o4 = sigmoid(w[1][0] * x1 + w[1][1] * x2 + w[1][2])ob2 = sigmoid(w[2][0] * x1 + w[2][1] * x2 + w[2][2])o5 = sigmoid(w[3][0] * o3 + w[3][1] * o4 + w[3][2] * ob2)return o5def outputs(x1, x2):global w# O3 = Sigmoid[w_13 * x1 + w_23 * x2 + w_b13 * 1]o3 = sigmoid(w[0][0] * x1 + w[0][1] * x2 + w[0][2])# O4 = Sigmoid[w_14 * x1 + w_24 * x2 + w_b14 * 1]o4 = sigmoid(w[1][0] * x1 + w[1][1] * x2 + w[1][2])# Ob2 = Sigmoid[w_1b2 * x1 + w_2b2 * x2 + w_b1b2 * 1]ob2 = sigmoid(w[2][0] * x1 + w[2][1] * x2 + w[2][2])# O5 = Sigmoid[w_35 * O3 + w_45 * O4 + w_b25 * Ob2]o5 = sigmoid(w[3][0] * o3 + w[3][1] * o4 + w[3][2] * ob2)arr = [x1, x2, o3, o4, ob2, o5]return x1, x2, o3, o4, ob2, o5, arrdef err_of(y, o3, o4, ob2, o5):global we5 = o5 * (1 - o5) * (y - o5)e3 = o3 * (1 - o3) * e5 * w[3][0]e4 = o4 * (1 - o4) * e5 * w[3][1]eb2 = ob2 * (1 - ob2) * e5 * w[3][2]return e5, e3, e4, eb2, [e5, e3, e4, eb2]def train():global train_set, wgo_through = Falsecnter = 0while (cnter < 20000):if total_err_var() < threshold:go_through = Trueprint ("Counter: %d" % cnter)break# 随机选择一个样本sample = random.choice(train_set)# 计算每个神经元的值x1, x2 = sample[0], sample[1]O1, O2, O3, O4, Ob2, O5, arr_o = outputs(x1, x2)# 计算误差e5, e3, e4, eb2, arr_e = err_of(sample[2], O3, O4, Ob2, O5)# 调整权重param = [x1, x2, 1]errs = [e3, e4, eb2]outs = [O3, O4, Ob2]for i in range(0, 3):for j in range(0, 3):w[i][j] += eta * errs[i] * param[j]w[3][i] += eta * e5 * outs[i]# 计数cnter += 1test_output()print (go_through)return go_throughdef test_output():print ("Result of (0, 0) = %f" % result_of(0, 0))print ("Result of (1, 1) = %f" % result_of(1, 1))print ("Result of (0, 1) = %f" % result_of(0, 1))print ("Result of (1, 0) = %f" % result_of(1, 0))if __name__ == '__main__':make_data()train()