Files
PSO/mnist.ipynb
jung-geun 7c5f3a53a3 23-05-21
pso 기본 알고리즘을 이용한 tensorflow model의 pso 알고리즘화 - xor 문제, mnist 분류 성공
2023-05-21 14:00:03 +09:00

365 lines
20 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "8a637c69-9071-4012-ac1e-93037548b3e9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-05-21 03:38:18.127052: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.10.0\n",
"[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n"
]
}
],
"source": [
"import os\n",
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
"\n",
"from pso_tf import PSO\n",
"\n",
"import numpy as np\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"tf.random.set_seed(777) # for reproducibility\n",
"\n",
"from keras.datasets import mnist\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras import backend as K\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from datetime import date\n",
"\n",
"print(tf.__version__)\n",
"print(tf.config.list_physical_devices())\n",
"\n",
"def get_data():\n",
" (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"\n",
" x_train, x_test = x_train / 255.0, x_test / 255.0\n",
" x_train = x_train.reshape((60000, 28 ,28, 1))\n",
" x_test = x_test.reshape((10000, 28 ,28, 1))\n",
"\n",
" print(f\"x_train : {x_train[0].shape} | y_train : {y_train[0].shape}\")\n",
" print(f\"x_test : {x_test[0].shape} | y_test : {y_test[0].shape}\")\n",
" return x_train, y_train, x_test, y_test\n",
"\n",
"def make_model():\n",
" model = Sequential()\n",
" model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n",
" model.add(MaxPooling2D(pool_size=(3, 3)))\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
" model.add(Flatten())\n",
" model.add(Dense(128, activation='relu'))\n",
" model.add(Dense(10, activation='softmax'))\n",
"\n",
" model.compile(loss='sparse_categorical_crossentropy',optimizer='adam', metrics=['accuracy'])\n",
"\n",
" model.summary()\n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a2d9891d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train : (28, 28, 1) | y_train : ()\n",
"x_test : (28, 28, 1) | y_test : ()\n",
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" conv2d (Conv2D) (None, 24, 24, 32) 832 \n",
" \n",
" max_pooling2d (MaxPooling2D (None, 8, 8, 32) 0 \n",
" ) \n",
" \n",
" conv2d_1 (Conv2D) (None, 6, 6, 64) 18496 \n",
" \n",
" max_pooling2d_1 (MaxPooling (None, 3, 3, 64) 0 \n",
" 2D) \n",
" \n",
" dropout (Dropout) (None, 3, 3, 64) 0 \n",
" \n",
" flatten (Flatten) (None, 576) 0 \n",
" \n",
" dense (Dense) (None, 128) 73856 \n",
" \n",
" dense_1 (Dense) (None, 10) 1290 \n",
" \n",
"=================================================================\n",
"Total params: 94,474\n",
"Trainable params: 94,474\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
" super().__init__(name, **kwargs)\n",
"init particles position: 100%|██████████| 2/2 [00:00<00:00, 34.52it/s]\n",
"init velocities: 100%|██████████| 2/2 [00:00<00:00, 1203.19it/s]\n",
"Iteration 0 / 10: 100%|##########| 2/2 [00:48<00:00, 24.43s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.021180160343647003 | acc : 0.9930999875068665\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 1 / 10: 100%|##########| 2/2 [00:46<00:00, 23.46s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.029995786026120186 | acc : 0.9927999973297119\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 2 / 10: 100%|##########| 2/2 [00:47<00:00, 23.57s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.13020965456962585 | acc : 0.9929999709129333\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 3 / 10: 100%|##########| 2/2 [00:47<00:00, 23.62s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.032199352979660034 | acc : 0.9918000102043152\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 4 / 10: 100%|##########| 2/2 [00:47<00:00, 23.66s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.025606701150536537 | acc : 0.9925000071525574\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 5 / 10: 100%|##########| 2/2 [00:47<00:00, 23.64s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.04198306426405907 | acc : 0.9921000003814697\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 6 / 10: 100%|##########| 2/2 [00:47<00:00, 23.69s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.048351287841796875 | acc : 0.9919999837875366\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 7 / 10: 100%|##########| 2/2 [00:47<00:00, 23.73s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.0416271910071373 | acc : 0.9890999794006348\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 8 / 10: 100%|##########| 2/2 [00:47<00:00, 23.70s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.18129077553749084 | acc : 0.9502000212669373\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Iteration 9 / 10: 100%|##########| 2/2 [00:47<00:00, 23.69s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss : 0.8072962760925293 | acc : 0.7225000262260437\n",
"313/313 - 0s - loss: 0.0212 - accuracy: 0.9931 - 290ms/epoch - 926us/step\n",
"Test loss: 0.021180160343647003 / Test accuracy: 0.9930999875068665\n"
]
}
],
"source": [
"x_train, y_train, x_test, y_test = get_data()\n",
"model = make_model()\n",
"\n",
"loss = keras.losses.MeanSquaredError()\n",
"optimizer = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n",
"\n",
"pso_m = PSO(model=model, loss_method=loss, optimizer=optimizer, n_particles=2)\n",
"best_weights, score = pso_m.optimize(x_train, y_train, x_test, y_test, maxiter=10)\n",
"\n",
"model.set_weights(best_weights)\n",
"\n",
"score = model.evaluate(x_test, y_test, verbose=2)\n",
"print(f\"Test loss: {score[0]} / Test accuracy: {score[1]}\")\n",
"\n",
"day = date.today().strftime(\"%Y-%m-%d\")\n",
"\n",
"model.save(f'./model/{day}_mnist.h5')\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1a38f3c1-8291-40d9-838e-4ffbf4578be5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 685us/step\n",
"틀린 것 갯수 > 69\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbtUlEQVR4nO3df2zV9fXH8dctP66AvZfV0t52FFZAQQVKYNA1KuJoaLuEifCHPxMwBgcrRmROg1PRbUkd5qtGw3BZHIxE1JkJKMm6aLFlbgUHShhxNpR0A0NbJkvvLUUulb6/fxDuvFB+fC739vReno/kJvTee3qPHy59entvb33OOScAAPpYlvUCAIArEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmBlovcLaenh4dPnxY2dnZ8vl81usAADxyzqmzs1OFhYXKyjr/45x+F6DDhw+rqKjIeg0AwGU6dOiQRo4ced7L+12AsrOzJZ1ePBAIGG8DAPAqEomoqKgo9vX8fFIWoDVr1uj5559XW1ubSkpK9Morr2jGjBkXnTvzbbdAIECAACCNXexplJS8COGtt97SihUrtGrVKn3yyScqKSlRRUWFjhw5koqbAwCkoZQE6IUXXtDixYt1//3364YbbtCrr76qoUOH6ne/+10qbg4AkIaSHqCTJ09q9+7dKi8v/9+NZGWpvLxcjY2N51w/Go0qEonEnQAAmS/pAfryyy916tQp5efnx52fn5+vtra2c65fU1OjYDAYO/EKOAC4Mpj/IOrKlSsVDodjp0OHDlmvBADoA0l/FVxubq4GDBig9vb2uPPb29sVCoXOub7f75ff70/2GgCAfi7pj4AGDx6sadOmqa6uLnZeT0+P6urqVFZWluybAwCkqZT8HNCKFSu0cOFCffe739WMGTP00ksvqaurS/fff38qbg4AkIZSEqA777xT//nPf/T000+rra1NU6ZMUW1t7TkvTAAAXLl8zjlnvcQ3RSIRBYNBhcNh3gkBANLQpX4dN38VHADgykSAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYGGi9wJXitdde8zzj8/k8z0yZMsXzzNSpUz3PAMDl4hEQAMAEAQIAmEh6gJ555hn5fL6404QJE5J9MwCANJeS54BuvPFGffDBB/+7kYE81QQAiJeSMgwcOFChUCgVnxoAkCFS8hzQ/v37VVhYqDFjxujee+/VwYMHz3vdaDSqSCQSdwIAZL6kB6i0tFTr169XbW2t1q5dq5aWFt1yyy3q7Ozs9fo1NTUKBoOxU1FRUbJXAgD0Qz7nnEvlDXR0dGj06NF64YUX9MADD5xzeTQaVTQajX0ciURUVFSkcDisQCCQytX6FD8HBOBKEYlEFAwGL/p1POWvDhg+fLiuu+46NTc393q53++X3+9P9RoAgH4m5T8HdOzYMR04cEAFBQWpvikAQBpJeoAeffRRNTQ06F//+pf+9re/6Y477tCAAQN09913J/umAABpLOnfgvviiy9099136+jRoxoxYoRuvvlm7dixQyNGjEj2TQEA0ljKX4Tg1aU+eZVusrK8P9hM5EUIifzQ79ChQz3P4PIk8s9u7dq1nmf68vnVd9991/PMj370I88zZWVlnmfQty716zjvBQcAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmEj5L6RD3/r66689z0QikRRsggtJ5M1I77vvvhRsYqutrc3zTG1tbQo2gQUeAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAE74bdR7Zs2eJ55t13303BJslz5MgRzzNbt25NwSYA0hGPgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEz7nnLNe4psikYiCwaDC4bACgYD1OriAzz77zPPMpEmTUrBJ+hk7dqznmalTp6Zgk94l8veUl5fneeaHP/yh55n8/HzPM+hbl/p1nEdAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJgdYLIH21trZar5B0AwYM8Dzz0ksveZ656667PM/k5OR4ngH6Mx4BAQBMECAAgAnPAdq+fbvmzp2rwsJC+Xw+bd68Oe5y55yefvppFRQUaMiQISovL9f+/fuTtS8AIEN4DlBXV5dKSkq0Zs2aXi9fvXq1Xn75Zb366qvauXOnhg0bpoqKCp04ceKylwUAZA7PL0KoqqpSVVVVr5c55/TSSy/pySef1O233y5J2rBhg/Lz87V58+aEnngFAGSmpD4H1NLSora2NpWXl8fOCwaDKi0tVWNjY68z0WhUkUgk7gQAyHxJDVBbW5ukc39ne35+fuyys9XU1CgYDMZORUVFyVwJANBPmb8KbuXKlQqHw7HToUOHrFcCAPSBpAYoFApJktrb2+POb29vj112Nr/fr0AgEHcCAGS+pAaouLhYoVBIdXV1sfMikYh27typsrKyZN4UACDNeX4V3LFjx9Tc3Bz7uKWlRXv27FFOTo5GjRql5cuX65e//KWuvfZaFRcX66mnnlJhYaHmzZuXzL0BAGnOc4B27dql2267LfbxihUrJEkLFy7U+vXr9dhjj6mrq0sPPvigOjo6dPPNN6u2tlZXXXVV8rYGAKQ9n3POWS/xTZFIRMFgUOFwmOeD+kg0Gk1obtasWZ5nPv7444Ruy6sRI0YkNFdbW+t5ZsqUKQndFpCpLvXruPmr4AAAVyYCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY8PzrGJB5Pv/884Tm9uzZk9xFkuj48eMJzW3YsKFPZqqqqjzPlJeXe57x+XyeZ4C+wiMgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCEzznnrJf4pkgkomAwqHA4rEAgYL0OLuC3v/2t55klS5akYJP0k8g/uyeeeMLzzKBBgzzPSNJDDz3keWbYsGGeZxLZLyuL/2/u7y716zh/kwAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACd6MFAnr6OjwPLN582bPM2vXrvU8s2vXLs8zfSmRf3Y+ny8Fm9iqrq72PPOzn/3M80x+fr7nGSSONyMFAPRrBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJ3owU+IZE3sT0L3/5i+eZ7du3e57ZsmWL55lMlMiXrKVLlyZ0WzfffLPnmXvuuSeh28okvBkpAKBfI0AAABOeA7R9+3bNnTtXhYWF8vl85/x+l0WLFsnn88WdKisrk7UvACBDeA5QV1eXSkpKtGbNmvNep7KyUq2trbHTG2+8cVlLAgAyz0CvA1VVVaqqqrrgdfx+v0KhUMJLAQAyX0qeA6qvr1deXp7Gjx+vpUuX6ujRo+e9bjQaVSQSiTsBADJf0gNUWVmpDRs2qK6uTr/61a/U0NCgqqoqnTp1qtfr19TUKBgMxk5FRUXJXgkA0A95/hbcxdx1112xP0+aNEmTJ0/W2LFjVV9fr9mzZ59z/ZUrV2rFihWxjyORCBECgCtAyl+GPWbMGOXm5qq5ubnXy/1+vwKBQNwJAJD5Uh6gL774QkePHlVBQUGqbwoAkEY8fwvu2LFjcY9mWlpatGfPHuXk5CgnJ0fPPvusFixYoFAopAMHDuixxx7TuHHjVFFRkdTFAQDpzXOAdu3apdtuuy328ZnnbxYuXKi1a9dq7969+v3vf6+Ojg4VFhZqzpw5+sUvfiG/35+8rQEAaY83IwUMnO9VoRfy9ddfe575zW9+43lGSuxNWV9//fWEbqs/GzJkiOeZDRs2eJ6ZP3++55n+jDcjBQD0awQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBu2EDOEd3d7fnmRMnTnie+e9//+t5Zvbs2Z5nWlpaPM/0pUTeHb0/492wAQD9GgECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgYqD1AgD6n0GDBvXJTHt7u+eZcePGeZ7p729GeqXiERAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYII3I4Vqa2sTmvu///s/zzNTpkzxPPP88897nsFp+/fvT2iuu7vb88yLL77oeeaPf/yj55lwOOx5pi/dcMMN1iukDR4BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmeDPSDHP8+HHPMw8//HBCt9Xc3Ox55pNPPvE8U1lZ6Xlm3LhxnmcS9ec//9nzzAcffOB5xufzeZ557733PM9IUjQaTWgu0wwbNszzzEcffZSCTTITj4AAACYIEADAhKcA1dTUaPr06crOzlZeXp7mzZunpqamuOucOHFC1dXVuuaaa3T11VdrwYIFam9vT+rSAID05ylADQ0Nqq6u1o4dO/T++++ru7tbc+bMUVdXV+w6jzzyiN577z29/fbbamho0OHDhzV//vykLw4ASG+eXoRw9m/OXL9+vfLy8rR7927NnDlT4XBYr732mjZu3Kjvf//7kqR169bp+uuv144dO/S9730veZsDANLaZT0HdOZX4+bk5EiSdu/ere7ubpWXl8euM2HCBI0aNUqNjY29fo5oNKpIJBJ3AgBkvoQD1NPTo+XLl+umm27SxIkTJUltbW0aPHiwhg8fHnfd/Px8tbW19fp5ampqFAwGY6eioqJEVwIApJGEA1RdXa19+/bpzTffvKwFVq5cqXA4HDsdOnTosj4fACA9JPSDqMuWLdPWrVu1fft2jRw5MnZ+KBTSyZMn1dHREfcoqL29XaFQqNfP5ff75ff7E1kDAJDGPD0Ccs5p2bJl2rRpk7Zt26bi4uK4y6dNm6ZBgwaprq4udl5TU5MOHjyosrKy5GwMAMgInh4BVVdXa+PGjdqyZYuys7Njz+sEg0ENGTJEwWBQDzzwgFasWKGcnBwFAgE99NBDKisr4xVwAIA4ngK0du1aSdKsWbPizl+3bp0WLVokSXrxxReVlZWlBQsWKBqNqqKiQr/+9a+TsiwAIHP4nHPOeolvikQiCgaDCofDCgQC1uuknb///e+eZ2699daEbos3rExcT0+P55msrMx756whQ4Z4npkwYYLnmdzcXM8zkvTcc895npkyZUpCt5VJLvXreObdowEAaYEAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmEvqNqOi/pk+f7nmmoqIiodvauXOn55n29vaEbguJGTZsWEJzibx79JNPPul5ZurUqZ5neLfpzMEjIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABG9GCm3atCmhuUgk4nlm7ty5nmf+8Y9/eJ5J1MKFCz3P3HrrrSnYJDnGjx+f0Nz111+f5E2Ac/EICABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwZuRImGBQMDzTENDQwo2AZCOeAQEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHgKUE1NjaZPn67s7Gzl5eVp3rx5ampqirvOrFmz5PP54k5LlixJ6tIAgPTnKUANDQ2qrq7Wjh079P7776u7u1tz5sxRV1dX3PUWL16s1tbW2Gn16tVJXRoAkP48/UbU2trauI/Xr1+vvLw87d69WzNnzoydP3ToUIVCoeRsCADISJf1HFA4HJYk5eTkxJ3/+uuvKzc3VxMnTtTKlSt1/Pjx836OaDSqSCQSdwIAZD5Pj4C+qaenR8uXL9dNN92kiRMnxs6/5557NHr0aBUWFmrv3r16/PHH1dTUpHfeeafXz1NTU6Nnn3020TUAAGnK55xziQwuXbpUf/rTn/TRRx9p5MiR573etm3bNHv2bDU3N2vs2LHnXB6NRhWNRmMfRyIRFRUVKRwOKxAIJLIaAMBQJBJRMBi86NfxhB4BLVu2TFu3btX27dsvGB9JKi0tlaTzBsjv98vv9yeyBgAgjXkKkHNODz30kDZt2qT6+noVFxdfdGbPnj2SpIKCgoQWBABkJk8Bqq6u1saNG7VlyxZlZ2erra1NkhQMBjVkyBAdOHBAGzdu1A9+8ANdc8012rt3rx555BHNnDlTkydPTsl/AAAgPXl6Dsjn8/V6/rp167Ro0SIdOnRI9913n/bt26euri4VFRXpjjvu0JNPPnnJz+dc6vcOAQD9U0qeA7pYq4qKitTQ0ODlUwIArlC8FxwAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMRA6wXO5pyTJEUiEeNNAACJOPP1+8zX8/PpdwHq7OyUJBUVFRlvAgC4HJ2dnQoGg+e93Oculqg+1tPTo8OHDys7O1s+ny/uskgkoqKiIh06dEiBQMBoQ3sch9M4DqdxHE7jOJzWH46Dc06dnZ0qLCxUVtb5n+npd4+AsrKyNHLkyAteJxAIXNF3sDM4DqdxHE7jOJzGcTjN+jhc6JHPGbwIAQBgggABAEykVYD8fr9WrVolv99vvYopjsNpHIfTOA6ncRxOS6fj0O9ehAAAuDKk1SMgAEDmIEAAABMECABgggABAEykTYDWrFmj73znO7rqqqtUWlqqjz/+2HqlPvfMM8/I5/PFnSZMmGC9Vspt375dc+fOVWFhoXw+nzZv3hx3uXNOTz/9tAoKCjRkyBCVl5dr//79Nsum0MWOw6JFi865f1RWVtosmyI1NTWaPn26srOzlZeXp3nz5qmpqSnuOidOnFB1dbWuueYaXX311VqwYIHa29uNNk6NSzkOs2bNOuf+sGTJEqONe5cWAXrrrbe0YsUKrVq1Sp988olKSkpUUVGhI0eOWK/W52688Ua1trbGTh999JH1SinX1dWlkpISrVmzptfLV69erZdfflmvvvqqdu7cqWHDhqmiokInTpzo401T62LHQZIqKyvj7h9vvPFGH26Yeg0NDaqurtaOHTv0/vvvq7u7W3PmzFFXV1fsOo888ojee+89vf3222poaNDhw4c1f/58w62T71KOgyQtXrw47v6wevVqo43Pw6WBGTNmuOrq6tjHp06dcoWFha6mpsZwq763atUqV1JSYr2GKUlu06ZNsY97enpcKBRyzz//fOy8jo4O5/f73RtvvGGwYd84+zg459zChQvd7bffbrKPlSNHjjhJrqGhwTl3+u9+0KBB7u23345d55///KeT5BobG63WTLmzj4Nzzt16663u4YcftlvqEvT7R0AnT57U7t27VV5eHjsvKytL5eXlamxsNNzMxv79+1VYWKgxY8bo3nvv1cGDB61XMtXS0qK2tra4+0cwGFRpaekVef+or69XXl6exo8fr6VLl+ro0aPWK6VUOByWJOXk5EiSdu/ere7u7rj7w4QJEzRq1KiMvj+cfRzOeP3115Wbm6uJEydq5cqVOn78uMV659Xv3oz0bF9++aVOnTql/Pz8uPPz8/P1+eefG21lo7S0VOvXr9f48ePV2tqqZ599Vrfccov27dun7Oxs6/VMtLW1SVKv948zl10pKisrNX/+fBUXF+vAgQN64oknVFVVpcbGRg0YMMB6vaTr6enR8uXLddNNN2nixImSTt8fBg8erOHDh8ddN5PvD70dB0m65557NHr0aBUWFmrv3r16/PHH1dTUpHfeecdw23j9PkD4n6qqqtifJ0+erNLSUo0ePVp/+MMf9MADDxhuhv7grrvuiv150qRJmjx5ssaOHav6+nrNnj3bcLPUqK6u1r59+66I50Ev5HzH4cEHH4z9edKkSSooKNDs2bN14MABjR07tq/X7FW//xZcbm6uBgwYcM6rWNrb2xUKhYy26h+GDx+u6667Ts3NzdarmDlzH+D+ca4xY8YoNzc3I+8fy5Yt09atW/Xhhx/G/fqWUCikkydPqqOjI+76mXp/ON9x6E1paakk9av7Q78P0ODBgzVt2jTV1dXFzuvp6VFdXZ3KysoMN7N37NgxHThwQAUFBdarmCkuLlYoFIq7f0QiEe3cufOKv3988cUXOnr0aEbdP5xzWrZsmTZt2qRt27apuLg47vJp06Zp0KBBcfeHpqYmHTx4MKPuDxc7Dr3Zs2ePJPWv+4P1qyAuxZtvvun8fr9bv369++yzz9yDDz7ohg8f7tra2qxX61M/+clPXH19vWtpaXF//etfXXl5ucvNzXVHjhyxXi2lOjs73aeffuo+/fRTJ8m98MIL7tNPP3X//ve/nXPOPffcc2748OFuy5Ytbu/eve722293xcXF7quvvjLePLkudBw6Ozvdo48+6hobG11LS4v74IMP3NSpU921117rTpw4Yb160ixdutQFg0FXX1/vWltbY6fjx4/HrrNkyRI3atQot23bNrdr1y5XVlbmysrKDLdOvosdh+bmZvfzn//c7dq1y7W0tLgtW7a4MWPGuJkzZxpvHi8tAuScc6+88oobNWqUGzx4sJsxY4bbsWOH9Up97s4773QFBQVu8ODB7tvf/ra78847XXNzs/VaKffhhx86SeecFi5c6Jw7/VLsp556yuXn5zu/3+9mz57tmpqabJdOgQsdh+PHj7s5c+a4ESNGuEGDBrnRo0e7xYsXZ9z/pPX23y/JrVu3Lnadr776yv34xz923/rWt9zQoUPdHXfc4VpbW+2WToGLHYeDBw+6mTNnupycHOf3+924cePcT3/6UxcOh20XPwu/jgEAYKLfPwcEAMhMBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJ/wdNoTZjrQGkuQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# print(f\"정답 > {y_test}\")\n",
"predicted_result = model.predict(x_test)\n",
"predicted_labels = np.argmax(predicted_result, axis=1)\n",
"not_correct = []\n",
"for i in range(len(y_test)):\n",
" if predicted_labels[i] != y_test[i]:\n",
" not_correct.append(i)\n",
" # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n",
" \n",
"print(f\"틀린 것 갯수 > {len(not_correct)}\")\n",
"for i in range(3):\n",
" plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc2d7044",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pso",
"language": "python",
"name": "pso"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"state": {},
"version_major": 2,
"version_minor": 0
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}