{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8a637c69-9071-4012-ac1e-93037548b3e9", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2023-05-24 15:37:52.889357: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2.10.0\n", "[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n" ] } ], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n", "\n", "import tensorflow as tf\n", "# tf.random.set_seed(777) # for reproducibility\n", "\n", "from tensorflow import keras\n", "from keras.datasets import mnist\n", "from keras.models import Sequential\n", "from keras.layers import Dense, Dropout, Flatten\n", "from keras.layers import Conv2D, MaxPooling2D\n", "from keras import backend as K\n", "\n", "from pso_tf import PSO\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "from datetime import date\n", "from tqdm import tqdm\n", "import json\n", "\n", "print(tf.__version__)\n", "print(tf.config.list_physical_devices())\n", "\n", "def get_data():\n", " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "\n", " x_train, x_test = x_train / 255.0, x_test / 255.0\n", " x_train = x_train.reshape((60000, 28 ,28, 1))\n", " x_test = x_test.reshape((10000, 28 ,28, 1))\n", "\n", " print(f\"x_train : {x_train[0].shape} | y_train : {y_train[0].shape}\")\n", " print(f\"x_test : {x_test[0].shape} | y_test : {y_test[0].shape}\")\n", " return x_train, y_train, x_test, y_test\n", "\n", "def make_model():\n", " model = Sequential()\n", " model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n", " model.add(MaxPooling2D(pool_size=(3, 3)))\n", " model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n", " model.add(MaxPooling2D(pool_size=(2, 2)))\n", " model.add(Dropout(0.25))\n", " model.add(Flatten())\n", " model.add(Dense(128, activation='relu'))\n", " model.add(Dense(10, activation='softmax'))\n", "\n", " # model.summary()\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": 2, "id": "a2d9891d", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.65it/s]\n", "init velocities: 100%|██████████| 30/30 [00:00<00:00, 681.12it/s]\n", "Iter 0/20: 13%|#3 | 4/30 [00:04<00:20, 1.28it/s]" ] } ], "source": [ "'''\n", "optimizer parameter\n", "'''\n", "lr = 0.1\n", "momentun = 0.8\n", "decay = 1e-04\n", "nestrov = True\n", "\n", "'''\n", "pso parameter\n", "'''\n", "n_particles = 30\n", "maxiter = 20\n", "# epochs = 1\n", "w = 0.8\n", "c0 = 0.6\n", "c1 = 1.6\n", "\n", "\n", "x_train, y_train, x_test, y_test = get_data()\n", "model = make_model()\n", "\n", "loss = keras.losses.MeanSquaredError()\n", "\n", "\n", "pso_m = PSO(model=model, loss_method=loss, n_particles=n_particles)\n", "# c0 : 지역 최적값 중요도\n", "# c1 : 전역 최적값 중요도\n", "# w : 관성 (현재 속도를 유지하는 정도)\n", "best_weights, score = pso_m.optimize(x_train, y_train, x_test, y_test, maxiter=maxiter, c0=c0, c1=c1, w=w)\n", "model.set_weights(best_weights)\n", "\n", "score_ = model.evaluate(x_test, y_test, verbose=2)\n", "print(f\" Test loss: {score_}\")\n", "score = round(score_[1]*100, 2)\n", "\n", "day = date.today().strftime(\"%Y-%m-%d\")\n", "\n", "os.makedirs(f'./model', exist_ok=True)\n", "model.save(f'./model/{day}_{score}_mnist.h5')\n", "json_save = {\n", " \"name\" : f\"{day}_{score}_mnist.h5\",\n", " \"score\" : score_,\n", " \"maxiter\" : maxiter,\n", " \"c0\" : c0,\n", " \"c1\" : c1,\n", " \"w\" : w \n", "}\n", "with open(f'./model/{day}_{score}_mnist.json', 'a') as f:\n", " json.dump(json_save, f)\n", " f.write(',\\n')\n", "\n", "\n", "# auto_tuning(n_particles=30, maxiter=1000, c0=0.5, c1=1.5, w=0.75)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1af7569b", "metadata": {}, "outputs": [], "source": [ "loss_, acc_ = pso_m.all_history()\n", "\n", "plt.subplot(2,1,1)\n", "for layer in all_loss:\n", " plt.plot(layer)\n", "plt.title('loss history')\n", "\n", "plt.subplot(2,1,2)\n", "for layer in all_acc:\n", " plt.plot(layer)\n", "plt.title('acc history')" ] }, { "cell_type": "code", "execution_count": 3, "id": "1a38f3c1-8291-40d9-838e-4ffbf4578be5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", " super().__init__(name, **kwargs)\n", "init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.95it/s]\n", "init velocities: 100%|██████████| 30/30 [00:00<00:00, 1399.35it/s]\n", "Iter 0/50: 100%|##########| 30/30 [00:15<00:00, 1.98it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9084339777628581 | acc avg : 0.0019799999892711638 | best loss : 0.15219999849796295\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 1/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9090563456217448 | acc avg : 0.0031199999153614043 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 2/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9103448867797852 | acc avg : 0.005286666750907898 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 3/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113266626993816 | acc avg : 0.004926666617393494 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 4/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113243738810222 | acc avg : 0.004126666734615962 | best loss : 0.20149999856948853\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 5/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113284428914388 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 6/50: 100%|##########| 30/30 [00:11<00:00, 2.51it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113288243611654 | acc avg : 0.0034666667381922406 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 7/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113253911336263 | acc avg : 0.0029633333285649615 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 8/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113227208455403 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 9/50: 100%|##########| 30/30 [00:11<00:00, 2.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911251958211263 | acc avg : 0.005486666659514109 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 10/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113367716471354 | acc avg : 0.004316666722297668 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 11/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113431294759115 | acc avg : 0.002943333238363266 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 12/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113429387410482 | acc avg : 0.004413333535194397 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 13/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113424936930339 | acc avg : 0.004670000076293946 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 14/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113433202107747 | acc avg : 0.0024433332184950513 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 15/50: 100%|##########| 30/30 [00:12<00:00, 2.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113349914550781 | acc avg : 0.0030966666837533314 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 16/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 17/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002806666741768519 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 18/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002503333240747452 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 19/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.9113413492838541 | acc avg : 0.003179999937613805 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 20/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004823333521684011 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 21/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003663333257039388 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 22/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002916666616996129 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 23/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0026966666181882223 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 24/50: 100%|##########| 30/30 [00:12<00:00, 2.43it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028999999165534975 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 25/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028833332161108654 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 26/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0027433333297570547 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 27/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0024033332864443462 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 28/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004453333218892416 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 29/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.00338333323597908 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 30/50: 100%|##########| 30/30 [00:12<00:00, 2.44it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028333333631356556 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 31/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002480000009139379 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 32/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0030733334521452584 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 33/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028366667528947195 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 34/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002760000030199687 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 35/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002463333308696747 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 36/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004286666711171468 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 37/50: 100%|##########| 30/30 [00:12<00:00, 2.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003916666656732559 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 38/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0037066665788491565 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 39/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.003233333428700765 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 40/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0020900001128514607 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 41/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 42/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0028566665947437286 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 43/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0026866666972637176 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 44/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0045466666420300806 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 45/50: 100%|##########| 30/30 [00:12<00:00, 2.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.004050000011920929 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 46/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0037399999797344207 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 47/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.00264999990661939 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 48/50: 100%|##########| 30/30 [00:11<00:00, 2.61it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.0029666667183240254 | best loss : 0.20180000364780426\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Iter 49/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "loss avg : 0.911343765258789 | acc avg : 0.002906666696071625 | best loss : 0.20180000364780426\n", "313/313 - 0s - loss: 27.3092 - accuracy: 0.2018 - 247ms/epoch - 788us/step\n", " Test loss: [27.309202194213867, 0.20180000364780426]\n", "x_train : (28, 28, 1) | y_train : ()\n", "x_test : (28, 28, 1) | y_test : ()\n", "313/313 [==============================] - 0s 691us/step\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "진행도: 100%|██████████| 10000/10000 [00:00<00:00, 2226867.00it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "틀린 갯수 > 7982/10000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# print(f\"정답 > {y_test}\")\n", "\n", "x_train, y_train, x_test, y_test = get_data()\n", "\n", "predicted_result = model.predict(x_test)\n", "predicted_labels = np.argmax(predicted_result, axis=1)\n", "not_correct = []\n", "for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n", " if predicted_labels[i] != y_test[i]:\n", " not_correct.append(i)\n", " # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n", " \n", "print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n", "\n", "\n", "for i in range(3):\n", " plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n", "plt.show() \n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "fc2d7044", "metadata": {}, "outputs": [], "source": [ "def default_mnist(epochs=5):\n", " x_train, y_train, x_test, y_test = get_data()\n", " model = make_model()\n", " \n", " model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])\n", " wei = model.get_weights()\n", " model.set_weights(wei)\n", " score = model.evaluate(x_test, y_test, verbose=2)\n", " print(f\"score : {score}\")\n", " # hist = model.fit(x_train, y_train, epochs=epochs, batch_size=32, verbose=1)\n", " # print(hist.history['loss'][-1])\n", " # print(hist.history['accuracy'][-1])\n", "\n", " # predicted_result = model.predict(x_test)\n", " # predicted_labels = np.argmax(predicted_result, axis=1)\n", " # not_correct = []\n", " # for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n", " # if predicted_labels[i] != y_test[i]:\n", " # not_correct.append(i)\n", " # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n", " \n", " # print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n", "# default_mnist()" ] }, { "cell_type": "code", "execution_count": null, "id": "27024a0b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pso", "language": "python", "name": "pso" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }