mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-06-24
패키지 호출 단순 수정
This commit is contained in:
29
mnist.py
29
mnist.py
@@ -1,30 +1,28 @@
|
||||
# %%
|
||||
import os
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
tf.random.set_seed(777) # for reproducibility
|
||||
|
||||
import numpy as np
|
||||
|
||||
np.random.seed(777)
|
||||
|
||||
from tensorflow import keras
|
||||
from keras.datasets import mnist
|
||||
from keras.models import Sequential
|
||||
from keras.layers import Dense, Dropout, Flatten
|
||||
from keras.layers import Conv2D, MaxPooling2D
|
||||
from keras import backend as K
|
||||
import gc
|
||||
from datetime import date
|
||||
|
||||
from keras import backend as K
|
||||
from keras.datasets import mnist
|
||||
from keras.layers import Conv2D, Dense, Dropout, Flatten, MaxPooling2D
|
||||
from keras.models import Sequential
|
||||
# from pso_tf import PSO
|
||||
from pso import Optimizer
|
||||
# from optimizer import Optimizer
|
||||
|
||||
|
||||
from datetime import date
|
||||
from tensorflow import keras
|
||||
from tqdm import tqdm
|
||||
|
||||
import gc
|
||||
|
||||
# print(tf.__version__)
|
||||
# print(tf.config.list_physical_devices())
|
||||
# print(f"Num GPUs Available: {len(tf.config.list_physical_devices('GPU'))}")
|
||||
@@ -74,12 +72,13 @@ if __name__ == "__main__":
|
||||
pso_mnist = Optimizer(
|
||||
model,
|
||||
loss=loss[0],
|
||||
n_particles=50,
|
||||
n_particles=75,
|
||||
c0=0.35,
|
||||
c1=0.8,
|
||||
w_min=0.7,
|
||||
w_max=1.15,
|
||||
negative_swarm=0.25
|
||||
negative_swarm=0.25,
|
||||
momentun_swarm=0.25,
|
||||
)
|
||||
|
||||
best_score = pso_mnist.fit(
|
||||
@@ -88,7 +87,7 @@ if __name__ == "__main__":
|
||||
epochs=200,
|
||||
save=True,
|
||||
save_path="./result/mnist",
|
||||
renewal="acc",
|
||||
renewal="loss",
|
||||
empirical_balance=False,
|
||||
Dispersion=False,
|
||||
check_point=25
|
||||
|
||||
Reference in New Issue
Block a user