pso 알고리즘을 구현하는데 bp 를 완전히 배제하는 방법으로 구현
model 디렉토리를 자동으로 생성하게 수정
This commit is contained in:
jung-geun
2023-05-24 14:00:31 +09:00
parent 7c5f3a53a3
commit 27d40ab56c
9 changed files with 1556 additions and 352 deletions

400
xor.ipynb
View File

@@ -2,18 +2,11 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 13,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-05-21 01:52:28.471404: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@@ -55,7 +48,7 @@
" model = Sequential(leyer)\n",
"\n",
" sgd = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n",
" adam = keras.optimizers.Adam(lr=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.)\n",
" # adam = keras.optimizers.Adam(lr=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.)\n",
" model.compile(loss='mse', optimizer=sgd, metrics=['accuracy'])\n",
"\n",
" print(model.summary())\n",
@@ -65,238 +58,185 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"Model: \"sequential_11\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" dense (Dense) (None, 2) 6 \n",
" dense_22 (Dense) (None, 2) 6 \n",
" \n",
" dense_1 (Dense) (None, 1) 3 \n",
" dense_23 (Dense) (None, 1) 3 \n",
" \n",
"=================================================================\n",
"Total params: 9\n",
"Trainable params: 9\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
"_________________________________________________________________\n",
"None\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
" super().__init__(name, **kwargs)\n",
"/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/adam.py:114: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
" super().__init__(name, **kwargs)\n"
"init particles position: 100%|██████████| 15/15 [00:00<00:00, 85.12it/s]\n",
"init velocities: 100%|██████████| 15/15 [00:00<00:00, 46465.70it/s]\n",
"Iteration 0 / 10: 100%|##########| 15/15 [00:05<00:00, 2.63it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"None\n",
"0 particle score : 0.24921603500843048\n",
"1 particle score : 0.2509610056877136\n",
"2 particle score : 0.28712478280067444\n",
"3 particle score : 0.2665291726589203\n",
"4 particle score : 0.2513682246208191\n",
"0 particle score : 0.26079031825065613\n",
"1 particle score : 0.24931921064853668\n",
"2 particle score : 0.2679133415222168\n",
"3 particle score : 0.27925199270248413\n",
"4 particle score : 0.2605195641517639\n",
"0 particle score : 0.30758577585220337\n",
"1 particle score : 0.26747316122055054\n",
"2 particle score : 0.36957648396492004\n",
"3 particle score : 0.19372068345546722\n",
"4 particle score : 0.3671383857727051\n",
"0 particle score : 0.24090810120105743\n",
"1 particle score : 0.3176509141921997\n",
"2 particle score : 0.23225924372673035\n",
"3 particle score : 0.37263113260269165\n",
"4 particle score : 0.47822105884552\n",
"0 particle score : 0.37611791491508484\n",
"1 particle score : 0.27166277170181274\n",
"2 particle score : 0.21416285634040833\n",
"3 particle score : 0.23324625194072723\n",
"4 particle score : 0.024583835154771805\n",
"0 particle score : 0.05194556713104248\n",
"1 particle score : 0.3102635443210602\n",
"2 particle score : 0.31894028186798096\n",
"3 particle score : 0.12679985165596008\n",
"4 particle score : 0.012038745917379856\n",
"0 particle score : 0.004551469348371029\n",
"1 particle score : 0.03923884406685829\n",
"2 particle score : 0.003701586974784732\n",
"3 particle score : 0.0026527238078415394\n",
"4 particle score : 0.0430503748357296\n",
"0 particle score : 0.000214503234019503\n",
"1 particle score : 0.0025649480521678925\n",
"2 particle score : 0.008843829855322838\n",
"3 particle score : 0.23036976158618927\n",
"4 particle score : 0.21686825156211853\n",
"0 particle score : 4.901693273495766e-07\n",
"1 particle score : 0.003860481781885028\n",
"2 particle score : 0.00047884139348752797\n",
"3 particle score : 0.1563722789287567\n",
"4 particle score : 1.1759411222556082e-07\n",
"0 particle score : 0.24969959259033203\n",
"1 particle score : 2.8646991268033162e-05\n",
"2 particle score : 1.0552450024903237e-09\n",
"3 particle score : 3.566808572941227e-07\n",
"4 particle score : 8.882947003831243e-14\n",
"0 particle score : 0.2497878521680832\n",
"1 particle score : 1.879385969766334e-12\n",
"2 particle score : 0.44945281744003296\n",
"3 particle score : 2.485284791549705e-14\n",
"4 particle score : 2.431787924306583e-26\n",
"0 particle score : 3.854774978241029e-18\n",
"1 particle score : 1.2515056546646974e-08\n",
"2 particle score : 0.49999988079071045\n",
"3 particle score : 2.8881452344524906e-22\n",
"4 particle score : 4.162688806996304e-30\n",
"0 particle score : 9.170106118851775e-37\n",
"1 particle score : 0.49933725595474243\n",
"2 particle score : 0.43209874629974365\n",
"3 particle score : 7.681456478781658e-30\n",
"4 particle score : 1.1656278206215614e-33\n",
"0 particle score : 0.0\n",
"1 particle score : 0.49545660614967346\n",
"2 particle score : 0.25\n",
"3 particle score : 0.0\n",
"4 particle score : 0.0\n",
"0 particle score : 0.0\n",
"1 particle score : 0.25\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.0\n",
"1 particle score : 0.0\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.0\n",
"1 particle score : 0.25\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.0\n",
"1 particle score : 0.0\n",
"2 particle score : 0.0\n",
"3 particle score : 0.0\n",
"4 particle score : 0.5\n",
"0 particle score : 1.2923532081356227e-22\n",
"1 particle score : 0.5\n",
"2 particle score : 0.25\n",
"3 particle score : 0.942779541015625\n",
"4 particle score : 0.5\n",
"0 particle score : 0.4959273338317871\n",
"1 particle score : 0.5\n",
"2 particle score : 0.5\n",
"3 particle score : 0.75\n",
"4 particle score : 0.75\n",
"0 particle score : 0.23154164850711823\n",
"1 particle score : 0.5\n",
"2 particle score : 0.5\n",
"3 particle score : 0.5\n",
"4 particle score : 0.5\n",
"0 particle score : 0.0\n",
"1 particle score : 0.5\n",
"2 particle score : 0.25\n",
"3 particle score : 0.0\n",
"4 particle score : 0.25\n",
"0 particle score : 0.0\n",
"1 particle score : 0.5\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.25\n",
"1 particle score : 0.25\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.0\n",
"1 particle score : 0.25\n",
"2 particle score : 0.25\n",
"3 particle score : 0.25\n",
"4 particle score : 0.25\n",
"0 particle score : 0.5760642290115356\n",
"1 particle score : 0.25\n",
"2 particle score : 0.25000467896461487\n",
"3 particle score : 0.5\n",
"4 particle score : 0.5\n",
"0 particle score : 0.5\n",
"1 particle score : 0.25\n",
"2 particle score : 0.4998854398727417\n",
"3 particle score : 0.5\n",
"4 particle score : 0.5\n",
"0 particle score : 0.5\n",
"1 particle score : 0.25\n",
"2 particle score : 0.5000014305114746\n",
"3 particle score : 0.5\n",
"4 particle score : 0.5\n",
"0 particle score : 0.5\n",
"1 particle score : 0.007790721021592617\n",
"2 particle score : 0.5\n",
"3 particle score : 0.75\n",
"4 particle score : 0.5\n",
"0 particle score : 0.25\n",
"1 particle score : 0.5\n",
"2 particle score : 0.5\n",
"3 particle score : 0.0\n",
"4 particle score : 0.5\n",
"1/1 [==============================] - 0s 40ms/step\n",
"[[1.8555788e-26]\n",
" [1.0000000e+00]\n",
" [1.0000000e+00]\n",
" [1.8555788e-26]]\n",
"loss : 0.2620863338311513 | acc : 0.5 | best loss : 0.24143654108047485\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 1 / 10: 100%|##########| 15/15 [00:05<00:00, 2.69it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.24147300918896994 | acc : 0.6333333333333333 | best loss : 0.20360520482063293\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 2 / 10: 100%|##########| 15/15 [00:05<00:00, 2.72it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.211648628115654 | acc : 0.65 | best loss : 0.17383326590061188\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 3 / 10: 100%|##########| 15/15 [00:05<00:00, 2.72it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.21790608167648315 | acc : 0.6833333333333333 | best loss : 0.16785581409931183\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 4 / 10: 100%|##########| 15/15 [00:05<00:00, 2.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.20557119349638622 | acc : 0.7333333333333333 | best loss : 0.16668711602687836\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 5 / 10: 100%|##########| 15/15 [00:05<00:00, 2.68it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.20823073089122773 | acc : 0.7 | best loss : 0.16668711602687836\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 6 / 10: 100%|##########| 15/15 [00:05<00:00, 2.53it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.21380058924357095 | acc : 0.7166666666666667 | best loss : 0.16668711602687836\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 7 / 10: 100%|##########| 15/15 [00:06<00:00, 2.30it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.2561836312214533 | acc : 0.6833333333333333 | best loss : 0.16667115688323975\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 8 / 10: 100%|##########| 15/15 [00:05<00:00, 2.55it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.30372582376003265 | acc : 0.65 | best loss : 0.16667115688323975\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 9 / 10: 100%|##########| 15/15 [00:05<00:00, 2.65it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.281868569056193 | acc : 0.7 | best loss : 0.16667115688323975\n",
"1/1 [==============================] - 0s 26ms/step\n",
"[[0. ]\n",
" [0.66422266]\n",
" [0.6642227 ]\n",
" [0.6642227 ]]\n",
"[[0]\n",
" [1]\n",
" [1]\n",
" [0]]\n",
"history > [[array([[-0.9191145, -0.7256227],\n",
" [ 1.2947526, 1.0081983]], dtype=float32), array([ 0.01203067, -0.07866445], dtype=float32), array([[-0.72274315],\n",
" [ 0.88691926]], dtype=float32), array([-0.08449478], dtype=float32)], [array([[-0.7327981, -2.120965 ],\n",
" [ 3.5870228, 2.0618958]], dtype=float32), array([-0.06788628, -2.1460009 ], dtype=float32), array([[-1.8084345],\n",
" [ 3.2274616]], dtype=float32), array([0.40823892], dtype=float32)], [array([[-6.749437, -5.01979 ],\n",
" [ 9.477569, 9.011221]], dtype=float32), array([ 1.0140182, -5.089527 ], dtype=float32), array([[-5.9527373],\n",
" [ 8.538484 ]], dtype=float32), array([0.8423419], dtype=float32)], [array([[-4.4376955, -7.542317 ],\n",
" [13.042126 , 9.401183 ]], dtype=float32), array([ 1.7249748, -6.2829194], dtype=float32), array([[-4.6019 ],\n",
" [12.654787]], dtype=float32), array([2.11288], dtype=float32)], [array([[-7.9655757, -8.855807 ],\n",
" [14.27012 , 12.6986265]], dtype=float32), array([ 2.2102568, -7.4656196], dtype=float32), array([[-5.386531],\n",
" [16.770058]], dtype=float32), array([2.2161639], dtype=float32)], [array([[-10.937471, -9.346545],\n",
" [ 15.040345, 13.547635]], dtype=float32), array([ 3.7305086, -8.93729 ], dtype=float32), array([[-9.661456],\n",
" [14.314214]], dtype=float32), array([1.9838718], dtype=float32)], [array([[-7.618989 , -8.295806 ],\n",
" [ 9.591193 , 7.3881774]], dtype=float32), array([ 2.9443424, -6.85388 ], dtype=float32), array([[-6.120155],\n",
" [ 9.558391]], dtype=float32), array([2.900807], dtype=float32)], [array([[-12.431582, -14.683373],\n",
" [ 24.192898, 18.607504]], dtype=float32), array([ 4.375762, -11.899742], dtype=float32), array([[-11.140665],\n",
" [ 25.361753]], dtype=float32), array([3.5045836], dtype=float32)], [array([[-16.167437, -19.325432],\n",
" [ 25.197618, 15.928284]], dtype=float32), array([ 6.8536587, -14.406519 ], dtype=float32), array([[-16.149462],\n",
" [ 21.955147]], dtype=float32), array([6.5853295], dtype=float32)], [array([[-20.64401 , -25.207134],\n",
" [ 28.023142, 19.938404]], dtype=float32), array([ 7.2551775, -17.74039 ], dtype=float32), array([[-15.623163],\n",
" [ 30.90391 ]], dtype=float32), array([7.7026973], dtype=float32)], [array([[-27.585245, -28.003128],\n",
" [ 46.606903, 34.010803]], dtype=float32), array([ 9.391173, -25.379646], dtype=float32), array([[-27.2021 ],\n",
" [ 44.79025]], dtype=float32), array([9.642486], dtype=float32)], [array([[-44.09209, -37.20285],\n",
" [ 47.20231, 40.34598]], dtype=float32), array([ 13.101824, -25.8866 ], dtype=float32), array([[-33.470924],\n",
" [ 47.784706]], dtype=float32), array([14.320648], dtype=float32)], [array([[-36.38443 , -39.23304 ],\n",
" [ 52.953644, 38.646732]], dtype=float32), array([ 10.276208, -30.864595], dtype=float32), array([[-31.08338],\n",
" [ 52.16088]], dtype=float32), array([15.342434], dtype=float32)], [array([[-62.84543 , -47.409748],\n",
" [ 63.300335, 56.867214]], dtype=float32), array([ 17.78217, -33.01626], dtype=float32), array([[-48.512455],\n",
" [ 61.87751 ]], dtype=float32), array([19.369736], dtype=float32)], [array([[-71.16499 , -57.702408],\n",
" [ 80.223915, 69.13328 ]], dtype=float32), array([ 19.08833 , -41.566013], dtype=float32), array([[-57.950104],\n",
" [ 76.35351 ]], dtype=float32), array([24.470982], dtype=float32)], [array([[-120.93972, -92.38105],\n",
" [ 107.01377, 110.19025]], dtype=float32), array([ 28.39684 , -59.285316], dtype=float32), array([[-75.1781 ],\n",
" [129.59488]], dtype=float32), array([34.034805], dtype=float32)], [array([[-161.36476, -114.62916],\n",
" [ 142.47905, 152.3887 ]], dtype=float32), array([ 36.139404, -74.1054 ], dtype=float32), array([[-101.517525],\n",
" [ 171.30031 ]], dtype=float32), array([42.26851], dtype=float32)]]\n",
"score > [0.5, 0.5]\n"
" [0]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
@@ -309,13 +249,12 @@
"model = make_model()\n",
"\n",
"loss = keras.losses.MeanSquaredError()\n",
"optimizer = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n",
"optimizer = keras.optimizers.SGD(lr=0.1, momentum=0.9, decay=1e-05, nesterov=True)\n",
"\n",
"\n",
"pso_xor = PSO(model=model, loss_method=loss, optimizer=optimizer, n_particles=15)\n",
"\n",
"pso_xor = PSO(model=model, loss=loss, optimizer=optimizer, n_particles=5)\n",
"\n",
"best_weights, score = pso_xor.optimize(x, y, x_test, y_test, maxiter=30)\n",
"best_weights, score = pso_xor.optimize(x, y, x_test, y_test, maxiter=10, epochs=20)\n",
"\n",
"model.set_weights(best_weights)\n",
"\n",
@@ -330,6 +269,17 @@
"# plt.plot(history)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x_test = np.array([[0, 1], [0, 0], [1, 1], [1, 0]])\n",
"y_pred = model.predict(x_test)\n",
"print(y_pred)"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -352,6 +302,26 @@
" return hist"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predicted_result = model.predict(x_test)\n",
"predicted_labels = np.argmax(predicted_result, axis=1)\n",
"not_correct = []\n",
"for i in range(len(y_test)):\n",
" if predicted_labels[i] != y_test[i]:\n",
" not_correct.append(i)\n",
" # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n",
" \n",
"print(f\"틀린 것 갯수 > {len(not_correct)}\")\n",
"for i in range(3):\n",
" plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,