mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
1032 lines
36 KiB
Plaintext
1032 lines
36 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "8a637c69-9071-4012-ac1e-93037548b3e9",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2023-05-24 15:37:52.889357: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.10.0\n",
|
|
"[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
|
|
"\n",
|
|
"import tensorflow as tf\n",
|
|
"# tf.random.set_seed(777) # for reproducibility\n",
|
|
"\n",
|
|
"from tensorflow import keras\n",
|
|
"from keras.datasets import mnist\n",
|
|
"from keras.models import Sequential\n",
|
|
"from keras.layers import Dense, Dropout, Flatten\n",
|
|
"from keras.layers import Conv2D, MaxPooling2D\n",
|
|
"from keras import backend as K\n",
|
|
"\n",
|
|
"from pso_tf import PSO\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from datetime import date\n",
|
|
"from tqdm import tqdm\n",
|
|
"import json\n",
|
|
"\n",
|
|
"print(tf.__version__)\n",
|
|
"print(tf.config.list_physical_devices())\n",
|
|
"\n",
|
|
"def get_data():\n",
|
|
" (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
|
|
"\n",
|
|
" x_train, x_test = x_train / 255.0, x_test / 255.0\n",
|
|
" x_train = x_train.reshape((60000, 28 ,28, 1))\n",
|
|
" x_test = x_test.reshape((10000, 28 ,28, 1))\n",
|
|
"\n",
|
|
" print(f\"x_train : {x_train[0].shape} | y_train : {y_train[0].shape}\")\n",
|
|
" print(f\"x_test : {x_test[0].shape} | y_test : {y_test[0].shape}\")\n",
|
|
" return x_train, y_train, x_test, y_test\n",
|
|
"\n",
|
|
"def make_model():\n",
|
|
" model = Sequential()\n",
|
|
" model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n",
|
|
" model.add(MaxPooling2D(pool_size=(3, 3)))\n",
|
|
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
|
|
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
|
|
" model.add(Dropout(0.25))\n",
|
|
" model.add(Flatten())\n",
|
|
" model.add(Dense(128, activation='relu'))\n",
|
|
" model.add(Dense(10, activation='softmax'))\n",
|
|
"\n",
|
|
" # model.summary()\n",
|
|
"\n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "a2d9891d",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"x_train : (28, 28, 1) | y_train : ()\n",
|
|
"x_test : (28, 28, 1) | y_test : ()\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.65it/s]\n",
|
|
"init velocities: 100%|██████████| 30/30 [00:00<00:00, 681.12it/s]\n",
|
|
"Iter 0/20: 13%|#3 | 4/30 [00:04<00:20, 1.28it/s]"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"'''\n",
|
|
"optimizer parameter\n",
|
|
"'''\n",
|
|
"lr = 0.1\n",
|
|
"momentun = 0.8\n",
|
|
"decay = 1e-04\n",
|
|
"nestrov = True\n",
|
|
"\n",
|
|
"'''\n",
|
|
"pso parameter\n",
|
|
"'''\n",
|
|
"n_particles = 30\n",
|
|
"maxiter = 20\n",
|
|
"# epochs = 1\n",
|
|
"w = 0.8\n",
|
|
"c0 = 0.6\n",
|
|
"c1 = 1.6\n",
|
|
"\n",
|
|
"\n",
|
|
"x_train, y_train, x_test, y_test = get_data()\n",
|
|
"model = make_model()\n",
|
|
"\n",
|
|
"loss = keras.losses.MeanSquaredError()\n",
|
|
"\n",
|
|
"\n",
|
|
"pso_m = PSO(model=model, loss_method=loss, n_particles=n_particles)\n",
|
|
"# c0 : 지역 최적값 중요도\n",
|
|
"# c1 : 전역 최적값 중요도\n",
|
|
"# w : 관성 (현재 속도를 유지하는 정도)\n",
|
|
"best_weights, score = pso_m.optimize(x_train, y_train, x_test, y_test, maxiter=maxiter, c0=c0, c1=c1, w=w)\n",
|
|
"model.set_weights(best_weights)\n",
|
|
"\n",
|
|
"score_ = model.evaluate(x_test, y_test, verbose=2)\n",
|
|
"print(f\" Test loss: {score_}\")\n",
|
|
"score = round(score_[1]*100, 2)\n",
|
|
"\n",
|
|
"day = date.today().strftime(\"%Y-%m-%d\")\n",
|
|
"\n",
|
|
"os.makedirs(f'./model', exist_ok=True)\n",
|
|
"model.save(f'./model/{day}_{score}_mnist.h5')\n",
|
|
"json_save = {\n",
|
|
" \"name\" : f\"{day}_{score}_mnist.h5\",\n",
|
|
" \"score\" : score_,\n",
|
|
" \"maxiter\" : maxiter,\n",
|
|
" \"c0\" : c0,\n",
|
|
" \"c1\" : c1,\n",
|
|
" \"w\" : w \n",
|
|
"}\n",
|
|
"with open(f'./model/{day}_{score}_mnist.json', 'a') as f:\n",
|
|
" json.dump(json_save, f)\n",
|
|
" f.write(',\\n')\n",
|
|
"\n",
|
|
"\n",
|
|
"# auto_tuning(n_particles=30, maxiter=1000, c0=0.5, c1=1.5, w=0.75)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1af7569b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"loss_, acc_ = pso_m.all_history()\n",
|
|
"\n",
|
|
"plt.subplot(2,1,1)\n",
|
|
"for layer in all_loss:\n",
|
|
" plt.plot(layer)\n",
|
|
"plt.title('loss history')\n",
|
|
"\n",
|
|
"plt.subplot(2,1,2)\n",
|
|
"for layer in all_acc:\n",
|
|
" plt.plot(layer)\n",
|
|
"plt.title('acc history')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "1a38f3c1-8291-40d9-838e-4ffbf4578be5",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"x_train : (28, 28, 1) | y_train : ()\n",
|
|
"x_test : (28, 28, 1) | y_test : ()\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/pieroot/miniconda3/envs/pso/lib/python3.8/site-packages/keras/optimizers/optimizer_v2/gradient_descent.py:111: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
|
|
" super().__init__(name, **kwargs)\n",
|
|
"init particles position: 100%|██████████| 30/30 [00:00<00:00, 36.95it/s]\n",
|
|
"init velocities: 100%|██████████| 30/30 [00:00<00:00, 1399.35it/s]\n",
|
|
"Iter 0/50: 100%|##########| 30/30 [00:15<00:00, 1.98it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9084339777628581 | acc avg : 0.0019799999892711638 | best loss : 0.15219999849796295\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 1/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9090563456217448 | acc avg : 0.0031199999153614043 | best loss : 0.20149999856948853\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 2/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9103448867797852 | acc avg : 0.005286666750907898 | best loss : 0.20149999856948853\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 3/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113266626993816 | acc avg : 0.004926666617393494 | best loss : 0.20149999856948853\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 4/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113243738810222 | acc avg : 0.004126666734615962 | best loss : 0.20149999856948853\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 5/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113284428914388 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 6/50: 100%|##########| 30/30 [00:11<00:00, 2.51it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113288243611654 | acc avg : 0.0034666667381922406 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 7/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113253911336263 | acc avg : 0.0029633333285649615 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 8/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113227208455403 | acc avg : 0.002809999883174896 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 9/50: 100%|##########| 30/30 [00:11<00:00, 2.53it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911251958211263 | acc avg : 0.005486666659514109 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 10/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113367716471354 | acc avg : 0.004316666722297668 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 11/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113431294759115 | acc avg : 0.002943333238363266 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 12/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113429387410482 | acc avg : 0.004413333535194397 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 13/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113424936930339 | acc avg : 0.004670000076293946 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 14/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113433202107747 | acc avg : 0.0024433332184950513 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 15/50: 100%|##########| 30/30 [00:12<00:00, 2.48it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113349914550781 | acc avg : 0.0030966666837533314 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 16/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 17/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002806666741768519 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 18/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002503333240747452 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 19/50: 100%|##########| 30/30 [00:12<00:00, 2.47it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.9113413492838541 | acc avg : 0.003179999937613805 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 20/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.004823333521684011 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 21/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.003663333257039388 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 22/50: 100%|##########| 30/30 [00:11<00:00, 2.54it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002916666616996129 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 23/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0026966666181882223 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 24/50: 100%|##########| 30/30 [00:12<00:00, 2.43it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0028999999165534975 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 25/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0028833332161108654 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 26/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0027433333297570547 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 27/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0024033332864443462 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 28/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.004453333218892416 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 29/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.00338333323597908 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 30/50: 100%|##########| 30/30 [00:12<00:00, 2.44it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0028333333631356556 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 31/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002480000009139379 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 32/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0030733334521452584 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 33/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0028366667528947195 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 34/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002760000030199687 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 35/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002463333308696747 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 36/50: 100%|##########| 30/30 [00:11<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.004286666711171468 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 37/50: 100%|##########| 30/30 [00:12<00:00, 2.39it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.003916666656732559 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 38/50: 100%|##########| 30/30 [00:11<00:00, 2.59it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0037066665788491565 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 39/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.003233333428700765 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 40/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0020900001128514607 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 41/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002956666549046834 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 42/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0028566665947437286 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 43/50: 100%|##########| 30/30 [00:11<00:00, 2.57it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0026866666972637176 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 44/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0045466666420300806 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 45/50: 100%|##########| 30/30 [00:12<00:00, 2.36it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.004050000011920929 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 46/50: 100%|##########| 30/30 [00:11<00:00, 2.58it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0037399999797344207 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 47/50: 100%|##########| 30/30 [00:11<00:00, 2.60it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.00264999990661939 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 48/50: 100%|##########| 30/30 [00:11<00:00, 2.61it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.0029666667183240254 | best loss : 0.20180000364780426\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iter 49/50: 100%|##########| 30/30 [00:11<00:00, 2.56it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss avg : 0.911343765258789 | acc avg : 0.002906666696071625 | best loss : 0.20180000364780426\n",
|
|
"313/313 - 0s - loss: 27.3092 - accuracy: 0.2018 - 247ms/epoch - 788us/step\n",
|
|
" Test loss: [27.309202194213867, 0.20180000364780426]\n",
|
|
"x_train : (28, 28, 1) | y_train : ()\n",
|
|
"x_test : (28, 28, 1) | y_test : ()\n",
|
|
"313/313 [==============================] - 0s 691us/step\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"진행도: 100%|██████████| 10000/10000 [00:00<00:00, 2226867.00it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"틀린 갯수 > 7982/10000\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"# print(f\"정답 > {y_test}\")\n",
|
|
"\n",
|
|
"x_train, y_train, x_test, y_test = get_data()\n",
|
|
"\n",
|
|
"predicted_result = model.predict(x_test)\n",
|
|
"predicted_labels = np.argmax(predicted_result, axis=1)\n",
|
|
"not_correct = []\n",
|
|
"for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n",
|
|
" if predicted_labels[i] != y_test[i]:\n",
|
|
" not_correct.append(i)\n",
|
|
" # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n",
|
|
" \n",
|
|
"print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n",
|
|
"\n",
|
|
"\n",
|
|
"for i in range(3):\n",
|
|
" plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n",
|
|
"plt.show() \n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "fc2d7044",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def default_mnist(epochs=5):\n",
|
|
" x_train, y_train, x_test, y_test = get_data()\n",
|
|
" model = make_model()\n",
|
|
" \n",
|
|
" model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])\n",
|
|
" wei = model.get_weights()\n",
|
|
" model.set_weights(wei)\n",
|
|
" score = model.evaluate(x_test, y_test, verbose=2)\n",
|
|
" print(f\"score : {score}\")\n",
|
|
" # hist = model.fit(x_train, y_train, epochs=epochs, batch_size=32, verbose=1)\n",
|
|
" # print(hist.history['loss'][-1])\n",
|
|
" # print(hist.history['accuracy'][-1])\n",
|
|
"\n",
|
|
" # predicted_result = model.predict(x_test)\n",
|
|
" # predicted_labels = np.argmax(predicted_result, axis=1)\n",
|
|
" # not_correct = []\n",
|
|
" # for i in tqdm(range(len(y_test)), desc=\"진행도\"):\n",
|
|
" # if predicted_labels[i] != y_test[i]:\n",
|
|
" # not_correct.append(i)\n",
|
|
" # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n",
|
|
" \n",
|
|
" # print(f\"틀린 갯수 > {len(not_correct)}/{len(y_test)}\")\n",
|
|
"# default_mnist()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "27024a0b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pso",
|
|
"language": "python",
|
|
"name": "pso"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
},
|
|
"widgets": {
|
|
"application/vnd.jupyter.widget-state+json": {
|
|
"state": {},
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|