mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-19 20:44:39 +09:00
402 lines
10 KiB
Plaintext
402 lines
10 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2.10.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'\n",
|
|
"import tensorflow as tf\n",
|
|
"# tf.random.set_seed(777) # for reproducibility\n",
|
|
"\n",
|
|
"from pso_tf import PSO\n",
|
|
"from tensorflow import keras\n",
|
|
"\n",
|
|
"print(tf.__version__)\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from tensorflow import keras\n",
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras import layers\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"def get_data():\n",
|
|
" x = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
|
|
" y = np.array([[0], [1], [1], [0]])\n",
|
|
" \n",
|
|
" return x, y\n",
|
|
"\n",
|
|
"def make_model():\n",
|
|
" leyer = []\n",
|
|
" leyer.append(layers.Dense(2, activation='sigmoid', input_shape=(2,)))\n",
|
|
" leyer.append(layers.Dense(1, activation='sigmoid'))\n",
|
|
"\n",
|
|
" model = Sequential(leyer)\n",
|
|
"\n",
|
|
" sgd = keras.optimizers.SGD(lr=0.1, momentum=1, decay=1e-05, nesterov=True)\n",
|
|
" # adam = keras.optimizers.Adam(lr=0.1, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.)\n",
|
|
" model.compile(loss='mse', optimizer=sgd, metrics=['accuracy'])\n",
|
|
"\n",
|
|
" print(model.summary())\n",
|
|
"\n",
|
|
" return model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Model: \"sequential_11\"\n",
|
|
"_________________________________________________________________\n",
|
|
" Layer (type) Output Shape Param # \n",
|
|
"=================================================================\n",
|
|
" dense_22 (Dense) (None, 2) 6 \n",
|
|
" \n",
|
|
" dense_23 (Dense) (None, 1) 3 \n",
|
|
" \n",
|
|
"=================================================================\n",
|
|
"Total params: 9\n",
|
|
"Trainable params: 9\n",
|
|
"Non-trainable params: 0\n",
|
|
"_________________________________________________________________\n",
|
|
"None\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"init particles position: 100%|██████████| 15/15 [00:00<00:00, 85.12it/s]\n",
|
|
"init velocities: 100%|██████████| 15/15 [00:00<00:00, 46465.70it/s]\n",
|
|
"Iteration 0 / 10: 100%|##########| 15/15 [00:05<00:00, 2.63it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.2620863338311513 | acc : 0.5 | best loss : 0.24143654108047485\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 1 / 10: 100%|##########| 15/15 [00:05<00:00, 2.69it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.24147300918896994 | acc : 0.6333333333333333 | best loss : 0.20360520482063293\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 2 / 10: 100%|##########| 15/15 [00:05<00:00, 2.72it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.211648628115654 | acc : 0.65 | best loss : 0.17383326590061188\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 3 / 10: 100%|##########| 15/15 [00:05<00:00, 2.72it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.21790608167648315 | acc : 0.6833333333333333 | best loss : 0.16785581409931183\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 4 / 10: 100%|##########| 15/15 [00:05<00:00, 2.73it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.20557119349638622 | acc : 0.7333333333333333 | best loss : 0.16668711602687836\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 5 / 10: 100%|##########| 15/15 [00:05<00:00, 2.68it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.20823073089122773 | acc : 0.7 | best loss : 0.16668711602687836\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 6 / 10: 100%|##########| 15/15 [00:05<00:00, 2.53it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.21380058924357095 | acc : 0.7166666666666667 | best loss : 0.16668711602687836\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 7 / 10: 100%|##########| 15/15 [00:06<00:00, 2.30it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.2561836312214533 | acc : 0.6833333333333333 | best loss : 0.16667115688323975\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 8 / 10: 100%|##########| 15/15 [00:05<00:00, 2.55it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.30372582376003265 | acc : 0.65 | best loss : 0.16667115688323975\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Iteration 9 / 10: 100%|##########| 15/15 [00:05<00:00, 2.65it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"loss : 0.281868569056193 | acc : 0.7 | best loss : 0.16667115688323975\n",
|
|
"1/1 [==============================] - 0s 26ms/step\n",
|
|
"[[0. ]\n",
|
|
" [0.66422266]\n",
|
|
" [0.6642227 ]\n",
|
|
" [0.6642227 ]]\n",
|
|
"[[0]\n",
|
|
" [1]\n",
|
|
" [1]\n",
|
|
" [0]]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"x, y = get_data()\n",
|
|
"x_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
|
|
"y_test = np.array([[0], [1], [1], [0]])\n",
|
|
"\n",
|
|
"model = make_model()\n",
|
|
"\n",
|
|
"loss = keras.losses.MeanSquaredError()\n",
|
|
"optimizer = keras.optimizers.SGD(lr=0.1, momentum=0.9, decay=1e-05, nesterov=True)\n",
|
|
"\n",
|
|
"\n",
|
|
"pso_xor = PSO(model=model, loss_method=loss, optimizer=optimizer, n_particles=15)\n",
|
|
"\n",
|
|
"best_weights, score = pso_xor.optimize(x, y, x_test, y_test, maxiter=10, epochs=20)\n",
|
|
"\n",
|
|
"model.set_weights(best_weights)\n",
|
|
"\n",
|
|
"y_pred = model.predict(x_test)\n",
|
|
"print(y_pred)\n",
|
|
"print(y_test)\n",
|
|
"\n",
|
|
"history = pso_xor.global_history()\n",
|
|
"\n",
|
|
"# print(f\"history > {history}\")\n",
|
|
"# print(f\"score > {score}\")\n",
|
|
"# plt.plot(history)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x_test = np.array([[0, 1], [0, 0], [1, 1], [1, 0]])\n",
|
|
"y_pred = model.predict(x_test)\n",
|
|
"print(y_pred)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def test():\n",
|
|
" model = make_model()\n",
|
|
" x, y = get_data()\n",
|
|
" x_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
|
|
" y_test = np.array([[0], [1], [1], [0]])\n",
|
|
" \n",
|
|
" callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)\n",
|
|
"\n",
|
|
" hist = model.fit(x, y, epochs=50000, verbose=1, callbacks=[callback] , validation_data=(x_test, y_test))\n",
|
|
" y_pred=model.predict(x_test)\n",
|
|
" print(y_pred)\n",
|
|
" print(y_test)\n",
|
|
" \n",
|
|
" return hist"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predicted_result = model.predict(x_test)\n",
|
|
"predicted_labels = np.argmax(predicted_result, axis=1)\n",
|
|
"not_correct = []\n",
|
|
"for i in range(len(y_test)):\n",
|
|
" if predicted_labels[i] != y_test[i]:\n",
|
|
" not_correct.append(i)\n",
|
|
" # print(f\"추론 > {predicted_labels[i]} | 정답 > {y_test[i]}\")\n",
|
|
" \n",
|
|
"print(f\"틀린 것 갯수 > {len(not_correct)}\")\n",
|
|
"for i in range(3):\n",
|
|
" plt.imshow(x_test[not_correct[i]].reshape(28,28), cmap='Greys')\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot_history(history):\n",
|
|
" fig, loss_ax = plt.subplots()\n",
|
|
" acc_ax = loss_ax.twinx()\n",
|
|
"\n",
|
|
" loss_ax.plot(hist.history['loss'], 'y', label='train loss')\n",
|
|
" loss_ax.plot(hist.history['val_loss'], 'r', label='val loss')\n",
|
|
" loss_ax.set_xlabel('epoch')\n",
|
|
" loss_ax.set_ylabel('loss')\n",
|
|
" loss_ax.legend(loc='upper left')\n",
|
|
"\n",
|
|
" acc_ax.plot(hist.history['accuracy'], 'b', label='train acc')\n",
|
|
" acc_ax.plot(hist.history['val_accuracy'], 'g', label='val acc')\n",
|
|
" acc_ax.set_ylabel('accuracy')\n",
|
|
" acc_ax.legend(loc='upper right')\n",
|
|
"\n",
|
|
" plt.show()\n",
|
|
"hist = test()\n",
|
|
"plot_history(hist)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(model(x).numpy())\n",
|
|
"print(y)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pso",
|
|
"language": "python",
|
|
"name": "pso"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
},
|
|
"widgets": {
|
|
"application/vnd.jupyter.widget-state+json": {
|
|
"state": {},
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|