準備

Googleドライブのマウント

In [1]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

sys.pathの設定

以下では,Googleドライブのマイドライブ直下にDNN_codeフォルダを置くことを仮定しています.必要に応じて,パスを変更してください.

In [2]:
import sys
sys.path.append('/content/drive/My Drive/DNN_code_colab_lesson_3_4')

simple RNN after

バイナリ加算

In [6]:
import numpy as np
from common import functions
import matplotlib.pyplot as plt


def d_tanh(x):
    return 1/(np.cosh(x) ** 2)

# データを用意
# 2進数の桁数
binary_dim = 8
# 最大値 + 1
largest_number = pow(2, binary_dim)
# largest_numberまで2進数を用意
binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)

input_layer_size = 2
hidden_layer_size = 16
output_layer_size = 1

weight_init_std = 1
learning_rate = 0.1

iters_num = 10000
plot_interval = 100

# ウェイト初期化 (バイアスは簡単のため省略)
W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)
# Xavier
# W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size))
# W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size))
# W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size))
# He
# W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size)) * np.sqrt(2)
# W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)
# W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)


# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)

u = np.zeros((hidden_layer_size, binary_dim + 1))
z = np.zeros((hidden_layer_size, binary_dim + 1))
y = np.zeros((output_layer_size, binary_dim))

delta_out = np.zeros((output_layer_size, binary_dim))
delta = np.zeros((hidden_layer_size, binary_dim + 1))

all_losses = []

for i in range(iters_num):
    
    # A, B初期化 (a + b = d)
    a_int = np.random.randint(largest_number/2)#【コメント】256/2
    a_bin = binary[a_int] # binary encoding
    b_int = np.random.randint(largest_number/2)#【コメント】256/2
    b_bin = binary[b_int] # binary encoding
    
    # 正解データ
    d_int = a_int + b_int
    d_bin = binary[d_int]
    
    # 出力バイナリ
    out_bin = np.zeros_like(d_bin)
    
    # 時系列全体の誤差
    all_loss = 0    
    
    # 時系列ループ
    for t in range(binary_dim):
        # 入力値
        X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)#【コメント】LSBから順に取り出し

        # 時刻tにおける正解データ
        dd = np.array([d_bin[binary_dim - t - 1]])
        
        u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)
        z[:,t+1] = functions.sigmoid(u[:,t+1])
#         z[:,t+1] = functions.relu(u[:,t+1])
#         z[:,t+1] = np.tanh(u[:,t+1])    
        y[:,t] = functions.sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))


        #誤差
        loss = functions.mean_squared_error(dd, y[:,t])
        
        delta_out[:,t] = functions.d_mean_squared_error(dd, y[:,t]) * functions.d_sigmoid(y[:,t])        
        
        all_loss += loss

        out_bin[binary_dim - t - 1] = np.round(y[:,t])
    
    #【コメント】逆伝播
    for t in range(binary_dim)[::-1]:
        X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1)        

        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * functions.d_sigmoid(u[:,t+1])
#         delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * functions.d_relu(u[:,t+1])
#         delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_tanh(u[:,t+1])    

        # 勾配更新
        W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))
        W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))
        W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))
    
    # 勾配適用
    W_in -= learning_rate * W_in_grad
    W_out -= learning_rate * W_out_grad
    W -= learning_rate * W_grad
    
    W_in_grad *= 0
    W_out_grad *= 0
    W_grad *= 0
    

    if(i % plot_interval == 0):
        all_losses.append(all_loss)        
        print("iters:" + str(i))
        print("Loss:" + str(all_loss))
        print("Pred:" + str(out_bin))
        print("True:" + str(d_bin))
        out_int = 0
        for index,x in enumerate(reversed(out_bin)):
            out_int += x * pow(2, index)
        print(str(a_int) + " + " + str(b_int) + " = " + str(out_int))
        print("------------")

lists = range(0, iters_num, plot_interval)
plt.plot(lists, all_losses, label="loss")
plt.show()
iters:0
Loss:0.8654591383439132
Pred:[1 1 1 1 1 1 1 1]
True:[1 1 1 0 1 1 1 0]
116 + 122 = 255
------------
iters:100
Loss:0.9740283924195009
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 0 0 0 1 1]
21 + 78 = 0
------------
iters:200
Loss:1.058733339097931
Pred:[1 1 1 1 1 1 1 1]
True:[1 0 0 1 0 0 1 0]
71 + 75 = 255
------------
iters:300
Loss:0.9899234430190385
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 1 0 1 0 1 0]
66 + 104 = 129
------------
iters:400
Loss:1.1032728212307035
Pred:[1 1 1 1 1 0 1 1]
True:[0 1 1 0 1 1 0 0]
76 + 32 = 251
------------
iters:500
Loss:1.0672994463059597
Pred:[1 1 1 1 1 1 1 1]
True:[0 1 0 1 1 1 0 0]
79 + 13 = 255
------------
iters:600
Loss:0.911258109277485
Pred:[1 1 1 0 0 0 1 1]
True:[0 1 1 1 0 0 0 0]
80 + 32 = 227
------------
iters:700
Loss:0.809581189193525
Pred:[1 1 1 1 1 1 1 1]
True:[1 0 1 1 1 1 1 0]
75 + 115 = 255
------------
iters:800
Loss:1.1700139118042676
Pred:[0 0 0 0 0 0 0 0]
True:[1 0 1 1 1 0 1 1]
89 + 98 = 0
------------
iters:900
Loss:1.2867065945910034
Pred:[0 0 0 0 0 0 0 0]
True:[1 0 1 1 1 1 0 1]
67 + 122 = 0
------------
iters:1000
Loss:1.1254587032348589
Pred:[0 0 1 1 1 0 0 1]
True:[0 1 0 0 0 1 0 0]
44 + 24 = 57
------------
iters:1100
Loss:1.0502458695271117
Pred:[0 0 0 0 0 0 1 1]
True:[0 1 1 1 0 1 0 1]
75 + 42 = 3
------------
iters:1200
Loss:0.7636776745979322
Pred:[0 0 0 0 0 0 0 0]
True:[0 0 0 0 0 1 0 0]
2 + 2 = 0
------------
iters:1300
Loss:0.8523588749209716
Pred:[1 1 1 1 1 1 1 1]
True:[0 1 0 1 1 0 1 1]
86 + 5 = 255
------------
iters:1400
Loss:0.9360499826914965
Pred:[1 1 1 0 0 0 1 0]
True:[1 0 0 0 0 1 1 0]
37 + 97 = 226
------------
iters:1500
Loss:1.135834083671794
Pred:[0 0 0 0 0 0 0 0]
True:[0 1 1 0 1 0 1 1]
97 + 10 = 0
------------
iters:1600
Loss:0.8947185256742668
Pred:[1 0 0 0 0 0 0 0]
True:[1 1 1 0 1 0 0 0]
124 + 108 = 128
------------
iters:1700
Loss:0.8705583046937302
Pred:[1 1 0 0 0 0 0 1]
True:[1 0 1 1 0 0 0 1]
81 + 96 = 193
------------
iters:1800
Loss:0.6954041685935325
Pred:[0 0 1 0 0 1 1 1]
True:[0 0 1 0 0 1 1 0]
35 + 3 = 39
------------
iters:1900
Loss:0.894221111509363
Pred:[0 1 0 0 0 1 0 1]
True:[0 1 0 1 0 0 0 1]
7 + 74 = 69
------------
iters:2000
Loss:0.730904677309192
Pred:[1 1 1 1 1 1 1 1]
True:[0 1 0 1 1 1 1 1]
28 + 67 = 255
------------
iters:2100
Loss:0.8188764426891197
Pred:[1 1 1 1 1 0 1 0]
True:[1 1 0 1 1 0 1 0]
106 + 112 = 250
------------
iters:2200
Loss:0.8738014848728315
Pred:[0 0 0 0 1 0 0 1]
True:[1 0 0 0 1 0 1 1]
122 + 17 = 9
------------
iters:2300
Loss:0.8747532472353428
Pred:[1 0 1 1 1 0 1 1]
True:[1 0 1 0 1 1 1 1]
54 + 121 = 187
------------
iters:2400
Loss:0.7634431790719484
Pred:[0 1 0 0 1 0 1 1]
True:[0 1 0 1 0 0 1 1]
78 + 5 = 75
------------
iters:2500
Loss:0.5822355466719209
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 0 0 1 0 1 1]
26 + 49 = 123
------------
iters:2600
Loss:0.7407450343224602
Pred:[0 1 0 1 1 0 1 1]
True:[1 0 0 1 1 0 1 1]
126 + 29 = 91
------------
iters:2700
Loss:0.7111884624063335
Pred:[1 0 0 0 0 0 1 1]
True:[1 1 0 0 0 0 1 1]
72 + 123 = 131
------------
iters:2800
Loss:0.6302858246036273
Pred:[0 1 0 0 0 1 0 0]
True:[0 1 0 0 1 1 1 0]
25 + 53 = 68
------------
iters:2900
Loss:0.5894784545194788
Pred:[0 0 0 0 1 1 1 0]
True:[1 0 0 0 1 1 1 0]
96 + 46 = 14
------------
iters:3000
Loss:0.6209309788663182
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 1 0 0 0]
5 + 115 = 116
------------
iters:3100
Loss:0.3960680157481153
Pred:[1 0 1 1 0 1 1 0]
True:[1 0 1 1 0 1 1 0]
96 + 86 = 182
------------
iters:3200
Loss:0.3198058207110639
Pred:[0 1 1 0 0 1 1 0]
True:[0 1 1 0 0 1 1 0]
86 + 16 = 102
------------
iters:3300
Loss:0.43527064883193256
Pred:[1 0 1 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
86 + 73 = 191
------------
iters:3400
Loss:0.30995843748446494
Pred:[1 0 1 0 0 1 0 0]
True:[1 0 1 0 0 1 0 0]
76 + 88 = 164
------------
iters:3500
Loss:0.20649628783395968
Pred:[1 0 0 1 1 1 0 0]
True:[1 0 0 1 1 1 0 0]
49 + 107 = 156
------------
iters:3600
Loss:0.1546370842260213
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
16 + 69 = 85
------------
iters:3700
Loss:0.15683505109765608
Pred:[1 0 1 0 1 0 0 0]
True:[1 0 1 0 1 0 0 0]
120 + 48 = 168
------------
iters:3800
Loss:0.15932534270446014
Pred:[1 1 0 0 0 1 0 0]
True:[1 1 0 0 0 1 0 0]
76 + 120 = 196
------------
iters:3900
Loss:0.14500035042904272
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
22 + 114 = 136
------------
iters:4000
Loss:0.09275867042961376
Pred:[1 0 1 1 1 0 0 1]
True:[1 0 1 1 1 0 0 1]
85 + 100 = 185
------------
iters:4100
Loss:0.09633459836783853
Pred:[1 0 0 1 1 0 0 0]
True:[1 0 0 1 1 0 0 0]
79 + 73 = 152
------------
iters:4200
Loss:0.05311579031378735
Pred:[1 0 0 0 1 0 1 1]
True:[1 0 0 0 1 0 1 1]
24 + 115 = 139
------------
iters:4300
Loss:0.06747033867560986
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 0 1 0 0]
21 + 95 = 116
------------
iters:4400
Loss:0.08441714840631113
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
110 + 60 = 170
------------
iters:4500
Loss:0.07583494379675933
Pred:[0 0 1 1 1 0 0 0]
True:[0 0 1 1 1 0 0 0]
46 + 10 = 56
------------
iters:4600
Loss:0.02618890113658895
Pred:[1 1 1 0 1 1 0 0]
True:[1 1 1 0 1 1 0 0]
125 + 111 = 236
------------
iters:4700
Loss:0.02942986696341156
Pred:[1 0 0 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
45 + 114 = 159
------------
iters:4800
Loss:0.029644026353909506
Pred:[0 0 0 1 0 1 1 0]
True:[0 0 0 1 0 1 1 0]
4 + 18 = 22
------------
iters:4900
Loss:0.01569710300694916
Pred:[0 0 1 1 0 1 1 1]
True:[0 0 1 1 0 1 1 1]
22 + 33 = 55
------------
iters:5000
Loss:0.010830219190800709
Pred:[0 1 1 0 1 1 0 0]
True:[0 1 1 0 1 1 0 0]
5 + 103 = 108
------------
iters:5100
Loss:0.0192386122909721
Pred:[0 0 1 0 1 1 1 0]
True:[0 0 1 0 1 1 1 0]
34 + 12 = 46
------------
iters:5200
Loss:0.021072926713136035
Pred:[1 0 1 1 1 1 0 1]
True:[1 0 1 1 1 1 0 1]
115 + 74 = 189
------------
iters:5300
Loss:0.021038507262215598
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
88 + 82 = 170
------------
iters:5400
Loss:0.021271060459267455
Pred:[0 1 0 0 1 0 0 1]
True:[0 1 0 0 1 0 0 1]
22 + 51 = 73
------------
iters:5500
Loss:0.01451537190833715
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
29 + 56 = 85
------------
iters:5600
Loss:0.01576904297650988
Pred:[1 0 1 1 1 1 0 0]
True:[1 0 1 1 1 1 0 0]
80 + 108 = 188
------------
iters:5700
Loss:0.012353839432875473
Pred:[1 1 0 1 1 1 0 1]
True:[1 1 0 1 1 1 0 1]
125 + 96 = 221
------------
iters:5800
Loss:0.01474226297394392
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
24 + 102 = 126
------------
iters:5900
Loss:0.009712824610203305
Pred:[0 1 1 1 0 0 0 1]
True:[0 1 1 1 0 0 0 1]
109 + 4 = 113
------------
iters:6000
Loss:0.016250617341443813
Pred:[1 1 0 1 0 0 1 0]
True:[1 1 0 1 0 0 1 0]
90 + 120 = 210
------------
iters:6100
Loss:0.011023349809927266
Pred:[1 0 0 0 0 1 1 1]
True:[1 0 0 0 0 1 1 1]
95 + 40 = 135
------------
iters:6200
Loss:0.002214579654360812
Pred:[0 1 0 0 0 0 0 0]
True:[0 1 0 0 0 0 0 0]
9 + 55 = 64
------------
iters:6300
Loss:0.006994126395314979
Pred:[0 1 1 1 1 1 0 1]
True:[0 1 1 1 1 1 0 1]
121 + 4 = 125
------------
iters:6400
Loss:0.011656535565668253
Pred:[1 1 0 1 0 0 0 0]
True:[1 1 0 1 0 0 0 0]
94 + 114 = 208
------------
iters:6500
Loss:0.00835822932431572
Pred:[0 1 0 1 1 1 0 1]
True:[0 1 0 1 1 1 0 1]
59 + 34 = 93
------------
iters:6600
Loss:0.004004940564498871
Pred:[1 0 1 1 1 0 1 0]
True:[1 0 1 1 1 0 1 0]
125 + 61 = 186
------------
iters:6700
Loss:0.004251408078137221
Pred:[1 1 0 0 0 0 1 0]
True:[1 1 0 0 0 0 1 0]
119 + 75 = 194
------------
iters:6800
Loss:0.014842453810526235
Pred:[1 0 0 1 1 1 0 0]
True:[1 0 0 1 1 1 0 0]
46 + 110 = 156
------------
iters:6900
Loss:0.002739152861783825
Pred:[1 0 1 1 0 1 1 0]
True:[1 0 1 1 0 1 1 0]
105 + 77 = 182
------------
iters:7000
Loss:0.005060309392321419
Pred:[1 1 1 0 0 0 1 1]
True:[1 1 1 0 0 0 1 1]
102 + 125 = 227
------------
iters:7100
Loss:0.007706847794406494
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
86 + 58 = 144
------------
iters:7200
Loss:0.007848196809802156
Pred:[1 0 0 1 0 1 0 0]
True:[1 0 0 1 0 1 0 0]
44 + 104 = 148
------------
iters:7300
Loss:0.003306521612182141
Pred:[0 1 0 0 1 1 0 1]
True:[0 1 0 0 1 1 0 1]
0 + 77 = 77
------------
iters:7400
Loss:0.003889550345273361
Pred:[1 0 0 1 0 1 0 1]
True:[1 0 0 1 0 1 0 1]
36 + 113 = 149
------------
iters:7500
Loss:0.004981213485463351
Pred:[1 0 0 1 1 0 1 1]
True:[1 0 0 1 1 0 1 1]
121 + 34 = 155
------------
iters:7600
Loss:0.006962505909565046
Pred:[1 0 0 1 0 0 1 0]
True:[1 0 0 1 0 0 1 0]
76 + 70 = 146
------------
iters:7700
Loss:0.0018352521445289913
Pred:[0 1 0 1 0 0 1 0]
True:[0 1 0 1 0 0 1 0]
45 + 37 = 82
------------
iters:7800
Loss:0.0016638956038422
Pred:[1 0 0 0 1 1 1 0]
True:[1 0 0 0 1 1 1 0]
113 + 29 = 142
------------
iters:7900
Loss:0.003543873412859076
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
60 + 97 = 157
------------
iters:8000
Loss:0.0024374799224191157
Pred:[0 0 1 1 1 1 0 1]
True:[0 0 1 1 1 1 0 1]
12 + 49 = 61
------------
iters:8100
Loss:0.002533502818801829
Pred:[1 0 0 1 0 0 1 1]
True:[1 0 0 1 0 0 1 1]
48 + 99 = 147
------------
iters:8200
Loss:0.003189828349739394
Pred:[1 0 1 0 1 1 0 1]
True:[1 0 1 0 1 1 0 1]
110 + 63 = 173
------------
iters:8300
Loss:0.0010226713931731077
Pred:[0 0 1 1 1 1 1 0]
True:[0 0 1 1 1 1 1 0]
25 + 37 = 62
------------
iters:8400
Loss:0.0060408723951314355
Pred:[0 1 1 0 1 1 1 0]
True:[0 1 1 0 1 1 1 0]
6 + 104 = 110
------------
iters:8500
Loss:0.0031443232561134227
Pred:[1 0 1 0 0 1 1 1]
True:[1 0 1 0 0 1 1 1]
67 + 100 = 167
------------
iters:8600
Loss:0.004723847270710192
Pred:[0 1 0 1 0 0 0 0]
True:[0 1 0 1 0 0 0 0]
16 + 64 = 80
------------
iters:8700
Loss:0.001483870265294737
Pred:[1 0 0 1 0 1 0 0]
True:[1 0 0 1 0 1 0 0]
123 + 25 = 148
------------
iters:8800
Loss:0.004562689881407613
Pred:[0 1 1 0 0 0 0 0]
True:[0 1 1 0 0 0 0 0]
92 + 4 = 96
------------
iters:8900
Loss:0.004607090758523282
Pred:[0 0 0 1 0 1 0 0]
True:[0 0 0 1 0 1 0 0]
2 + 18 = 20
------------
iters:9000
Loss:0.0021486017522436333
Pred:[0 1 1 0 1 0 1 1]
True:[0 1 1 0 1 0 1 1]
14 + 93 = 107
------------
iters:9100
Loss:0.0004668637065865403
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
21 + 123 = 144
------------
iters:9200
Loss:0.0034633332877822822
Pred:[0 1 0 0 0 0 1 0]
True:[0 1 0 0 0 0 1 0]
64 + 2 = 66
------------
iters:9300
Loss:0.0016145269926962013
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
80 + 77 = 157
------------
iters:9400
Loss:0.0006806393853258948
Pred:[0 1 0 0 1 0 1 0]
True:[0 1 0 0 1 0 1 0]
69 + 5 = 74
------------
iters:9500
Loss:0.0004910266676769213
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
99 + 27 = 126
------------
iters:9600
Loss:0.0005325292824086261
Pred:[0 1 0 1 1 1 1 0]
True:[0 1 0 1 1 1 1 0]
43 + 51 = 94
------------
iters:9700
Loss:0.0030860797349639806
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
118 + 26 = 144
------------
iters:9800
Loss:0.0016809075732658351
Pred:[1 0 1 1 1 1 1 1]
True:[1 0 1 1 1 1 1 1]
90 + 101 = 191
------------
iters:9900
Loss:0.0006005369971381359
Pred:[1 0 1 1 1 0 0 0]
True:[1 0 1 1 1 0 0 0]
125 + 59 = 184
------------

[try] weight_init_stdやlearning_rate, hidden_layer_sizeを変更してみよう

[try] 重みの初期化方法を変更してみよう

Xavier, He

[try] 中間層の活性化関数を変更してみよう

ReLU(勾配爆発を確認しよう)
tanh(numpyにtanhが用意されている。導関数をd_tanhとして作成しよう)


In [12]:
import numpy as np
from common import functions
import matplotlib.pyplot as plt


def d_tanh(x):
    return 1/(np.cosh(x) ** 2)

# データを用意
# 2進数の桁数
binary_dim = 8
# 最大値 + 1
largest_number = pow(2, binary_dim)
# largest_numberまで2進数を用意
binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)

input_layer_size = 2
hidden_layer_size = 16
output_layer_size = 1

weight_init_std = 1
learning_rate = 0.1

iters_num = 10000
plot_interval = 100

# ウェイト初期化 (バイアスは簡単のため省略)
#W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
#W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
#W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)
# Xavier
#W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size))
#W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size))
#W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size))
#He
W_in = np.random.randn(input_layer_size, hidden_layer_size) / (np.sqrt(input_layer_size)) * np.sqrt(2)
W_out = np.random.randn(hidden_layer_size, output_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)
W = np.random.randn(hidden_layer_size, hidden_layer_size) / (np.sqrt(hidden_layer_size)) * np.sqrt(2)


# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)

u = np.zeros((hidden_layer_size, binary_dim + 1))
z = np.zeros((hidden_layer_size, binary_dim + 1))
y = np.zeros((output_layer_size, binary_dim))

delta_out = np.zeros((output_layer_size, binary_dim))
delta = np.zeros((hidden_layer_size, binary_dim + 1))

all_losses = []

for i in range(iters_num):
    
    # A, B初期化 (a + b = d)
    a_int = np.random.randint(largest_number/2)#【コメント】256/2
    a_bin = binary[a_int] # binary encoding
    b_int = np.random.randint(largest_number/2)#【コメント】256/2
    b_bin = binary[b_int] # binary encoding
    
    # 正解データ
    d_int = a_int + b_int
    d_bin = binary[d_int]
    
    # 出力バイナリ
    out_bin = np.zeros_like(d_bin)
    
    # 時系列全体の誤差
    all_loss = 0    
    
    # 時系列ループ
    for t in range(binary_dim):
        # 入力値
        X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)#【コメント】LSBから順に取り出し

        # 時刻tにおける正解データ
        dd = np.array([d_bin[binary_dim - t - 1]])
        
        u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)
#        z[:,t+1] = functions.sigmoid(u[:,t+1])
#        z[:,t+1] = functions.relu(u[:,t+1])
        z[:,t+1] = np.tanh(u[:,t+1])    
        y[:,t] = functions.sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))


        #誤差
        loss = functions.mean_squared_error(dd, y[:,t])
        
        delta_out[:,t] = functions.d_mean_squared_error(dd, y[:,t]) * functions.d_sigmoid(y[:,t])        
        
        all_loss += loss

        out_bin[binary_dim - t - 1] = np.round(y[:,t])
    
    #【コメント】逆伝播
    for t in range(binary_dim)[::-1]:
        X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1)        

#        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * functions.d_sigmoid(u[:,t+1])
#        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * functions.d_relu(u[:,t+1])
        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * d_tanh(u[:,t+1])    

        # 勾配更新
        W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))
        W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))
        W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))
    
    # 勾配適用
    W_in -= learning_rate * W_in_grad
    W_out -= learning_rate * W_out_grad
    W -= learning_rate * W_grad
    
    W_in_grad *= 0
    W_out_grad *= 0
    W_grad *= 0
    

    if(i % plot_interval == 0):
        all_losses.append(all_loss)        
        print("iters:" + str(i))
        print("Loss:" + str(all_loss))
        print("Pred:" + str(out_bin))
        print("True:" + str(d_bin))
        out_int = 0
        for index,x in enumerate(reversed(out_bin)):
            out_int += x * pow(2, index)
        print(str(a_int) + " + " + str(b_int) + " = " + str(out_int))
        print("------------")

lists = range(0, iters_num, plot_interval)
plt.plot(lists, all_losses, label="loss")
plt.show()
iters:0
Loss:1.7825079188591872
Pred:[0 0 0 0 0 0 0 1]
True:[1 0 0 1 1 1 1 1]
71 + 88 = 1
------------
iters:100
Loss:1.068114983865859
Pred:[0 0 0 0 1 0 0 1]
True:[1 0 0 1 1 1 1 1]
91 + 68 = 9
------------
iters:200
Loss:0.7980580448445825
Pred:[1 0 1 1 0 1 1 1]
True:[1 0 0 0 0 1 0 1]
106 + 27 = 183
------------
iters:300
Loss:0.7430101320846468
Pred:[0 0 1 0 1 1 0 1]
True:[0 1 0 0 0 1 0 1]
26 + 43 = 45
------------
iters:400
Loss:0.3547527089333904
Pred:[0 1 1 0 1 1 1 0]
True:[0 1 1 0 1 1 1 0]
74 + 36 = 110
------------
iters:500
Loss:0.4412233768541721
Pred:[1 1 0 1 0 0 1 1]
True:[1 0 0 1 0 0 1 1]
24 + 123 = 211
------------
iters:600
Loss:0.05728341372660974
Pred:[0 0 1 1 0 0 1 1]
True:[0 0 1 1 0 0 1 1]
20 + 31 = 51
------------
iters:700
Loss:0.4473464715313207
Pred:[1 1 1 0 0 1 0 1]
True:[0 1 1 0 0 0 0 1]
65 + 32 = 229
------------
iters:800
Loss:0.005265394894770008
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 1 1 1 0 1 1]
122 + 1 = 123
------------
iters:900
Loss:0.007732614113016279
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
44 + 113 = 157
------------
iters:1000
Loss:0.0027478829812228422
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 1 1 1 0 1 1]
39 + 84 = 123
------------
iters:1100
Loss:0.12687003725236484
Pred:[0 0 1 1 1 1 1 0]
True:[0 0 1 1 1 1 1 0]
58 + 4 = 62
------------
iters:1200
Loss:0.12624947338635023
Pred:[0 1 1 1 0 0 0 0]
True:[0 1 1 1 0 0 0 0]
98 + 14 = 112
------------
iters:1300
Loss:0.00374707842398215
Pred:[1 0 0 0 0 1 0 1]
True:[1 0 0 0 0 1 0 1]
58 + 75 = 133
------------
iters:1400
Loss:0.0010539539075530182
Pred:[0 1 1 0 0 1 1 1]
True:[0 1 1 0 0 1 1 1]
47 + 56 = 103
------------
iters:1500
Loss:0.0012807498642258607
Pred:[1 0 1 0 0 0 0 0]
True:[1 0 1 0 0 0 0 0]
89 + 71 = 160
------------
iters:1600
Loss:0.00027810120900452294
Pred:[0 1 1 1 1 1 1 1]
True:[0 1 1 1 1 1 1 1]
108 + 19 = 127
------------
iters:1700
Loss:0.0003627208120751486
Pred:[0 1 1 1 1 1 0 1]
True:[0 1 1 1 1 1 0 1]
11 + 114 = 125
------------
iters:1800
Loss:0.0008168185131451574
Pred:[1 0 0 0 1 0 1 0]
True:[1 0 0 0 1 0 1 0]
75 + 63 = 138
------------
iters:1900
Loss:0.0006956249258870611
Pred:[1 0 0 1 1 1 0 0]
True:[1 0 0 1 1 1 0 0]
97 + 59 = 156
------------
iters:2000
Loss:0.00042832060959352176
Pred:[1 0 0 0 1 1 0 1]
True:[1 0 0 0 1 1 0 1]
108 + 33 = 141
------------
iters:2100
Loss:0.00028488131701217266
Pred:[0 1 0 0 1 1 0 1]
True:[0 1 0 0 1 1 0 1]
2 + 75 = 77
------------
iters:2200
Loss:0.1252786282876126
Pred:[1 0 0 0 1 0 1 0]
True:[1 0 0 0 1 0 1 0]
14 + 124 = 138
------------
iters:2300
Loss:0.00023809308983985088
Pred:[1 0 0 0 1 1 0 1]
True:[1 0 0 0 1 1 0 1]
28 + 113 = 141
------------
iters:2400
Loss:0.1254340918005175
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
108 + 62 = 170
------------
iters:2500
Loss:0.0003785222806668951
Pred:[1 0 1 0 1 1 0 1]
True:[1 0 1 0 1 1 0 1]
47 + 126 = 173
------------
iters:2600
Loss:0.0002105931197411917
Pred:[1 0 0 0 0 1 1 1]
True:[1 0 0 0 0 1 1 1]
52 + 83 = 135
------------
iters:2700
Loss:0.0002441573766304045
Pred:[0 1 1 0 0 0 1 1]
True:[0 1 1 0 0 0 1 1]
47 + 52 = 99
------------
iters:2800
Loss:0.00027489853311958503
Pred:[1 0 1 0 1 0 0 1]
True:[1 0 1 0 1 0 0 1]
78 + 91 = 169
------------
iters:2900
Loss:0.0003140441400391026
Pred:[1 0 1 0 0 0 1 0]
True:[1 0 1 0 0 0 1 0]
127 + 35 = 162
------------
iters:3000
Loss:8.643852597235558e-05
Pred:[0 1 1 1 1 1 0 0]
True:[0 1 1 1 1 1 0 0]
21 + 103 = 124
------------
iters:3100
Loss:0.1251003509734491
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 0 1 0 0]
18 + 98 = 116
------------
iters:3200
Loss:0.12513844133344143
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
42 + 102 = 144
------------
iters:3300
Loss:8.086040718682504e-05
Pred:[0 1 1 1 1 1 1 1]
True:[0 1 1 1 1 1 1 1]
75 + 52 = 127
------------
iters:3400
Loss:4.67266179511879e-05
Pred:[0 0 1 0 1 0 0 1]
True:[0 0 1 0 1 0 0 1]
40 + 1 = 41
------------
iters:3500
Loss:9.053072511586486e-05
Pred:[1 1 0 0 1 1 1 0]
True:[1 1 0 0 1 1 1 0]
79 + 127 = 206
------------
iters:3600
Loss:0.00016996395205645472
Pred:[1 1 1 0 0 1 0 1]
True:[1 1 1 0 0 1 0 1]
124 + 105 = 229
------------
iters:3700
Loss:0.00016804680629905664
Pred:[1 0 0 1 0 0 1 1]
True:[1 0 0 1 0 0 1 1]
104 + 43 = 147
------------
iters:3800
Loss:7.341881992600306e-05
Pred:[0 0 1 1 1 1 1 1]
True:[0 0 1 1 1 1 1 1]
2 + 61 = 63
------------
iters:3900
Loss:9.421307448256092e-05
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 1 0 1 1]
52 + 39 = 91
------------
iters:4000
Loss:5.2738189739628505e-05
Pred:[0 1 0 0 0 1 0 0]
True:[0 1 0 0 0 1 0 0]
19 + 49 = 68
------------
iters:4100
Loss:5.95447272714803e-05
Pred:[0 1 1 1 0 0 0 0]
True:[0 1 1 1 0 0 0 0]
25 + 87 = 112
------------
iters:4200
Loss:7.569514208592685e-05
Pred:[0 1 0 0 0 1 0 1]
True:[0 1 0 0 0 1 0 1]
10 + 59 = 69
------------
iters:4300
Loss:4.055061163579041e-05
Pred:[0 1 0 1 1 0 0 0]
True:[0 1 0 1 1 0 0 0]
67 + 21 = 88
------------
iters:4400
Loss:4.092651459489415e-05
Pred:[0 1 1 1 0 1 1 1]
True:[0 1 1 1 0 1 1 1]
20 + 99 = 119
------------
iters:4500
Loss:4.629768390045177e-05
Pred:[1 0 1 1 0 1 0 0]
True:[1 0 1 1 0 1 0 0]
101 + 79 = 180
------------
iters:4600
Loss:4.123288343974828e-05
Pred:[0 1 1 0 1 1 1 1]
True:[0 1 1 0 1 1 1 1]
56 + 55 = 111
------------
iters:4700
Loss:0.2500268464000157
Pred:[0 0 1 1 0 1 0 0]
True:[0 0 1 1 0 1 0 0]
48 + 4 = 52
------------
iters:4800
Loss:4.599740459679424e-05
Pred:[1 0 0 0 1 1 0 0]
True:[1 0 0 0 1 1 0 0]
81 + 59 = 140
------------
iters:4900
Loss:3.941486841660323e-05
Pred:[0 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
117 + 1 = 118
------------
iters:5000
Loss:6.74943591489191e-05
Pred:[0 1 0 0 0 0 0 1]
True:[0 1 0 0 0 0 0 1]
6 + 59 = 65
------------
iters:5100
Loss:3.232108916862731e-05
Pred:[0 0 1 1 0 1 1 1]
True:[0 0 1 1 0 1 1 1]
45 + 10 = 55
------------
iters:5200
Loss:5.287632834617799e-05
Pred:[0 1 0 0 1 1 0 1]
True:[0 1 0 0 1 1 0 1]
38 + 39 = 77
------------
iters:5300
Loss:3.157069795502695e-05
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
8 + 121 = 129
------------
iters:5400
Loss:3.264176652427938e-05
Pred:[0 1 0 1 0 0 1 1]
True:[0 1 0 1 0 0 1 1]
74 + 9 = 83
------------
iters:5500
Loss:4.4234816152602624e-05
Pred:[1 0 0 1 0 1 1 1]
True:[1 0 0 1 0 1 1 1]
91 + 60 = 151
------------
iters:5600
Loss:4.1960900088220426e-05
Pred:[0 1 1 1 0 1 1 1]
True:[0 1 1 1 0 1 1 1]
105 + 14 = 119
------------
iters:5700
Loss:4.664589166277189e-05
Pred:[1 1 0 0 0 0 1 1]
True:[1 1 0 0 0 0 1 1]
79 + 116 = 195
------------
iters:5800
Loss:3.591381648414389e-05
Pred:[1 0 1 0 0 0 1 1]
True:[1 0 1 0 0 0 1 1]
114 + 49 = 163
------------
iters:5900
Loss:2.754463424391885e-05
Pred:[0 1 1 0 0 0 0 0]
True:[0 1 1 0 0 0 0 0]
87 + 9 = 96
------------
iters:6000
Loss:0.12502073402826852
Pred:[1 0 1 0 1 1 1 0]
True:[1 0 1 0 1 1 1 0]
108 + 66 = 174
------------
iters:6100
Loss:2.5280760253401427e-05
Pred:[0 1 1 0 0 1 0 1]
True:[0 1 1 0 0 1 0 1]
70 + 31 = 101
------------
iters:6200
Loss:0.12502781591919682
Pred:[0 1 1 0 0 0 1 0]
True:[0 1 1 0 0 0 1 0]
36 + 62 = 98
------------
iters:6300
Loss:3.401681512827791e-05
Pred:[1 0 1 0 1 0 1 1]
True:[1 0 1 0 1 0 1 1]
93 + 78 = 171
------------
iters:6400
Loss:2.0636033243253728e-05
Pred:[0 0 1 0 0 1 0 1]
True:[0 0 1 0 0 1 0 1]
0 + 37 = 37
------------
iters:6500
Loss:0.1250200388601585
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
4 + 122 = 126
------------
iters:6600
Loss:3.424672143080501e-05
Pred:[1 0 1 0 0 1 1 1]
True:[1 0 1 0 0 1 1 1]
88 + 79 = 167
------------
iters:6700
Loss:1.3524981089835943e-05
Pred:[0 0 0 1 1 0 0 0]
True:[0 0 0 1 1 0 0 0]
21 + 3 = 24
------------
iters:6800
Loss:2.438745930506212e-05
Pred:[1 0 0 0 1 1 1 0]
True:[1 0 0 0 1 1 1 0]
47 + 95 = 142
------------
iters:6900
Loss:0.1250136650035057
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
38 + 98 = 136
------------
iters:7000
Loss:0.12501706142810073
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
74 + 52 = 126
------------
iters:7100
Loss:1.4503997360543127e-05
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 0 1 0 0]
71 + 45 = 116
------------
iters:7200
Loss:1.8022358415222422e-05
Pred:[0 1 1 0 0 1 0 1]
True:[0 1 1 0 0 1 0 1]
69 + 32 = 101
------------
iters:7300
Loss:3.749512131128678e-05
Pred:[1 0 1 1 0 0 1 0]
True:[1 0 1 1 0 0 1 0]
93 + 85 = 178
------------
iters:7400
Loss:9.055521018460283e-06
Pred:[0 1 0 1 1 0 0 0]
True:[0 1 0 1 1 0 0 0]
9 + 79 = 88
------------
iters:7500
Loss:2.622976193634337e-05
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
113 + 57 = 170
------------
iters:7600
Loss:1.276576148532525e-05
Pred:[0 1 0 1 0 0 0 0]
True:[0 1 0 1 0 0 0 0]
17 + 63 = 80
------------
iters:7700
Loss:8.732101677456416e-06
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
43 + 79 = 122
------------
iters:7800
Loss:1.6052946758071522e-05
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
73 + 49 = 122
------------
iters:7900
Loss:2.214482869546824e-05
Pred:[1 0 0 1 0 0 1 1]
True:[1 0 0 1 0 0 1 1]
30 + 117 = 147
------------
iters:8000
Loss:1.5251033220524903e-05
Pred:[0 1 0 1 1 1 0 1]
True:[0 1 0 1 1 1 0 1]
51 + 42 = 93
------------
iters:8100
Loss:1.7061906225698132e-05
Pred:[1 1 0 1 0 1 1 1]
True:[1 1 0 1 0 1 1 1]
91 + 124 = 215
------------
iters:8200
Loss:1.899722917028597e-05
Pred:[1 0 0 1 0 0 0 1]
True:[1 0 0 1 0 0 0 1]
52 + 93 = 145
------------
iters:8300
Loss:2.975699747462139e-05
Pred:[1 1 0 1 0 0 0 1]
True:[1 1 0 1 0 0 0 1]
126 + 83 = 209
------------
iters:8400
Loss:1.039892279305604e-05
Pred:[0 1 1 1 0 0 1 1]
True:[0 1 1 1 0 0 1 1]
38 + 77 = 115
------------
iters:8500
Loss:1.87327551538756e-05
Pred:[1 0 0 0 1 1 1 1]
True:[1 0 0 0 1 1 1 1]
121 + 22 = 143
------------
iters:8600
Loss:1.2710807151580049e-05
Pred:[1 0 1 1 0 1 0 0]
True:[1 0 1 1 0 1 0 0]
93 + 87 = 180
------------
iters:8700
Loss:0.12500839328084346
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
14 + 122 = 136
------------
iters:8800
Loss:0.25000892679952963
Pred:[0 1 0 1 1 0 0 0]
True:[0 1 0 1 1 0 0 0]
20 + 68 = 88
------------
iters:8900
Loss:1.4884770102710749e-05
Pred:[0 1 1 1 1 0 0 1]
True:[0 1 1 1 1 0 0 1]
55 + 66 = 121
------------
iters:9000
Loss:0.37500575361078753
Pred:[0 1 0 0 1 0 0 0]
True:[0 1 0 0 1 0 0 0]
64 + 8 = 72
------------
iters:9100
Loss:0.1250116469151159
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
114 + 30 = 144
------------
iters:9200
Loss:2.0845834933762964e-05
Pred:[1 0 0 0 1 0 0 1]
True:[1 0 0 0 1 0 0 1]
79 + 58 = 137
------------
iters:9300
Loss:2.117197441553305e-05
Pred:[1 1 0 1 0 0 0 1]
True:[1 1 0 1 0 0 0 1]
119 + 90 = 209
------------
iters:9400
Loss:8.00234522762448e-06
Pred:[0 1 1 0 1 0 1 0]
True:[0 1 1 0 1 0 1 0]
33 + 73 = 106
------------
iters:9500
Loss:0.12501182433103109
Pred:[1 1 1 0 1 0 0 0]
True:[1 1 1 0 1 0 0 0]
118 + 114 = 232
------------
iters:9600
Loss:9.763321455056136e-06
Pred:[0 1 0 0 1 0 1 1]
True:[0 1 0 0 1 0 1 1]
70 + 5 = 75
------------
iters:9700
Loss:9.973657694693363e-06
Pred:[0 1 0 1 1 0 1 0]
True:[0 1 0 1 1 0 1 0]
57 + 33 = 90
------------
iters:9800
Loss:8.32444962080584e-06
Pred:[1 0 0 0 0 0 1 1]
True:[1 0 0 0 0 0 1 1]
65 + 66 = 131
------------
iters:9900
Loss:9.437658859481507e-06
Pred:[0 0 1 0 0 1 1 0]
True:[0 0 1 0 0 1 1 0]
1 + 37 = 38
------------