mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-19 20:44:39 +09:00
23-07-08
mse -> sparse_categorical_crossentropy 로 수정 ( BP 에서 mse 로는 학습이 되지 않음 )
This commit is contained in:
9
mnist.py
9
mnist.py
@@ -50,11 +50,12 @@ def make_model():
|
||||
|
||||
# %%
|
||||
model = make_model()
|
||||
x_test, y_test = get_data_test()
|
||||
x_train, y_train, x_test, y_test = get_data()
|
||||
|
||||
loss = [
|
||||
"mse",
|
||||
"categorical_crossentropy",
|
||||
"sparse_categorical_crossentropy",
|
||||
"binary_crossentropy",
|
||||
"kullback_leibler_divergence",
|
||||
"poisson",
|
||||
@@ -69,7 +70,7 @@ if __name__ == "__main__":
|
||||
try:
|
||||
pso_mnist = Optimizer(
|
||||
model,
|
||||
loss=loss[0],
|
||||
loss=loss[2],
|
||||
n_particles=100,
|
||||
c0=0.35,
|
||||
c1=0.8,
|
||||
@@ -80,8 +81,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
best_score = pso_mnist.fit(
|
||||
x_test,
|
||||
y_test,
|
||||
x_train,
|
||||
y_train,
|
||||
epochs=200,
|
||||
save=True,
|
||||
save_path="./result/mnist",
|
||||
|
||||
Reference in New Issue
Block a user