{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8a637c69-9071-4012-ac1e-93037548b3e9", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-05-24 15:37:52.889357: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.10.0\n", "[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" ] } ], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", "\n", "import tensorflow as tf\n", "# tf.random.set_seed(777) # for reproducibility\n", "\n", "from tensorflow import keras\n", "from keras.datasets import mnist\n", "from keras.models import Sequential\n", "from keras.layers import Dense, Dropout, Flatten\n", "from keras.layers import Conv2D, MaxPooling2D\n", "from keras import backend as K\n", "\n", "from pso_tf import PSO\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from datetime import date\n", "from tqdm import tqdm\n", "import json\n", "\n", "print(tf.__version__)\n", "print(tf.config.list_physical_devices())\n", "\n", "def get_data():\n", " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "\n", " x_train, x_test = x_train / 255.0, x_test / 255.0\n", " x_train = x_train.reshape((60000, 28 ,28, 1))\n", " x_test = x_test.reshape((10000, 28 ,28, 1))\n", "\n", " print(f\"x_train : {x_train[0].shape} | y_train : {y_train[0].shape}\")\n", " print(f\"x_test : {x_test[0].shape} | y_test : {y_test[0].shape}\")\n", " return x_train, y_train, x_test, y_test\n", "\n", "def make_model():\n", " model = Sequential()\n", " model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n", " model.add(MaxPooling2D(pool_size=(3, 3)))\n", " model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n", " model.add(MaxPooling2D(pool_size=(2, 2)))\n", " model.add(Dropout(0.25))\n", " model.add(Flatten())\n", " model.add(Dense(128, activation='relu'))\n", " model.add(Dense(10, activation='softmax'))\n", "\n", " # model.summary()\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": 2, "id": "a2d9891d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.65it/s]\n", "init velocities: 100%|██████████| 30/30 [00:00<00:00, 681.12it/s]\n", "Iter 0/20: 13%|#3 | 4/30 [00:04<00:20, 1.28it/s]" ] } ], "source": [ "'''\n", "optimizer parameter\n", "'''\n", "lr = 0.1\n", "momentun = 0.8\n", "decay = 1e-04\n", "nestrov = True\n", "\n", "'''\n", "pso parameter\n", "'''\n", "n_particles = 30\n", "maxiter = 20\n", "# epochs = 1\n", "w = 0.8\n", "c0 = 0.6\n", "c1 = 1.6\n", "\n", "\n", "x_train, y_train, x_test, y_test = get_data()\n", "model = make_model()\n", "\n", "loss = keras.losses.MeanSquaredError()\n", "\n", "\n", "pso_m = PSO(model=model, loss_method=loss, n_particles=n_particles)\n", "# c0 : 지역 최적값 중요도\n", "# c1 : 전역 최적값 중요도\n", "# w : 관성 (현재 속도를 유지하는 정도)\n", "best_weights, score = pso_m.optimize(x_train, y_train, x_test, y_test, maxiter=maxiter, c0=c0, c1=c1, w=w)\n", "model.set_weights(best_weights)\n", "\n", "score_ = model.evaluate(x_test, y_test, verbose=2)\n", "print(f\" Test loss: {score_}\")\n", "score = round(score_[1]*100, 2)\n", "\n", "day = date.today().strftime(\"%Y-%m-%d\")\n", "\n", "os.makedirs(f'./model', exist_ok=True)\n", "model.save(f'./model/{day}_{score}_mnist.h5')\n", "json_save = {\n", " \"name\" : f\"{day}_{score}_mnist.h5\",\n", " \"score\" : score_,\n", " \"maxiter\" : maxiter,\n", " \"c0\" : c0,\n", " \"c1\" : c1,\n", " \"w\" : w \n", "}\n", "with open(f'./model/{day}_{score}_mnist.json', 'a') as f:\n", " json.dump(json_save, f)\n", " f.write(',\\n')\n", "\n", "\n", "# auto_tuning(n_particles=30, maxiter=1000, c0=0.5, c1=1.5, w=0.75)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1af7569b", "metadata": {}, "outputs": [], "source": [ "loss_, acc_ = pso_m.all_history()\n", "\n", "plt.subplot(2,1,1)\n", "for layer in all_loss:\n", " plt.plot(layer)\n", "plt.title('loss history')\n", "\n", "plt.subplot(2,1,2)\n", "for layer in all_acc:\n", " plt.plot(layer)\n", "plt.title('acc history')" ] }, { "cell_type": "code", "execution_count": 3, "id": "1a38f3c1-8291-40d9-838e-4ffbf4578be5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", " super().__init__(name, **kwargs)\n", "init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.95it/s]\n", "init velocities: 100%|██████████| 30/30 [00:00<00:00, 1399.35it/s]\n", "Iter 0/50: 100%|##########| 30/30 [00:15<00:00, 1.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9084339777628581 | acc avg : 0.0019799999892711638 | best loss : 0.15219999849796295\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 1/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9090563456217448 | acc avg : 0.0031199999153614043 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 2/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9103448867797852 | acc avg : 0.005286666750907898 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 3/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113266626993816 | acc avg : 0.004926666617393494 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 4/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113243738810222 | acc avg : 0.004126666734615962 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 5/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113284428914388 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 6/50: 100%|##########| 30/30 [00:11<00:00, 2.51it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113288243611654 | acc avg : 0.0034666667381922406 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 7/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113253911336263 | acc avg : 0.0029633333285649615 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 8/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113227208455403 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 9/50: 100%|##########| 30/30 [00:11<00:00, 2.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911251958211263 | acc avg : 0.005486666659514109 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 10/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113367716471354 | acc avg : 0.004316666722297668 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 11/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113431294759115 | acc avg : 0.002943333238363266 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 12/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113429387410482 | acc avg : 0.004413333535194397 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 13/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113424936930339 | acc avg : 0.004670000076293946 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 14/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113433202107747 | acc avg : 0.0024433332184950513 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 15/50: 100%|##########| 30/30 [00:12<00:00, 2.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113349914550781 | acc avg : 0.0030966666837533314 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 16/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 17/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002806666741768519 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 18/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002503333240747452 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 19/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113413492838541 | acc avg : 0.003179999937613805 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 20/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004823333521684011 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 21/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003663333257039388 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 22/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002916666616996129 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 23/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0026966666181882223 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 24/50: 100%|##########| 30/30 [00:12<00:00, 2.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028999999165534975 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 25/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028833332161108654 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 26/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0027433333297570547 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 27/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0024033332864443462 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 28/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004453333218892416 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 29/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.00338333323597908 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 30/50: 100%|##########| 30/30 [00:12<00:00, 2.44it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028333333631356556 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 31/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002480000009139379 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 32/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0030733334521452584 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 33/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028366667528947195 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 34/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002760000030199687 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 35/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002463333308696747 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 36/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004286666711171468 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 37/50: 100%|##########| 30/30 [00:12<00:00, 2.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003916666656732559 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 38/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0037066665788491565 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 39/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003233333428700765 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 40/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0020900001128514607 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 41/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 42/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028566665947437286 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 43/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0026866666972637176 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 44/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0045466666420300806 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 45/50: 100%|##########| 30/30 [00:12<00:00, 2.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004050000011920929 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 46/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0037399999797344207 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 47/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.00264999990661939 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 48/50: 100%|##########| 30/30 [00:11<00:00, 2.61it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0029666667183240254 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 49/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002906666696071625 | best loss : 0.20180000364780426\n", "313/313 - 0s - loss: 27.3092 - accuracy: 0.2018 - 247ms/epoch - 788us/step\n", " Test loss: [27.309202194213867, 0.20180000364780426]\n", "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n", "313/313 [==============================] - 0s 691us/step\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "진행도: 100%|██████████| 10000/10000 [00:00<00:00, 2226867.00it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "틀린 갯수 > 7982/10000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAa90lEQVR4nO3df2zU9R3H8deB9EBtj5XaXjsKK4igAl3GoDYKA2mAmhh+/QHqEjAEIytm0DlNnQL+SLpB5hys0y3ZQBcBdRGIJGKw2BK3FgJKCJnrKOukhLYoSXul0IL0sz8Itx20wve467s9no/km9C776fft1/PPv1yxxefc84JAIAe1s96AADAzYkAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE7dYD3Clzs5OnTx5UsnJyfL5fNbjAAA8cs6ptbVVWVlZ6tev++ucXhegkydPKjs723oMAMANqq+v19ChQ7t9vtcFKDk5WdKlwVNSUoynAQB4FQqFlJ2dHf553p24BaisrEzr1q1TY2OjcnNztWHDBk2aNOma6y7/tltKSgoBAoA+7Fpvo8TlQwjvvPOOiouLtXr1an322WfKzc3VzJkzderUqXgcDgDQB8UlQK+++qqWLl2qxx9/XPfcc4/eeOMN3Xrrrfrzn/8cj8MBAPqgmAfo/PnzOnjwoAoKCv53kH79VFBQoKqqqqv27+joUCgUitgAAIkv5gH6+uuvdfHiRWVkZEQ8npGRocbGxqv2Ly0tVSAQCG98Ag4Abg7mfxC1pKRELS0t4a2+vt56JABAD4j5p+DS0tLUv39/NTU1RTze1NSkYDB41f5+v19+vz/WYwAAermYXwElJSVpwoQJKi8vDz/W2dmp8vJy5efnx/pwAIA+Ki5/Dqi4uFiLFi3SD3/4Q02aNEmvvfaa2tra9Pjjj8fjcACAPiguAVqwYIG++uorrVq1So2Njfr+97+vXbt2XfXBBADAzcvnnHPWQ/y/UCikQCCglpYW7oQAAH3Q9f4cN/8UHADg5kSAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYuMV6AOBa6uvrPa+ZNm1aVMeqra2Nah2ic+TIEc9rhg0b5nlNSkqK5zWIP66AAAAmCBAAwETMA7RmzRr5fL6IbcyYMbE+DACgj4vLe0D33nuvPv744/8d5BbeagIARIpLGW655RYFg8F4fGsAQIKIy3tAR48eVVZWlkaMGKHHHntMx48f73bfjo4OhUKhiA0AkPhiHqC8vDxt2rRJu3bt0uuvv666ujpNnjxZra2tXe5fWlqqQCAQ3rKzs2M9EgCgF/I551w8D9Dc3Kzhw4fr1Vdf1ZIlS656vqOjQx0dHeGvQ6GQsrOz1dLSwmf3IYk/B5TI+HNAiSkUCikQCFzz53jcPx0wePBg3XXXXd3+h+33++X3++M9BgCgl4n7nwM6c+aMjh07pszMzHgfCgDQh8Q8QE8//bQqKyv1n//8R3//+981d+5c9e/fX4888kisDwUA6MNi/ltwJ06c0COPPKLTp0/rjjvu0AMPPKDq6mrdcccdsT4UAKAPi3mAtm7dGutviZvc7t27Pa9pb2+PwySItb/+9a+e13z11Vee15SVlXleg/jjXnAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIm4/4V0wP/r7Oz0vGbbtm1xmAS9weTJkz2v+cUvfuF5zfnz5z2vkaSkpKSo1uH6cAUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9wNGz3qiy++8Lzmww8/9Lxm3bp1nteg5506dcrzmgMHDnhe880333heI3E37HjjCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHNSBG1hoYGz2sefPBBz2vuuecez2uKioo8r0HPe/fdd61HgCGugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFFF75ZVXPK9pbW31vGb//v2e1yQlJXlegxtz7tw5z2u2b9/ueU2/fvx/c6Lg3yQAwAQBAgCY8BygvXv36uGHH1ZWVpZ8Pt9Vl9DOOa1atUqZmZkaNGiQCgoKdPTo0VjNCwBIEJ4D1NbWptzcXJWVlXX5/Nq1a7V+/Xq98cYb2rdvn2677TbNnDlT7e3tNzwsACBxeP4QQmFhoQoLC7t8zjmn1157Tc8//7xmz54tSXrrrbeUkZGh7du3a+HChTc2LQAgYcT0PaC6ujo1NjaqoKAg/FggEFBeXp6qqqq6XNPR0aFQKBSxAQASX0wD1NjYKEnKyMiIeDwjIyP83JVKS0sVCATCW3Z2dixHAgD0UuafgispKVFLS0t4q6+vtx4JANADYhqgYDAoSWpqaop4vKmpKfzclfx+v1JSUiI2AEDii2mAcnJyFAwGVV5eHn4sFApp3759ys/Pj+WhAAB9nOdPwZ05c0a1tbXhr+vq6nTo0CGlpqZq2LBhWrFihV555RWNGjVKOTk5euGFF5SVlaU5c+bEcm4AQB/nOUAHDhzQtGnTwl8XFxdLkhYtWqRNmzbpmWeeUVtbm5544gk1NzfrgQce0K5duzRw4MDYTQ0A6PM8B2jq1KlyznX7vM/n00svvaSXXnrphgZDz6muro5q3dtvv+15zbhx4zyvGT58uOc16Hm//e1vPa+J5sai8+bN87zG7/d7XoP4M/8UHADg5kSAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATnu+GjcTz1ltvRbXuzJkzntc899xzUR0LPau5udnzmg0bNnhe079/f89rXn755R45DuKPKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQ3I00w7e3tntd89NFHcZika7Nnz+6xYyF6Gzdu9LymqanJ85oJEyZ4XjNmzBjPa9A7cQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqQJ5uLFi57XfPnll1Edq6ioKKp16P2OHj3aI8eZOHFijxwHvRNXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GmmCSkpI8r5k8eXJUx9q/f7/nNefOnfO8ZtCgQZ7X4JK2trao1v3hD3+I8SRdKygo6JHjoHfiCggAYIIAAQBMeA7Q3r179fDDDysrK0s+n0/bt2+PeH7x4sXy+XwR26xZs2I1LwAgQXgOUFtbm3Jzc1VWVtbtPrNmzVJDQ0N427Jlyw0NCQBIPJ4/hFBYWKjCwsJv3cfv9ysYDEY9FAAg8cXlPaCKigqlp6dr9OjRWrZsmU6fPt3tvh0dHQqFQhEbACDxxTxAs2bN0ltvvaXy8nL96le/UmVlpQoLC3Xx4sUu9y8tLVUgEAhv2dnZsR4JANALxfzPAS1cuDD863Hjxmn8+PEaOXKkKioqNH369Kv2LykpUXFxcfjrUChEhADgJhD3j2GPGDFCaWlpqq2t7fJ5v9+vlJSUiA0AkPjiHqATJ07o9OnTyszMjPehAAB9iOffgjtz5kzE1UxdXZ0OHTqk1NRUpaam6sUXX9T8+fMVDAZ17NgxPfPMM7rzzjs1c+bMmA4OAOjbPAfowIEDmjZtWvjry+/fLFq0SK+//roOHz6sN998U83NzcrKytKMGTP08ssvy+/3x25qAECf5zlAU6dOlXOu2+c/+uijGxoIN2bAgAGe19x9991RHeuPf/yj5zVz5871vGb16tWe1/R2n332mec1//rXvzyv+fe//+15jST5fL6o1vXW46B34l5wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBHzv5Ibfc+aNWuiWvdtd0Xvzl/+8hfPayZPnux5TW+XkZHheU00d45uamryvKYnPfTQQ9YjwBBXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACZ+L5o6ScRQKhRQIBNTS0qKUlBTrcRBjJ06c6JE1vd19993XI8cpLi6Oat369etjPEnXvvnmmx45DnrW9f4c5woIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBxi/UAuLkMHTq0R9bgklGjRlmP8K0aGho8r8nMzIzDJLDAFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKbkQIJzDnXo+u84saiNzeugAAAJggQAMCEpwCVlpZq4sSJSk5OVnp6uubMmaOampqIfdrb21VUVKQhQ4bo9ttv1/z589XU1BTToQEAfZ+nAFVWVqqoqEjV1dXavXu3Lly4oBkzZqitrS28z8qVK/XBBx/ovffeU2VlpU6ePKl58+bFfHAAQN/m6UMIu3btivh606ZNSk9P18GDBzVlyhS1tLToT3/6kzZv3qwHH3xQkrRx40bdfffdqq6u1n333Re7yQEAfdoNvQfU0tIiSUpNTZUkHTx4UBcuXFBBQUF4nzFjxmjYsGGqqqrq8nt0dHQoFApFbACAxBd1gDo7O7VixQrdf//9Gjt2rCSpsbFRSUlJGjx4cMS+GRkZamxs7PL7lJaWKhAIhLfs7OxoRwIA9CFRB6ioqEhHjhzR1q1bb2iAkpIStbS0hLf6+vob+n4AgL4hqj+Iunz5cu3cuVN79+7V0KFDw48Hg0GdP39ezc3NEVdBTU1NCgaDXX4vv98vv98fzRgAgD7M0xWQc07Lly/Xtm3btGfPHuXk5EQ8P2HCBA0YMEDl5eXhx2pqanT8+HHl5+fHZmIAQELwdAVUVFSkzZs3a8eOHUpOTg6/rxMIBDRo0CAFAgEtWbJExcXFSk1NVUpKip566inl5+fzCTgAQARPAXr99dclSVOnTo14fOPGjVq8eLEk6Te/+Y369eun+fPnq6OjQzNnztTvf//7mAwLAEgcngJ0PTcoHDhwoMrKylRWVhb1UABiw+fz9eg6wAvuBQcAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATUf2NqAD6hnPnzvXYsQYNGtRjx0Ji4AoIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDBzUiBBPbrX/86qnVDhgzxvOZ3v/tdVMfCzYsrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABDcjBRJYQUFBVOtKSko8rxkzZkxUx8LNiysgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAENyMFEtibb75pPQLQLa6AAAAmCBAAwISnAJWWlmrixIlKTk5Wenq65syZo5qamoh9pk6dKp/PF7E9+eSTMR0aAND3eQpQZWWlioqKVF1drd27d+vChQuaMWOG2traIvZbunSpGhoawtvatWtjOjQAoO/z9CGEXbt2RXy9adMmpaen6+DBg5oyZUr48VtvvVXBYDA2EwIAEtINvQfU0tIiSUpNTY14/O2331ZaWprGjh2rkpISnT17ttvv0dHRoVAoFLEBABJf1B/D7uzs1IoVK3T//fdr7Nix4ccfffRRDR8+XFlZWTp8+LCeffZZ1dTU6P333+/y+5SWlurFF1+MdgwAQB/lc865aBYuW7ZMH374oT799FMNHTq02/327Nmj6dOnq7a2ViNHjrzq+Y6ODnV0dIS/DoVCys7OVktLi1JSUqIZDQBgKBQKKRAIXPPneFRXQMuXL9fOnTu1d+/eb42PJOXl5UlStwHy+/3y+/3RjAEA6MM8Bcg5p6eeekrbtm1TRUWFcnJyrrnm0KFDkqTMzMyoBgQAJCZPASoqKtLmzZu1Y8cOJScnq7GxUZIUCAQ0aNAgHTt2TJs3b9ZDDz2kIUOG6PDhw1q5cqWmTJmi8ePHx+UfAADQN3l6D8jn83X5+MaNG7V48WLV19frxz/+sY4cOaK2tjZlZ2dr7ty5ev7556/7/Zzr/b1DAEDvFJf3gK7VquzsbFVWVnr5lgCAmxT3ggMAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmLjFeoArOeckSaFQyHgSAEA0Lv/8vvzzvDu9LkCtra2SpOzsbONJAAA3orW1VYFAoNvnfe5aiephnZ2dOnnypJKTk+Xz+SKeC4VCys7OVn19vVJSUowmtMd5uITzcAnn4RLOwyW94Tw459Ta2qqsrCz169f9Oz297gqoX79+Gjp06Lfuk5KSclO/wC7jPFzCebiE83AJ5+ES6/PwbVc+l/EhBACACQIEADDRpwLk9/u1evVq+f1+61FMcR4u4Txcwnm4hPNwSV86D73uQwgAgJtDn7oCAgAkDgIEADBBgAAAJggQAMBEnwlQWVmZvve972ngwIHKy8vT/v37rUfqcWvWrJHP54vYxowZYz1W3O3du1cPP/ywsrKy5PP5tH379ojnnXNatWqVMjMzNWjQIBUUFOjo0aM2w8bRtc7D4sWLr3p9zJo1y2bYOCktLdXEiROVnJys9PR0zZkzRzU1NRH7tLe3q6ioSEOGDNHtt9+u+fPnq6mpyWji+Lie8zB16tSrXg9PPvmk0cRd6xMBeuedd1RcXKzVq1frs88+U25urmbOnKlTp05Zj9bj7r33XjU0NIS3Tz/91HqkuGtra1Nubq7Kysq6fH7t2rVav3693njjDe3bt0+33XabZs6cqfb29h6eNL6udR4kadasWRGvjy1btvTghPFXWVmpoqIiVVdXa/fu3bpw4YJmzJihtra28D4rV67UBx98oPfee0+VlZU6efKk5s2bZzh17F3PeZCkpUuXRrwe1q5dazRxN1wfMGnSJFdUVBT++uLFiy4rK8uVlpYaTtXzVq9e7XJzc63HMCXJbdu2Lfx1Z2enCwaDbt26deHHmpubnd/vd1u2bDGYsGdceR6cc27RokVu9uzZJvNYOXXqlJPkKisrnXOX/t0PGDDAvffee+F9vvjiCyfJVVVVWY0Zd1eeB+ec+9GPfuR++tOf2g11HXr9FdD58+d18OBBFRQUhB/r16+fCgoKVFVVZTiZjaNHjyorK0sjRozQY489puPHj1uPZKqurk6NjY0Rr49AIKC8vLyb8vVRUVGh9PR0jR49WsuWLdPp06etR4qrlpYWSVJqaqok6eDBg7pw4ULE62HMmDEaNmxYQr8erjwPl7399ttKS0vT2LFjVVJSorNnz1qM161edzPSK3399de6ePGiMjIyIh7PyMjQP//5T6OpbOTl5WnTpk0aPXq0Ghoa9OKLL2ry5Mk6cuSIkpOTrccz0djYKEldvj4uP3ezmDVrlubNm6ecnBwdO3ZMzz33nAoLC1VVVaX+/ftbjxdznZ2dWrFihe6//36NHTtW0qXXQ1JSkgYPHhyxbyK/Hro6D5L06KOPavjw4crKytLhw4f17LPPqqamRu+//77htJF6fYDwP4WFheFfjx8/Xnl5eRo+fLjeffddLVmyxHAy9AYLFy4M/3rcuHEaP368Ro4cqYqKCk2fPt1wsvgoKirSkSNHbor3Qb9Nd+fhiSeeCP963LhxyszM1PTp03Xs2DGNHDmyp8fsUq//Lbi0tDT179//qk+xNDU1KRgMGk3VOwwePFh33XWXamtrrUcxc/k1wOvjaiNGjFBaWlpCvj6WL1+unTt36pNPPon461uCwaDOnz+v5ubmiP0T9fXQ3XnoSl5eniT1qtdDrw9QUlKSJkyYoPLy8vBjnZ2dKi8vV35+vuFk9s6cOaNjx44pMzPTehQzOTk5CgaDEa+PUCikffv23fSvjxMnTuj06dMJ9fpwzmn58uXatm2b9uzZo5ycnIjnJ0yYoAEDBkS8HmpqanT8+PGEej1c6zx05dChQ5LUu14P1p+CuB5bt251fr/fbdq0yf3jH/9wTzzxhBs8eLBrbGy0Hq1H/exnP3MVFRWurq7O/e1vf3MFBQUuLS3NnTp1ynq0uGptbXWff/65+/zzz50k9+qrr7rPP//cffnll8455375y1+6wYMHux07drjDhw+72bNnu5ycHHfu3DnjyWPr285Da2ure/rpp11VVZWrq6tzH3/8sfvBD37gRo0a5drb261Hj5lly5a5QCDgKioqXENDQ3g7e/ZseJ8nn3zSDRs2zO3Zs8cdOHDA5efnu/z8fMOpY+9a56G2tta99NJL7sCBA66urs7t2LHDjRgxwk2ZMsV48kh9IkDOObdhwwY3bNgwl5SU5CZNmuSqq6utR+pxCxYscJmZmS4pKcl997vfdQsWLHC1tbXWY8XdJ5984iRdtS1atMg5d+mj2C+88ILLyMhwfr/fTZ8+3dXU1NgOHQffdh7Onj3rZsyY4e644w43YMAAN3z4cLd06dKE+5+0rv75JbmNGzeG9zl37pz7yU9+4r7zne+4W2+91c2dO9c1NDTYDR0H1zoPx48fd1OmTHGpqanO7/e7O++80/385z93LS0ttoNfgb+OAQBgote/BwQASEwECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIn/AtNbpDSoQnmvAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# print(f\"정답 > {y_test}\")\n", "\n", "x_train, y_train, x_test, y_test = get_data()\n", "\n", "predicted_result = model.predict(x_test)\n", "predicted_labels = np.argmax(predicted_result, axis=1)\n", "not_correct = []\n", "for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n", " if predicted_labels[i] != y_test[i]:\n", " not_correct.append(i)\n", " # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n", " \n", "print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n", "\n", "\n", "for i in range(3):\n", " plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n", "plt.show() \n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "fc2d7044", "metadata": {}, "outputs": [], "source": [ "def default_mnist(epochs=5):\n", " x_train, y_train, x_test, y_test = get_data()\n", " model = make_model()\n", " \n", " model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])\n", " wei = model.get_weights()\n", " model.set_weights(wei)\n", " score = model.evaluate(x_test, y_test, verbose=2)\n", " print(f\"score : {score}\")\n", " # hist = model.fit(x_train, y_train, epochs=epochs, batch_size=32, verbose=1)\n", " # print(hist.history['loss'][-1])\n", " # print(hist.history['accuracy'][-1])\n", "\n", " # predicted_result = model.predict(x_test)\n", " # predicted_labels = np.argmax(predicted_result, axis=1)\n", " # not_correct = []\n", " # for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n", " # if predicted_labels[i] != y_test[i]:\n", " # not_correct.append(i)\n", " # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n", " \n", " # print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n", "# default_mnist()" ] }, { "cell_type": "code", "execution_count": null, "id": "27024a0b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pso", "language": "python", "name": "pso" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }