From 7c5f3a53a3788dd70b4b96ee97b4601c7333f07c Mon Sep 17 00:00:00 2001 From: jung-geun Date: Sun, 21 May 2023 14:00:03 +0900 Subject: [PATCH] =?UTF-8?q?23-05-21=20pso=20=EA=B8=B0=EB=B3=B8=20=EC=95=8C?= =?UTF-8?q?=EA=B3=A0=EB=A6=AC=EC=A6=98=EC=9D=84=20=EC=9D=B4=EC=9A=A9?= =?UTF-8?q?=ED=95=9C=20tensorflow=20model=EC=9D=98=20pso=20=EC=95=8C?= =?UTF-8?q?=EA=B3=A0=EB=A6=AC=EC=A6=98=ED=99=94=20-=20xor=20=EB=AC=B8?= =?UTF-8?q?=EC=A0=9C,=20mnist=20=EB=B6=84=EB=A5=98=20=EC=84=B1=EA=B3=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 + mnist.ipynb | 364 ++++++++++++++++++++++++++++++++++++++++++++ pso.py | 137 +++++++++++++++++ pso_tf.py | 261 +++++++++++++++++++++++++++++++ readme.md | 53 +++++++ xor.ipynb | 431 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1251 insertions(+) create mode 100644 .gitignore create mode 100644 mnist.ipynb create mode 100644 pso.py create mode 100644 pso_tf.py create mode 100644 readme.md create mode 100644 xor.ipynb diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5ca7509 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*.pyc +__pycache__/ +.ipynb_checkpoints/ +*.pdf +model/ \ No newline at end of file diff --git a/mnist.ipynb b/mnist.ipynb new file mode 100644 index 0000000..c58e0dc --- /dev/null +++ b/mnist.ipynb @@ -0,0 +1,364 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8a637c69-9071-4012-ac1e-93037548b3e9", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-21 03:38:18.127052: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.10.0\n", + "[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", + "\n", + "from pso_tf import PSO\n", + "\n", + "import numpy as np\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "tf.random.set_seed(777) # for reproducibility\n", + "\n", + "from keras.datasets import mnist\n", + "from keras.models import Sequential\n", + "from keras.layers import Dense, Dropout, Flatten\n", + "from keras.layers import Conv2D, MaxPooling2D\n", + "from keras import backend as K\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from datetime import date\n", + "\n", + "print(tf.__version__)\n", + "print(tf.config.list_physical_devices())\n", + "\n", + "def get_data():\n", + " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "\n", + " x_train, x_test = x_train / 255.0, x_test / 255.0\n", + " x_train = x_train.reshape((60000, 28 ,28, 1))\n", + " x_test = x_test.reshape((10000, 28 ,28, 1))\n", + "\n", + " print(f\"x_train : {x_train[0].shape} | y_train : {y_train[0].shape}\")\n", + " print(f\"x_test : {x_test[0].shape} | y_test : {y_test[0].shape}\")\n", + " return x_train, y_train, x_test, y_test\n", + "\n", + "def make_model():\n", + " model = Sequential()\n", + " model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n", + " model.add(MaxPooling2D(pool_size=(3, 3)))\n", + " model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n", + " model.add(MaxPooling2D(pool_size=(2, 2)))\n", + " model.add(Dropout(0.25))\n", + " model.add(Flatten())\n", + " model.add(Dense(128, activation='relu'))\n", + " model.add(Dense(10, activation='softmax'))\n", + "\n", + " model.compile(loss='sparse_categorical_crossentropy',optimizer='adam', metrics=['accuracy'])\n", + "\n", + " model.summary()\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a2d9891d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train : (28, 28, 1) | y_train : ()\n", + "x_test : (28, 28, 1) | y_test : ()\n", + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " conv2d (Conv2D) (None, 24, 24, 32) 832 \n", + " \n", + " max_pooling2d (MaxPooling2D (None, 8, 8, 32) 0 \n", + " ) \n", + " \n", + " conv2d_1 (Conv2D) (None, 6, 6, 64) 18496 \n", + " \n", + " max_pooling2d_1 (MaxPooling (None, 3, 3, 64) 0 \n", + " 2D) \n", + " \n", + " dropout (Dropout) (None, 3, 3, 64) 0 \n", + " \n", + " flatten (Flatten) (None, 576) 0 \n", + " \n", + " dense (Dense) (None, 128) 73856 \n", + " \n", + " dense_1 (Dense) (None, 10) 1290 \n", + " \n", + "=================================================================\n", + "Total params: 94,474\n", + "Trainable params: 94,474\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", + " super().__init__(name, **kwargs)\n", + "init particles position: 100%|██████████| 2/2 [00:00<00:00, 34.52it/s]\n", + "init velocities: 100%|██████████| 2/2 [00:00<00:00, 1203.19it/s]\n", + "Iteration 0 / 10: 100%|##########| 2/2 [00:48<00:00, 24.43s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.021180160343647003 | acc : 0.9930999875068665\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 1 / 10: 100%|##########| 2/2 [00:46<00:00, 23.46s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.029995786026120186 | acc : 0.9927999973297119\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 2 / 10: 100%|##########| 2/2 [00:47<00:00, 23.57s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.13020965456962585 | acc : 0.9929999709129333\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 3 / 10: 100%|##########| 2/2 [00:47<00:00, 23.62s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.032199352979660034 | acc : 0.9918000102043152\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 4 / 10: 100%|##########| 2/2 [00:47<00:00, 23.66s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.025606701150536537 | acc : 0.9925000071525574\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 5 / 10: 100%|##########| 2/2 [00:47<00:00, 23.64s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.04198306426405907 | acc : 0.9921000003814697\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 6 / 10: 100%|##########| 2/2 [00:47<00:00, 23.69s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.048351287841796875 | acc : 0.9919999837875366\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 7 / 10: 100%|##########| 2/2 [00:47<00:00, 23.73s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.0416271910071373 | acc : 0.9890999794006348\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 8 / 10: 100%|##########| 2/2 [00:47<00:00, 23.70s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.18129077553749084 | acc : 0.9502000212669373\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration 9 / 10: 100%|##########| 2/2 [00:47<00:00, 23.69s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : 0.8072962760925293 | acc : 0.7225000262260437\n", + "313/313 - 0s - loss: 0.0212 - accuracy: 0.9931 - 290ms/epoch - 926us/step\n", + "Test loss: 0.021180160343647003 / Test accuracy: 0.9930999875068665\n" + ] + } + ], + "source": [ + "x_train, y_train, x_test, y_test = get_data()\n", + "model = make_model()\n", + "\n", + "loss = keras.losses.MeanSquaredError()\n", + "optimizer = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n", + "\n", + "pso_m = PSO(model=model, loss_method=loss, optimizer=optimizer, n_particles=2)\n", + "best_weights, score = pso_m.optimize(x_train, y_train, x_test, y_test, maxiter=10)\n", + "\n", + "model.set_weights(best_weights)\n", + "\n", + "score = model.evaluate(x_test, y_test, verbose=2)\n", + "print(f\"Test loss: {score[0]} / Test accuracy: {score[1]}\")\n", + "\n", + "day = date.today().strftime(\"%Y-%m-%d\")\n", + "\n", + "model.save(f'./model/{day}_mnist.h5')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1a38f3c1-8291-40d9-838e-4ffbf4578be5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "313/313 [==============================] - 0s 685us/step\n", + "틀린 것 갯수 > 69\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbtUlEQVR4nO3df2zV9fXH8dctP66AvZfV0t52FFZAQQVKYNA1KuJoaLuEifCHPxMwBgcrRmROg1PRbUkd5qtGw3BZHIxE1JkJKMm6aLFlbgUHShhxNpR0A0NbJkvvLUUulb6/fxDuvFB+fC739vReno/kJvTee3qPHy59entvb33OOScAAPpYlvUCAIArEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmBlovcLaenh4dPnxY2dnZ8vl81usAADxyzqmzs1OFhYXKyjr/45x+F6DDhw+rqKjIeg0AwGU6dOiQRo4ced7L+12AsrOzJZ1ePBAIGG8DAPAqEomoqKgo9vX8fFIWoDVr1uj5559XW1ubSkpK9Morr2jGjBkXnTvzbbdAIECAACCNXexplJS8COGtt97SihUrtGrVKn3yyScqKSlRRUWFjhw5koqbAwCkoZQE6IUXXtDixYt1//3364YbbtCrr76qoUOH6ne/+10qbg4AkIaSHqCTJ09q9+7dKi8v/9+NZGWpvLxcjY2N51w/Go0qEonEnQAAmS/pAfryyy916tQp5efnx52fn5+vtra2c65fU1OjYDAYO/EKOAC4Mpj/IOrKlSsVDodjp0OHDlmvBADoA0l/FVxubq4GDBig9vb2uPPb29sVCoXOub7f75ff70/2GgCAfi7pj4AGDx6sadOmqa6uLnZeT0+P6urqVFZWluybAwCkqZT8HNCKFSu0cOFCffe739WMGTP00ksvqaurS/fff38qbg4AkIZSEqA777xT//nPf/T000+rra1NU6ZMUW1t7TkvTAAAXLl8zjlnvcQ3RSIRBYNBhcNh3gkBANLQpX4dN38VHADgykSAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGi9wJXitdde8zzj8/k8z0yZMsXzzNSpUz3PAMDl4hEQAMAEAQIAmEh6gJ555hn5fL6404QJE5J9MwCANJeS54BuvPFGffDBB/+7kYE81QQAiJeSMgwcOFChUCgVnxoAkCFS8hzQ/v37VVhYqDFjxujee+/VwYMHz3vdaDSqSCQSdwIAZL6kB6i0tFTr169XbW2t1q5dq5aWFt1yyy3q7Ozs9fo1NTUKBoOxU1FRUbJXAgD0Qz7nnEvlDXR0dGj06NF64YUX9MADD5xzeTQaVTQajX0ciURUVFSkcDisQCCQytX6FD8HBOBKEYlEFAwGL/p1POWvDhg+fLiuu+46NTc393q53++X3+9P9RoAgH4m5T8HdOzYMR04cEAFBQWpvikAQBpJeoAeffRRNTQ06F//+pf+9re/6Y477tCAAQN09913J/umAABpLOnfgvviiy9099136+jRoxoxYoRuvvlm7dixQyNGjEj2TQEA0ljKX4Tg1aU+eZVusrK8P9hM5EUIifzQ79ChQz3P4PIk8s9u7dq1nmf68vnVd9991/PMj370I88zZWVlnmfQty716zjvBQcAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmEj5L6RD3/r66689z0QikRRsggtJ5M1I77vvvhRsYqutrc3zTG1tbQo2gQUeAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAE74bdR7Zs2eJ55t13303BJslz5MgRzzNbt25NwSYA0hGPgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEz7nnLNe4psikYiCwaDC4bACgYD1OriAzz77zPPMpEmTUrBJ+hk7dqznmalTp6Zgk94l8veUl5fneeaHP/yh55n8/HzPM+hbl/p1nEdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJgdYLIH21trZar5B0AwYM8Dzz0ksveZ656667PM/k5OR4ngH6Mx4BAQBMECAAgAnPAdq+fbvmzp2rwsJC+Xw+bd68Oe5y55yefvppFRQUaMiQISovL9f+/fuTtS8AIEN4DlBXV5dKSkq0Zs2aXi9fvXq1Xn75Zb366qvauXOnhg0bpoqKCp04ceKylwUAZA7PL0KoqqpSVVVVr5c55/TSSy/pySef1O233y5J2rBhg/Lz87V58+aEnngFAGSmpD4H1NLSora2NpWXl8fOCwaDKi0tVWNjY68z0WhUkUgk7gQAyHxJDVBbW5ukc39ne35+fuyys9XU1CgYDMZORUVFyVwJANBPmb8KbuXKlQqHw7HToUOHrFcCAPSBpAYoFApJktrb2+POb29vj112Nr/fr0AgEHcCAGS+pAaouLhYoVBIdXV1sfMikYh27typsrKyZN4UACDNeX4V3LFjx9Tc3Bz7uKWlRXv27FFOTo5GjRql5cuX65e//KWuvfZaFRcX66mnnlJhYaHmzZuXzL0BAGnOc4B27dql2267LfbxihUrJEkLFy7U+vXr9dhjj6mrq0sPPvigOjo6dPPNN6u2tlZXXXVV8rYGAKQ9n3POWS/xTZFIRMFgUOFwmOeD+kg0Gk1obtasWZ5nPv7444Ruy6sRI0YkNFdbW+t5ZsqUKQndFpCpLvXruPmr4AAAVyYCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY8PzrGJB5Pv/884Tm9uzZk9xFkuj48eMJzW3YsKFPZqqqqjzPlJeXe57x+XyeZ4C+wiMgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCEzznnrJf4pkgkomAwqHA4rEAgYL0OLuC3v/2t55klS5akYJP0k8g/uyeeeMLzzKBBgzzPSNJDDz3keWbYsGGeZxLZLyuL/2/u7y716zh/kwAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACd6MFAnr6OjwPLN582bPM2vXrvU8s2vXLs8zfSmRf3Y+ny8Fm9iqrq72PPOzn/3M80x+fr7nGSSONyMFAPRrBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJ3owU+IZE3sT0L3/5i+eZ7du3e57ZsmWL55lMlMiXrKVLlyZ0WzfffLPnmXvuuSeh28okvBkpAKBfI0AAABOeA7R9+3bNnTtXhYWF8vl85/x+l0WLFsnn88WdKisrk7UvACBDeA5QV1eXSkpKtGbNmvNep7KyUq2trbHTG2+8cVlLAgAyz0CvA1VVVaqqqrrgdfx+v0KhUMJLAQAyX0qeA6qvr1deXp7Gjx+vpUuX6ujRo+e9bjQaVSQSiTsBADJf0gNUWVmpDRs2qK6uTr/61a/U0NCgqqoqnTp1qtfr19TUKBgMxk5FRUXJXgkA0A95/hbcxdx1112xP0+aNEmTJ0/W2LFjVV9fr9mzZ59z/ZUrV2rFihWxjyORCBECgCtAyl+GPWbMGOXm5qq5ubnXy/1+vwKBQNwJAJD5Uh6gL774QkePHlVBQUGqbwoAkEY8fwvu2LFjcY9mWlpatGfPHuXk5CgnJ0fPPvusFixYoFAopAMHDuixxx7TuHHjVFFRkdTFAQDpzXOAdu3apdtuuy328ZnnbxYuXKi1a9dq7969+v3vf6+Ojg4VFhZqzpw5+sUvfiG/35+8rQEAaY83IwUMnO9VoRfy9ddfe575zW9+43lGSuxNWV9//fWEbqs/GzJkiOeZDRs2eJ6ZP3++55n+jDcjBQD0awQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBu2EDOEd3d7fnmRMnTnie+e9//+t5Zvbs2Z5nWlpaPM/0pUTeHb0/492wAQD9GgECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgYqD1AgD6n0GDBvXJTHt7u+eZcePGeZ7p729GeqXiERAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYII3I4Vqa2sTmvu///s/zzNTpkzxPPP88897nsFp+/fvT2iuu7vb88yLL77oeeaPf/yj55lwOOx5pi/dcMMN1iukDR4BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmeDPSDHP8+HHPMw8//HBCt9Xc3Ox55pNPPvE8U1lZ6Xlm3LhxnmcS9ec//9nzzAcffOB5xufzeZ557733PM9IUjQaTWgu0wwbNszzzEcffZSCTTITj4AAACYIEADAhKcA1dTUaPr06crOzlZeXp7mzZunpqamuOucOHFC1dXVuuaaa3T11VdrwYIFam9vT+rSAID05ylADQ0Nqq6u1o4dO/T++++ru7tbc+bMUVdXV+w6jzzyiN577z29/fbbamho0OHDhzV//vykLw4ASG+eXoRw9m/OXL9+vfLy8rR7927NnDlT4XBYr732mjZu3Kjvf//7kqR169bp+uuv144dO/S9730veZsDANLaZT0HdOZX4+bk5EiSdu/ere7ubpWXl8euM2HCBI0aNUqNjY29fo5oNKpIJBJ3AgBkvoQD1NPTo+XLl+umm27SxIkTJUltbW0aPHiwhg8fHnfd/Px8tbW19fp5ampqFAwGY6eioqJEVwIApJGEA1RdXa19+/bpzTffvKwFVq5cqXA4HDsdOnTosj4fACA9JPSDqMuWLdPWrVu1fft2jRw5MnZ+KBTSyZMn1dHREfcoqL29XaFQqNfP5ff75ff7E1kDAJDGPD0Ccs5p2bJl2rRpk7Zt26bi4uK4y6dNm6ZBgwaprq4udl5TU5MOHjyosrKy5GwMAMgInh4BVVdXa+PGjdqyZYuys7Njz+sEg0ENGTJEwWBQDzzwgFasWKGcnBwFAgE99NBDKisr4xVwAIA4ngK0du1aSdKsWbPizl+3bp0WLVokSXrxxReVlZWlBQsWKBqNqqKiQr/+9a+TsiwAIHP4nHPOeolvikQiCgaDCofDCgQC1uuknb///e+eZ2699daEbos3rExcT0+P55msrMx756whQ4Z4npkwYYLnmdzcXM8zkvTcc895npkyZUpCt5VJLvXreObdowEAaYEAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmEvqNqOi/pk+f7nmmoqIiodvauXOn55n29vaEbguJGTZsWEJzibx79JNPPul5ZurUqZ5neLfpzMEjIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABG9GCm3atCmhuUgk4nlm7ty5nmf+8Y9/eJ5J1MKFCz3P3HrrrSnYJDnGjx+f0Nz111+f5E2Ac/EICABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwZuRImGBQMDzTENDQwo2AZCOeAQEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHgKUE1NjaZPn67s7Gzl5eVp3rx5ampqirvOrFmz5PP54k5LlixJ6tIAgPTnKUANDQ2qrq7Wjh079P7776u7u1tz5sxRV1dX3PUWL16s1tbW2Gn16tVJXRoAkP48/UbU2trauI/Xr1+vvLw87d69WzNnzoydP3ToUIVCoeRsCADISJf1HFA4HJYk5eTkxJ3/+uuvKzc3VxMnTtTKlSt1/Pjx836OaDSqSCQSdwIAZD5Pj4C+qaenR8uXL9dNN92kiRMnxs6/5557NHr0aBUWFmrv3r16/PHH1dTUpHfeeafXz1NTU6Nnn3020TUAAGnK55xziQwuXbpUf/rTn/TRRx9p5MiR573etm3bNHv2bDU3N2vs2LHnXB6NRhWNRmMfRyIRFRUVKRwOKxAIJLIaAMBQJBJRMBi86NfxhB4BLVu2TFu3btX27dsvGB9JKi0tlaTzBsjv98vv9yeyBgAgjXkKkHNODz30kDZt2qT6+noVFxdfdGbPnj2SpIKCgoQWBABkJk8Bqq6u1saNG7VlyxZlZ2erra1NkhQMBjVkyBAdOHBAGzdu1A9+8ANdc8012rt3rx555BHNnDlTkydPTsl/AAAgPXl6Dsjn8/V6/rp167Ro0SIdOnRI9913n/bt26euri4VFRXpjjvu0JNPPnnJz+dc6vcOAQD9U0qeA7pYq4qKitTQ0ODlUwIArlC8FxwAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMRA6wXO5pyTJEUiEeNNAACJOPP1+8zX8/PpdwHq7OyUJBUVFRlvAgC4HJ2dnQoGg+e93Oculqg+1tPTo8OHDys7O1s+ny/uskgkoqKiIh06dEiBQMBoQ3sch9M4DqdxHE7jOJzWH46Dc06dnZ0qLCxUVtb5n+npd4+AsrKyNHLkyAteJxAIXNF3sDM4DqdxHE7jOJzGcTjN+jhc6JHPGbwIAQBgggABAEykVYD8fr9WrVolv99vvYopjsNpHIfTOA6ncRxOS6fj0O9ehAAAuDKk1SMgAEDmIEAAABMECABgggABAEykTYDWrFmj73znO7rqqqtUWlqqjz/+2HqlPvfMM8/I5/PFnSZMmGC9Vspt375dc+fOVWFhoXw+nzZv3hx3uXNOTz/9tAoKCjRkyBCVl5dr//79Nsum0MWOw6JFi865f1RWVtosmyI1NTWaPn26srOzlZeXp3nz5qmpqSnuOidOnFB1dbWuueYaXX311VqwYIHa29uNNk6NSzkOs2bNOuf+sGTJEqONe5cWAXrrrbe0YsUKrVq1Sp988olKSkpUUVGhI0eOWK/W52688Ua1trbGTh999JH1SinX1dWlkpISrVmzptfLV69erZdfflmvvvqqdu7cqWHDhqmiokInTpzo401T62LHQZIqKyvj7h9vvPFGH26Yeg0NDaqurtaOHTv0/vvvq7u7W3PmzFFXV1fsOo888ojee+89vf3222poaNDhw4c1f/58w62T71KOgyQtXrw47v6wevVqo43Pw6WBGTNmuOrq6tjHp06dcoWFha6mpsZwq763atUqV1JSYr2GKUlu06ZNsY97enpcKBRyzz//fOy8jo4O5/f73RtvvGGwYd84+zg459zChQvd7bffbrKPlSNHjjhJrqGhwTl3+u9+0KBB7u23345d55///KeT5BobG63WTLmzj4Nzzt16663u4YcftlvqEvT7R0AnT57U7t27VV5eHjsvKytL5eXlamxsNNzMxv79+1VYWKgxY8bo3nvv1cGDB61XMtXS0qK2tra4+0cwGFRpaekVef+or69XXl6exo8fr6VLl+ro0aPWK6VUOByWJOXk5EiSdu/ere7u7rj7w4QJEzRq1KiMvj+cfRzOeP3115Wbm6uJEydq5cqVOn78uMV659Xv3oz0bF9++aVOnTql/Pz8uPPz8/P1+eefG21lo7S0VOvXr9f48ePV2tqqZ599Vrfccov27dun7Oxs6/VMtLW1SVKv948zl10pKisrNX/+fBUXF+vAgQN64oknVFVVpcbGRg0YMMB6vaTr6enR8uXLddNNN2nixImSTt8fBg8erOHDh8ddN5PvD70dB0m65557NHr0aBUWFmrv3r16/PHH1dTUpHfeecdw23j9PkD4n6qqqtifJ0+erNLSUo0ePVp/+MMf9MADDxhuhv7grrvuiv150qRJmjx5ssaOHav6+nrNnj3bcLPUqK6u1r59+66I50Ev5HzH4cEHH4z9edKkSSooKNDs2bN14MABjR07tq/X7FW//xZcbm6uBgwYcM6rWNrb2xUKhYy26h+GDx+u6667Ts3NzdarmDlzH+D+ca4xY8YoNzc3I+8fy5Yt09atW/Xhhx/G/fqWUCikkydPqqOjI+76mXp/ON9x6E1paakk9av7Q78P0ODBgzVt2jTV1dXFzuvp6VFdXZ3KysoMN7N37NgxHThwQAUFBdarmCkuLlYoFIq7f0QiEe3cufOKv3988cUXOnr0aEbdP5xzWrZsmTZt2qRt27apuLg47vJp06Zp0KBBcfeHpqYmHTx4MKPuDxc7Dr3Zs2ePJPWv+4P1qyAuxZtvvun8fr9bv369++yzz9yDDz7ohg8f7tra2qxX61M/+clPXH19vWtpaXF//etfXXl5ucvNzXVHjhyxXi2lOjs73aeffuo+/fRTJ8m98MIL7tNPP3X//ve/nXPOPffcc2748OFuy5Ytbu/eve722293xcXF7quvvjLePLkudBw6Ozvdo48+6hobG11LS4v74IMP3NSpU921117rTpw4Yb160ixdutQFg0FXX1/vWltbY6fjx4/HrrNkyRI3atQot23bNrdr1y5XVlbmysrKDLdOvosdh+bmZvfzn//c7dq1y7W0tLgtW7a4MWPGuJkzZxpvHi8tAuScc6+88oobNWqUGzx4sJsxY4bbsWOH9Up97s4773QFBQVu8ODB7tvf/ra78847XXNzs/VaKffhhx86SeecFi5c6Jw7/VLsp556yuXn5zu/3+9mz57tmpqabJdOgQsdh+PHj7s5c+a4ESNGuEGDBrnRo0e7xYsXZ9z/pPX23y/JrVu3Lnadr776yv34xz923/rWt9zQoUPdHXfc4VpbW+2WToGLHYeDBw+6mTNnupycHOf3+924cePcT3/6UxcOh20XPwu/jgEAYKLfPwcEAMhMBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJ/wdNoTZjrQGkuQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# print(f\"정답 > {y_test}\")\n", + "predicted_result = model.predict(x_test)\n", + "predicted_labels = np.argmax(predicted_result, axis=1)\n", + "not_correct = []\n", + "for i in range(len(y_test)):\n", + " if predicted_labels[i] != y_test[i]:\n", + " not_correct.append(i)\n", + " # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n", + " \n", + "print(f\"틀린 것 갯수 > {len(not_correct)}\")\n", + "for i in range(3):\n", + " plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc2d7044", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pso", + "language": "python", + "name": "pso" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pso.py b/pso.py new file mode 100644 index 0000000..b7042b2 --- /dev/null +++ b/pso.py @@ -0,0 +1,137 @@ +import numpy as np + +class PSO(object): + """ + Class implementing PSO algorithm + """ + + def __init__(self, func, init_pos, n_particles): + """ + Initialize the key variables. + + Args: + fun (function): the fitness function to optimize + init_pos(array_like): + n_particles(int): the number of particles of the swarm. + """ + self.func = func + self.n_particles = n_particles + self.init_pos = init_pos # 곰샥헐 차원 + self.particle_dim = len(init_pos) # 검색할 차원의 크기 + self.particles_pos = np.random.uniform(size=(n_particles, self.particle_dim)) \ + * self.init_pos + # 입력받은 파티클의 개수 * 검색할 차원의 크기 만큼의 균등한 위치를 생성 + self.velocities = np.random.uniform( + size=(n_particles, self.particle_dim)) + # 입력받은 파티클의 개수 * 검색할 차원의 크기 만큼의 속도를 무작위로 초기화 + self.g_best = init_pos # 최대 사이즈로 전역 최적갑 저장 - global best + self.p_best = self.particles_pos # 모든 파티클의 위치 - particles best + self.g_history = [] + self.history = [] + + def update_position(self, x, v): + """ + Update particle position + + Args: + x (array-like): particle current position + v (array-like): particle current velocity + + Returns: + The updated position(array-like) + """ + x = np.array(x) # 각 파티클의 위치 + v = np.array(v) # 각 파티클의 속도(방향과 속력을 가짐) + new_x = x + v # 각 파티클을 랜덤한 속도만큼 진행 + return new_x # 진행한 파티클들의 위치를 반환 + + def update_velocity(self, x, v, p_best, g_best, c0=0.5, c1=1.5, w=0.75): + """ + Update particle velocity + + Args: + x(array-like): particle current position + v (array-like): particle current velocity + p_best(array-like): the best position found so far for a particle + g_best(array-like): the best position regarding all the particles found so far + c0 (float): the congnitive scaling constant, 인지 스케일링 상수 + c1 (float): the social scaling constant + w (float): the inertia weight, 관성 중량 + + Returns: + The updated velocity (array-like). + """ + x = np.array(x) + v = np.array(v) + assert x.shape == v.shape, "Position and velocity must have same shape." + # 두 데이터의 shape 이 같지 않으면 오류 출력 + # 0에서 1사이의 숫자를 랜덤 생성 + r = np.random.uniform() + p_best = np.array(p_best) + g_best = np.array(g_best) + + # 가중치(상수)*속도 + \ + # 스케일링 상수*랜덤 가중치*(나의 최적값 - 처음 위치) + \ + # 전역 스케일링 상수*랜덤 가중치*(전체 최적값 - 처음 위치) + new_v = w*v + c0*r*(p_best - x) + c1*r*(g_best - x) + return new_v + + def optimize(self, maxiter=200): + """ + Run the PSO optimization process utill the stoping critera is met. + Cas for minization. The aim is to minimize the cost function + + Args: + maxiter (int): the maximum number of iterations before stopping the optimization + 파티클의 최종 위치를 위한 반복 횟수 + Returns: + The best solution found (array-like) + """ + for _ in range(maxiter): + for i in range(self.n_particles): + x = self.particles_pos[i] # 각 파티클 추출 + v = self.velocities[i] # 랜덤 생성한 속도 추출 + p_best = self.p_best[i] # 결과치 저장할 변수 지정 + self.velocities[i] = self.update_velocity( + x, v, p_best, self.g_best) + # 다음에 움직일 속도 = 최초 위치, 현재 속도, 현재 위치, 최종 위치 + self.particles_pos[i] = self.update_position(x, v) + # 현재 위치 = 최초 위치 현재 속도 + # Update the besst position for particle i + # 내 현재 위치가 내 위치의 최소치보다 작으면 갱신 + if self.func(self.particles_pos[i]) < self.func(p_best): + self.p_best[i] = self.particles_pos[i] + # Update the best position overall + # 내 현재 위치가 전체 위치 최소치보다 작으면 갱신 + if self.func(self.particles_pos[i]) < self.func(self.g_best): + self.g_best = self.particles_pos[i] + self.g_history.append(self.g_best) + + self.history.append(self.particles_pos.copy()) + + # 전체 최소 위치, 전체 최소 벡터 + return self.g_best, self.func(self.g_best) + + """ + Returns: + 현재 전체 위치 + """ + + def position(self): + return self.particles_pos.copy() + + """ + Returns: + 전체 위치 벡터 history + """ + + def position_history(self): + return self.history + + """ + Returns: + global best 의 갱신된 값의 변화를 반환 + """ + + def global_history(self): + return self.g_history.copy() diff --git a/pso_tf.py b/pso_tf.py new file mode 100644 index 0000000..f0d5f6a --- /dev/null +++ b/pso_tf.py @@ -0,0 +1,261 @@ +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tqdm import tqdm + + +class PSO(object): + """ + Class implementing PSO algorithm + """ + + def __init__(self, model, loss_method=keras.losses.MeanSquaredError(), optimizer=keras.optimizers.SGD(), n_particles=5): + """ + Initialize the key variables. + + Args: + model : 학습할 모델 객체 (Sequential) + loss_method : 손실 함수 + optimizer : 최적화 함수 + n_particles(int) : 파티클의 개수 + """ + self.model = model # 모델 + self.n_particles = n_particles # 파티클의 개수 + self.loss_method = loss_method # 손실 함수 + self.optimizer = optimizer # 최적화 함수 + self.model_structure = self.model.to_json() # 모델의 구조 + self.init_weights = self.model.get_weights() # 검색할 차원 + self.particle_depth = len(self.model.get_weights()) # 검색할 차원의 깊이 + self.particles_weights = [None] * n_particles # 파티클의 위치 + for _ in tqdm(range(self.n_particles), desc="init particles position"): + # particle_node = [] + m = keras.models.model_from_json(self.model_structure) + m.compile(loss=self.loss_method, optimizer=self.optimizer) + + self.particles_weights[_] = m.get_weights() + + # self.particles_weights.append(particle_node) + + # print(f"particles_weights > {self.particles_weights}") + # self.particles_weights = np.random.uniform(size=(n_particles, self.particle_depth)) \ + # * self.init_pos + # 입력받은 파티클의 개수 * 검색할 차원의 크기 만큼의 균등한 위치를 생성 + # self.velocities = [None] * self.n_particles + self.velocities = [ + [0 for i in range(self.particle_depth)] for n in range(n_particles)] + for i in tqdm(range(n_particles), desc="init velocities"): + # print(i) + for index, layer in enumerate(self.init_weights): + # print(f"index > {index}") + # print(f"layer > {layer.shape}") + self.velocities[i][index] = np.random.rand( + *layer.shape) / 5 - 0.10 + # if layer.ndim == 1: + # self.velocities[i][index] = np.random.uniform( + # size=(layer.shape[0],)) + # elif layer.ndim == 2: + # self.velocities[i][index] = np.random.uniform( + # size=(layer.shape[0], layer.shape[1])) + # elif layer.ndim == 3: + # self.velocities[i][index] = np.random.uniform( + # size=(layer.shape[0], layer.shape[1], layer.shape[2])) + # print(f"type > {type(self.velocities)}") + # print(f"velocities > {self.velocities}") + + # print(f"velocities > {self.velocities}") + # for i, layer in enumerate(self.init_weights): + # self.velocities[i] = np.random.rand(*layer.shape) / 5 - 0.10 + + # self.velocities = np.random.uniform( + # size=(n_particles, self.particle_depth)) + # 입력받은 파티클의 개수 * 검색할 차원의 크기 만큼의 속도를 무작위로 초기화 + # 최대 사이즈로 전역 최적갑 저장 - global best + self.g_best = self.model.get_weights() # 전역 최적값(최적의 가중치) + self.p_best = self.particles_weights # 각 파티클의 최적값(최적의 가중치) + self.p_best_score = [np.inf for i in range( + n_particles)] # 각 파티클의 최적값의 점수 + self.g_best_score = np.inf # 전역 최적값의 점수(초기화 - 무한대) + self.g_history = [] + self.history = [] + + def _update_weights(self, weights, v): + """ + Update particle position + + Args: + weights (array-like) : 파티클의 현재 가중치 + v (array-like) : 가중치의 속도 + + Returns: + (array-like) : 파티클의 새로운 가중치(위치) + """ + # w = np.array(w) # 각 파티클의 위치 + # v = np.array(v) # 각 파티클의 속도(방향과 속력을 가짐) + # print(f"len(w) > {len(w)}") + # print(f"len(v) > {len(v)}") + new_weights = [0 for i in range(len(weights))] + for i in range(len(weights)): + # print(f"shape > w : {np.shape(w[i])}, v : {np.shape(v[i])}") + new_weights[i] = tf.add(weights[i], v[i]) + # new_w = tf.add(w, v) # 각 파티클을 랜덤한 속도만큼 진행 + return new_weights # 진행한 파티클들의 위치를 반환 + + def _update_velocity(self, weights, v, p_best, c0=0.5, c1=1.5, w=0.75): + """ + Update particle velocity + + Args: + weights (array-like) : 파티클의 현재 가중치 + v (array-like) : 속도 + p_best(array-like) : 각 파티클의 최적의 위치 (최적의 가중치) + c0 (float) : 인지 스케일링 상수 (가중치의 중요도 - 지역) - 지역 관성 + c1 (float) : 사회 스케일링 상수 (가중치의 중요도 - 전역) - 전역 관성 + w (float) : 관성 상수 (현재 속도의 중요도) + + Returns: + (array-like) : 각 파티클의 새로운 속도 + """ + # x = np.array(x) + # v = np.array(v) + # assert np.shape(weights) == np.shape(v), "Position and velocity must have same shape." + # 두 데이터의 shape 이 같지 않으면 오류 출력 + # 0에서 1사이의 숫자를 랜덤 생성 + r0 = np.random.rand() + r1 = np.random.rand() + # print(f"type > weights : {type(weights)}") + # print(f"type > v : {type(v)}") + # print( + # f"shape > weights : {np.shape(weights[0])}, v : {np.shape(v[0])}") + # print(f"len > weights : {len(weights)}, v : {len(v)}") + # p_best = np.array(p_best) + # g_best = np.array(g_best) + + # 가중치(상수)*속도 + \ + # 스케일링 상수*랜덤 가중치*(나의 최적값 - 처음 위치) + \ + # 전역 스케일링 상수*랜덤 가중치*(전체 최적값 - 처음 위치) + # for i, layer in enumerate(weights): + new_velocity = [None] * len(weights) + for i, layer in enumerate(weights): + + new_v = w*v[i] + new_v = new_v + c0*r0*(p_best[i] - layer) + # m2 = tf.multiply(tf.multiply(c0, r0), + # tf.subtract(p_best[i], layer)) + new_v = new_v + c1*r1*(self.g_best[i] - layer) + # m3 = tf.multiply(tf.multiply(c1, r1), + # tf.subtract(g_best[i], layer)) + new_velocity[i] = new_v + # new_v[i] = tf.add(m1, tf.add(m2, m3)) + # new_v[i] = tf.add_n([m1, m2, m3]) + # new_v[i] = tf.add_n( + # tf.multiply(w, v[i]), + # tf.multiply(tf.multiply(c0, r0), + # tf.subtract(p_best[i], layer)), + # tf.multiply(tf.multiply(c1, r1), + # tf.subtract(g_best[i], layer))) + # new_v = w*v + c0*r0*(p_best - weights) + c1*r1*(g_best - weights) + return new_velocity + + def _get_score(self, x, y): + """ + Compute the score of the current position of the particles. + + Args: + x (array-like): The current position of the particles + y (array-like): The current position of the particles + Returns: + (array-like) : 추론에 대한 점수 + """ + # = self.model + # model.set_weights(weights) + score = self.model.evaluate(x, y, verbose=0) + + return score + + def optimize(self, x_train, y_train, x_test, y_test, maxiter=20, epoch=10, verbose=0): + """ + Run the PSO optimization process utill the stoping critera is met. + Cas for minization. The aim is to minimize the cost function + + Args: + maxiter (int): the maximum number of iterations before stopping the optimization + 파티클의 최종 위치를 위한 반복 횟수 + Returns: + The best solution found (array-like) + """ + for _ in range(maxiter): + loss = 0 + acc = 0 + for i in tqdm(range(self.n_particles), desc=f"Iteration {_} / {maxiter}", ascii=True): + weights = self.particles_weights[i] # 각 파티클 추출 + v = self.velocities[i] # 각 파티클의 다음 속도 추출 + p_best = self.p_best[i] # 결과치 저장할 변수 지정 + self.velocities[i] = self._update_velocity( + weights, v, p_best) + # 다음에 움직일 속도 = 최초 위치, 현재 속도, 현재 위치, 최종 위치 + self.particles_weights[i] = self._update_weights(weights, v) + # 현재 위치 = 최초 위치 현재 속도 + # Update the besst position for particle i + # 내 현재 위치가 내 위치의 최소치보다 작으면 갱신 + + self.model.set_weights(self.particles_weights[i]) + self.model.fit(x_train, y_train, epochs=epoch, + verbose=0, validation_data=(x_test, y_test)) + self.particles_weights[i] = self.model.get_weights() + + score = self._get_score(x_test, y_test) + + # print(f"score : {score}") + # print(f"loss : {loss}") + # print(f"p_best_score : {self.p_best_score[i]}") + + if score[0] < self.p_best_score[i]: + self.p_best_score[i] = score[0] + self.p_best[i] = self.particles_weights[i] + if score[0] < self.g_best_score: + self.g_best_score = score[0] + self.g_best = self.particles_weights[i].copy() + self.g_history.append(self.g_best.copy()) + + self.score = score[0] + loss = score[0] + acc = score[1] + # if self.func(self.particles_weights[i]) < self.func(p_best): + # self.p_best[i] = self.particles_weights[i] + # if self. + # Update the best position overall + # 내 현재 위치가 전체 위치 최소치보다 작으면 갱신 + # if self.func(self.particles_weights[i]) < self.func(self.g_best): + # self.g_best = self.particles_weights[i] + # self.g_history.append(self.g_best) + # print(f"{i} particle score : {score[0]}") + print(f"loss : {loss} | acc : {acc}") + + # self.history.append(self.particles_weights.copy()) + + # 전체 최소 위치, 전체 최소 벡터 + return self.g_best, self._get_score(x_test, y_test) + + """ + Returns: + 현재 전체 위치 + """ + + def position(self): + return self.particles_weights.copy() + + """ + Returns: + 전체 위치 벡터 history + """ + + def position_history(self): + return self.history.copy() + + """ + Returns: + global best 의 갱신된 값의 변화를 반환 + """ + + def global_history(self): + return self.g_history.copy() diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..03093ee --- /dev/null +++ b/readme.md @@ -0,0 +1,53 @@ +# PSO 알고리즘 구현 및 새로운 시도 + +pso 알고리즘을 사용하여 새로운 학습 방법을 찾는중 입니다 +병렬처리로 사용하는 논문을 찾아보았지만 이보다 더 좋은 방법이 있을 것 같아서 찾아보고 있습니다 - A Distribute Deep Learning System Using PSO Algorithm.pdf + +기본 pso 알고리즘의 수식은 다음과 같습니다 + +> $$V_{id(t+1)} = W_{V_id(t)} + c_1 * r_1 (p_{id(t)} - x_{id(t)}) + c_2 * r_2(p_{gd(t)} - x_{id(t)})$$ + +다음 속도을 구하는 수식입니다 + +> $$x_{id(t+1)} = x_{id(t)} + V_{id(t+1)}$$ + +다음 위치를 구하는 수식입니다 + +> $$p_{id(t+1)} = \begin{cases} x_{id(t+1)} & \text{if } f(x_{id(t+1)}) < f(p_{id(t)}) \\ p_{id(t)} & \text{otherwise} \end{cases}$$ + +### 위치를 가장 최적값으로 변경(덮어쓰기)하면 안되는 이유 + +위치를 가장 최적값으로 변경하면 지역 최적값에서 벗어나지 못합니다. 따라서 전역 최적값을 찾을 수 없습니다. + +# 현재 진행 상황 + +## 1. PSO 알고리즘 구현 + +```plain text +pso.py # PSO 알고리즘 구현 +pso_tf.py # tensorflow 모델을 이용가능한 PSO 알고리즘 구현 + +xor.ipynb # xor 문제를 pso 알고리즘으로 풀이 +mnist.ipynb # mnist 문제를 pso 알고리즘으로 풀이 +``` + +## 2. PSO 알고리즘을 이용한 최적화 문제 풀이 + +pso 알고리즘을 이용하여 오차역전파 함수를 최적화 하는 방법을 찾는 중입니다 + +### 임시 아이디어 + +1. 오차역전파 함수를 1~5회 실행하여 오차를 구합니다 +2. 오차가 가장 적은 다른 노드(particle) 가중치로 유도합니다. + + 2-1. 만약 오차가 가장 작은 다른 노드가 현재 노드보다 오차가 크다면, 현재 노드의 가중치를 유지합니다. - 현재의 가중치를 최적값으로 업로드합니다 + + 2-2. 지역 최적값을 찾았다면, 전역 최적값을 찾을 때까지 1~2 과정을 반복합니다 + +3. 전역 최적값이 특정 임계치에서 변화율이 적다면 학습을 종료합니다 + +### 개인적인 생각 + +> 머신러닝 분류 방식에 존재하는 random forest 방식을 이용하여, 오차역전파 함수를 최적화 하는 방법이 있을것 같습니다 +> +> > pso 와 random forest 방식이 매우 유사하다고 생각하여 학습할 때 뿐만 아니라 예측 할 때도 이러한 방식으로 사용할 수 있을 것 같습니다 diff --git a/xor.ipynb b/xor.ipynb new file mode 100644 index 0000000..948dd9a --- /dev/null +++ b/xor.ipynb @@ -0,0 +1,431 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-21 01:52:28.471404: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.10.0\n" + ] + } + ], + "source": [ + "import os\n", + "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", + "import tensorflow as tf\n", + "# tf.random.set_seed(777) # for reproducibility\n", + "\n", + "from pso_tf import PSO\n", + "from tensorflow import keras\n", + "\n", + "print(tf.__version__)\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from tensorflow import keras\n", + "from tensorflow.keras.models import Sequential\n", + "from tensorflow.keras import layers\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def get_data():\n", + " x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n", + " y = np.array([[0], [1], [1], [0]])\n", + " \n", + " return x, y\n", + "\n", + "def make_model():\n", + " leyer = []\n", + " leyer.append(layers.Dense(2, activation='sigmoid', input_shape=(2,)))\n", + " leyer.append(layers.Dense(1, activation='sigmoid'))\n", + "\n", + " model = Sequential(leyer)\n", + "\n", + " sgd = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n", + " adam = keras.optimizers.Adam(lr=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.)\n", + " model.compile(loss='mse', optimizer=sgd, metrics=['accuracy'])\n", + "\n", + " print(model.summary())\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " dense (Dense) (None, 2) 6 \n", + " \n", + " dense_1 (Dense) (None, 1) 3 \n", + " \n", + "=================================================================\n", + "Total params: 9\n", + "Trainable params: 9\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", + " super().__init__(name, **kwargs)\n", + "/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/adam.py:114: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", + " super().__init__(name, **kwargs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "None\n", + "0 particle score : 0.24921603500843048\n", + "1 particle score : 0.2509610056877136\n", + "2 particle score : 0.28712478280067444\n", + "3 particle score : 0.2665291726589203\n", + "4 particle score : 0.2513682246208191\n", + "0 particle score : 0.26079031825065613\n", + "1 particle score : 0.24931921064853668\n", + "2 particle score : 0.2679133415222168\n", + "3 particle score : 0.27925199270248413\n", + "4 particle score : 0.2605195641517639\n", + "0 particle score : 0.30758577585220337\n", + "1 particle score : 0.26747316122055054\n", + "2 particle score : 0.36957648396492004\n", + "3 particle score : 0.19372068345546722\n", + "4 particle score : 0.3671383857727051\n", + "0 particle score : 0.24090810120105743\n", + "1 particle score : 0.3176509141921997\n", + "2 particle score : 0.23225924372673035\n", + "3 particle score : 0.37263113260269165\n", + "4 particle score : 0.47822105884552\n", + "0 particle score : 0.37611791491508484\n", + "1 particle score : 0.27166277170181274\n", + "2 particle score : 0.21416285634040833\n", + "3 particle score : 0.23324625194072723\n", + "4 particle score : 0.024583835154771805\n", + "0 particle score : 0.05194556713104248\n", + "1 particle score : 0.3102635443210602\n", + "2 particle score : 0.31894028186798096\n", + "3 particle score : 0.12679985165596008\n", + "4 particle score : 0.012038745917379856\n", + "0 particle score : 0.004551469348371029\n", + "1 particle score : 0.03923884406685829\n", + "2 particle score : 0.003701586974784732\n", + "3 particle score : 0.0026527238078415394\n", + "4 particle score : 0.0430503748357296\n", + "0 particle score : 0.000214503234019503\n", + "1 particle score : 0.0025649480521678925\n", + "2 particle score : 0.008843829855322838\n", + "3 particle score : 0.23036976158618927\n", + "4 particle score : 0.21686825156211853\n", + "0 particle score : 4.901693273495766e-07\n", + "1 particle score : 0.003860481781885028\n", + "2 particle score : 0.00047884139348752797\n", + "3 particle score : 0.1563722789287567\n", + "4 particle score : 1.1759411222556082e-07\n", + "0 particle score : 0.24969959259033203\n", + "1 particle score : 2.8646991268033162e-05\n", + "2 particle score : 1.0552450024903237e-09\n", + "3 particle score : 3.566808572941227e-07\n", + "4 particle score : 8.882947003831243e-14\n", + "0 particle score : 0.2497878521680832\n", + "1 particle score : 1.879385969766334e-12\n", + "2 particle score : 0.44945281744003296\n", + "3 particle score : 2.485284791549705e-14\n", + "4 particle score : 2.431787924306583e-26\n", + "0 particle score : 3.854774978241029e-18\n", + "1 particle score : 1.2515056546646974e-08\n", + "2 particle score : 0.49999988079071045\n", + "3 particle score : 2.8881452344524906e-22\n", + "4 particle score : 4.162688806996304e-30\n", + "0 particle score : 9.170106118851775e-37\n", + "1 particle score : 0.49933725595474243\n", + "2 particle score : 0.43209874629974365\n", + "3 particle score : 7.681456478781658e-30\n", + "4 particle score : 1.1656278206215614e-33\n", + "0 particle score : 0.0\n", + "1 particle score : 0.49545660614967346\n", + "2 particle score : 0.25\n", + "3 particle score : 0.0\n", + "4 particle score : 0.0\n", + "0 particle score : 0.0\n", + "1 particle score : 0.25\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.0\n", + "1 particle score : 0.0\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.0\n", + "1 particle score : 0.25\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.0\n", + "1 particle score : 0.0\n", + "2 particle score : 0.0\n", + "3 particle score : 0.0\n", + "4 particle score : 0.5\n", + "0 particle score : 1.2923532081356227e-22\n", + "1 particle score : 0.5\n", + "2 particle score : 0.25\n", + "3 particle score : 0.942779541015625\n", + "4 particle score : 0.5\n", + "0 particle score : 0.4959273338317871\n", + "1 particle score : 0.5\n", + "2 particle score : 0.5\n", + "3 particle score : 0.75\n", + "4 particle score : 0.75\n", + "0 particle score : 0.23154164850711823\n", + "1 particle score : 0.5\n", + "2 particle score : 0.5\n", + "3 particle score : 0.5\n", + "4 particle score : 0.5\n", + "0 particle score : 0.0\n", + "1 particle score : 0.5\n", + "2 particle score : 0.25\n", + "3 particle score : 0.0\n", + "4 particle score : 0.25\n", + "0 particle score : 0.0\n", + "1 particle score : 0.5\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.25\n", + "1 particle score : 0.25\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.0\n", + "1 particle score : 0.25\n", + "2 particle score : 0.25\n", + "3 particle score : 0.25\n", + "4 particle score : 0.25\n", + "0 particle score : 0.5760642290115356\n", + "1 particle score : 0.25\n", + "2 particle score : 0.25000467896461487\n", + "3 particle score : 0.5\n", + "4 particle score : 0.5\n", + "0 particle score : 0.5\n", + "1 particle score : 0.25\n", + "2 particle score : 0.4998854398727417\n", + "3 particle score : 0.5\n", + "4 particle score : 0.5\n", + "0 particle score : 0.5\n", + "1 particle score : 0.25\n", + "2 particle score : 0.5000014305114746\n", + "3 particle score : 0.5\n", + "4 particle score : 0.5\n", + "0 particle score : 0.5\n", + "1 particle score : 0.007790721021592617\n", + "2 particle score : 0.5\n", + "3 particle score : 0.75\n", + "4 particle score : 0.5\n", + "0 particle score : 0.25\n", + "1 particle score : 0.5\n", + "2 particle score : 0.5\n", + "3 particle score : 0.0\n", + "4 particle score : 0.5\n", + "1/1 [==============================] - 0s 40ms/step\n", + "[[1.8555788e-26]\n", + " [1.0000000e+00]\n", + " [1.0000000e+00]\n", + " [1.8555788e-26]]\n", + "[[0]\n", + " [1]\n", + " [1]\n", + " [0]]\n", + "history > [[array([[-0.9191145, -0.7256227],\n", + " [ 1.2947526, 1.0081983]], dtype=float32), array([ 0.01203067, -0.07866445], dtype=float32), array([[-0.72274315],\n", + " [ 0.88691926]], dtype=float32), array([-0.08449478], dtype=float32)], [array([[-0.7327981, -2.120965 ],\n", + " [ 3.5870228, 2.0618958]], dtype=float32), array([-0.06788628, -2.1460009 ], dtype=float32), array([[-1.8084345],\n", + " [ 3.2274616]], dtype=float32), array([0.40823892], dtype=float32)], [array([[-6.749437, -5.01979 ],\n", + " [ 9.477569, 9.011221]], dtype=float32), array([ 1.0140182, -5.089527 ], dtype=float32), array([[-5.9527373],\n", + " [ 8.538484 ]], dtype=float32), array([0.8423419], dtype=float32)], [array([[-4.4376955, -7.542317 ],\n", + " [13.042126 , 9.401183 ]], dtype=float32), array([ 1.7249748, -6.2829194], dtype=float32), array([[-4.6019 ],\n", + " [12.654787]], dtype=float32), array([2.11288], dtype=float32)], [array([[-7.9655757, -8.855807 ],\n", + " [14.27012 , 12.6986265]], dtype=float32), array([ 2.2102568, -7.4656196], dtype=float32), array([[-5.386531],\n", + " [16.770058]], dtype=float32), array([2.2161639], dtype=float32)], [array([[-10.937471, -9.346545],\n", + " [ 15.040345, 13.547635]], dtype=float32), array([ 3.7305086, -8.93729 ], dtype=float32), array([[-9.661456],\n", + " [14.314214]], dtype=float32), array([1.9838718], dtype=float32)], [array([[-7.618989 , -8.295806 ],\n", + " [ 9.591193 , 7.3881774]], dtype=float32), array([ 2.9443424, -6.85388 ], dtype=float32), array([[-6.120155],\n", + " [ 9.558391]], dtype=float32), array([2.900807], dtype=float32)], [array([[-12.431582, -14.683373],\n", + " [ 24.192898, 18.607504]], dtype=float32), array([ 4.375762, -11.899742], dtype=float32), array([[-11.140665],\n", + " [ 25.361753]], dtype=float32), array([3.5045836], dtype=float32)], [array([[-16.167437, -19.325432],\n", + " [ 25.197618, 15.928284]], dtype=float32), array([ 6.8536587, -14.406519 ], dtype=float32), array([[-16.149462],\n", + " [ 21.955147]], dtype=float32), array([6.5853295], dtype=float32)], [array([[-20.64401 , -25.207134],\n", + " [ 28.023142, 19.938404]], dtype=float32), array([ 7.2551775, -17.74039 ], dtype=float32), array([[-15.623163],\n", + " [ 30.90391 ]], dtype=float32), array([7.7026973], dtype=float32)], [array([[-27.585245, -28.003128],\n", + " [ 46.606903, 34.010803]], dtype=float32), array([ 9.391173, -25.379646], dtype=float32), array([[-27.2021 ],\n", + " [ 44.79025]], dtype=float32), array([9.642486], dtype=float32)], [array([[-44.09209, -37.20285],\n", + " [ 47.20231, 40.34598]], dtype=float32), array([ 13.101824, -25.8866 ], dtype=float32), array([[-33.470924],\n", + " [ 47.784706]], dtype=float32), array([14.320648], dtype=float32)], [array([[-36.38443 , -39.23304 ],\n", + " [ 52.953644, 38.646732]], dtype=float32), array([ 10.276208, -30.864595], dtype=float32), array([[-31.08338],\n", + " [ 52.16088]], dtype=float32), array([15.342434], dtype=float32)], [array([[-62.84543 , -47.409748],\n", + " [ 63.300335, 56.867214]], dtype=float32), array([ 17.78217, -33.01626], dtype=float32), array([[-48.512455],\n", + " [ 61.87751 ]], dtype=float32), array([19.369736], dtype=float32)], [array([[-71.16499 , -57.702408],\n", + " [ 80.223915, 69.13328 ]], dtype=float32), array([ 19.08833 , -41.566013], dtype=float32), array([[-57.950104],\n", + " [ 76.35351 ]], dtype=float32), array([24.470982], dtype=float32)], [array([[-120.93972, -92.38105],\n", + " [ 107.01377, 110.19025]], dtype=float32), array([ 28.39684 , -59.285316], dtype=float32), array([[-75.1781 ],\n", + " [129.59488]], dtype=float32), array([34.034805], dtype=float32)], [array([[-161.36476, -114.62916],\n", + " [ 142.47905, 152.3887 ]], dtype=float32), array([ 36.139404, -74.1054 ], dtype=float32), array([[-101.517525],\n", + " [ 171.30031 ]], dtype=float32), array([42.26851], dtype=float32)]]\n", + "score > [0.5, 0.5]\n" + ] + } + ], + "source": [ + "\n", + "x, y = get_data()\n", + "x_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n", + "y_test = np.array([[0], [1], [1], [0]])\n", + "\n", + "model = make_model()\n", + "\n", + "loss = keras.losses.MeanSquaredError()\n", + "optimizer = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n", + "\n", + "\n", + "\n", + "pso_xor = PSO(model=model, loss=loss, optimizer=optimizer, n_particles=5)\n", + "\n", + "best_weights, score = pso_xor.optimize(x, y, x_test, y_test, maxiter=30)\n", + "\n", + "model.set_weights(best_weights)\n", + "\n", + "y_pred = model.predict(x_test)\n", + "print(y_pred)\n", + "print(y_test)\n", + "\n", + "history = pso_xor.global_history()\n", + "\n", + "# print(f\"history > {history}\")\n", + "# print(f\"score > {score}\")\n", + "# plt.plot(history)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test():\n", + " model = make_model()\n", + " x, y = get_data()\n", + " x_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n", + " y_test = np.array([[0], [1], [1], [0]])\n", + " \n", + " callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n", + "\n", + " hist = model.fit(x, y, epochs=50000, verbose=1, callbacks=[callback] , validation_data=(x_test, y_test))\n", + " y_pred=model.predict(x_test)\n", + " print(y_pred)\n", + " print(y_test)\n", + " \n", + " return hist" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def plot_history(history):\n", + " fig, loss_ax = plt.subplots()\n", + " acc_ax = loss_ax.twinx()\n", + "\n", + " loss_ax.plot(hist.history['loss'], 'y', label='train loss')\n", + " loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')\n", + " loss_ax.set_xlabel('epoch')\n", + " loss_ax.set_ylabel('loss')\n", + " loss_ax.legend(loc='upper left')\n", + "\n", + " acc_ax.plot(hist.history['accuracy'], 'b', label='train acc')\n", + " acc_ax.plot(hist.history['val_accuracy'], 'g', label='val acc')\n", + " acc_ax.set_ylabel('accuracy')\n", + " acc_ax.legend(loc='upper right')\n", + "\n", + " plt.show()\n", + "hist = test()\n", + "plot_history(hist)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(model(x).numpy())\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pso", + "language": "python", + "name": "pso" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}