mnist 파티클 개수 75 -> 150 으로 조정
tensorboard 로 log 분석할 수 있게 수정
pypi 패키지 파일 제거
conda env 파일 tensorflow 2.12 -> 2.11
This commit is contained in:
jung-geun
2023-07-13 21:39:40 +09:00
parent 5494df2bc3
commit 768d3ccee7
22 changed files with 157 additions and 1866 deletions

View File

@@ -81,8 +81,7 @@
" print(e)\n",
"\n",
"from tensorflow import keras\n",
"from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten,\n",
" MaxPooling2D)\n",
"from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D\n",
"from tensorflow.keras.models import Sequential"
]
},
@@ -145,19 +144,22 @@
"source": [
"def make_model():\n",
" model = Sequential()\n",
" model.add(Conv2D(32, kernel_size=(5, 5), activation='relu', input_shape=(28,28,1)))\n",
" model.add(\n",
" Conv2D(32, kernel_size=(5, 5), activation=\"relu\", input_shape=(28, 28, 1))\n",
" )\n",
" model.add(MaxPooling2D(pool_size=(3, 3)))\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))\n",
" model.add(Conv2D(64, kernel_size=(3, 3), activation=\"relu\"))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
" model.add(Flatten())\n",
" model.add(Dense(128, activation='relu'))\n",
" model.add(Dense(10, activation='softmax'))\n",
" model.add(Dense(128, activation=\"relu\"))\n",
" model.add(Dense(10, activation=\"softmax\"))\n",
"\n",
" # model.summary()\n",
"\n",
" return model\n",
"\n",
"\n",
"model = make_model()\n",
"weights = model.get_weights()\n",
"\n",
@@ -165,7 +167,7 @@
"\n",
"for i in range(len(weights)):\n",
" print(weights[i].shape)\n",
" print(weights[i].min(), weights[i].max())\n"
" print(weights[i].min(), weights[i].max())"
]
},
{
@@ -197,7 +199,7 @@
"# json_ = model.to_json()\n",
"# print(json_)\n",
"# for layer in model.get_weights():\n",
" # print(layer.shape)\n",
"# print(layer.shape)\n",
"weight = model.get_weights()"
]
},
@@ -246,9 +248,10 @@
" w_ = layer.reshape(-1)\n",
" lenght.append(len(w_))\n",
" w_gpu = cp.append(w_gpu, w_)\n",
" \n",
"\n",
" return w_gpu, shape, lenght\n",
"\n",
"\n",
"def decode(weight, shape, lenght):\n",
" weights = []\n",
" start = 0\n",
@@ -263,15 +266,16 @@
"\n",
" return weights\n",
"\n",
"\n",
"w = 0.8\n",
"v,_,_ = encode(weight)\n",
"v, _, _ = encode(weight)\n",
"c0 = 0.5\n",
"c1 = 1.5\n",
"r0 = 0.2\n",
"r1 = 0.8\n",
"p_best,_,_ = encode(weight)\n",
"g_best,_,_ = encode(weight)\n",
"layer,shape,leng = encode(weight)\n",
"p_best, _, _ = encode(weight)\n",
"g_best, _, _ = encode(weight)\n",
"layer, shape, leng = encode(weight)\n",
"\n",
"# new_v = w*v[i]\n",
"# new_v = new_v + c0*r0*(p_best[i] - layer)\n",
@@ -313,7 +317,7 @@
"# print(\"not same\")\n",
"# break\n",
"# else:\n",
"# print(\"same\")\n"
"# print(\"same\")"
]
},
{
@@ -409,10 +413,11 @@
"\n",
"\n",
"def get_xor():\n",
" x = np.array([[0,0],[0,1],[1,0],[1,1]])\n",
" y = np.array([[0],[1],[1],[0]])\n",
" x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
" y = np.array([[0], [1], [1], [0]])\n",
"\n",
" return x, y\n",
"\n",
" return x,y\n",
"\n",
"def get_iris():\n",
" iris = load_iris()\n",
@@ -421,10 +426,13 @@
"\n",
" y = keras.utils.to_categorical(y, 3)\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle=True, stratify=y)\n",
" x_train, x_test, y_train, y_test = train_test_split(\n",
" x, y, test_size=0.2, shuffle=True, stratify=y\n",
" )\n",
"\n",
" return x_train, x_test, y_train, y_test\n",
"\n",
"\n",
"# model = keras.models.load_model(\"./result/xor/06-02-13-31/75_0.35_0.8_0.6.h5\")\n",
"model = keras.models.load_model(\"./result/iris/06-02-13-48/50_0.4_0.8_0.7.h5\")\n",
"# x,y = get_xor()\n",
@@ -432,7 +440,7 @@
"\n",
"print(model.predict(x_test))\n",
"print(y_test)\n",
"print(model.evaluate(x_test,y_test))"
"print(model.evaluate(x_test, y_test))"
]
},
{
@@ -464,7 +472,7 @@
"import tensorflow.compiler as tf_cc\n",
"import tensorrt as trt\n",
"\n",
"linked_trt_ver=tf_cc.tf2tensorrt._pywrap_py_utils.get_linked_tensorrt_version()\n",
"linked_trt_ver = tf_cc.tf2tensorrt._pywrap_py_utils.get_linked_tensorrt_version()\n",
"print(f\"Linked TRT ver: {linked_trt_ver}\")"
]
},