mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-07-10
mnist 46% 달성
This commit is contained in:
19
mnist_tf.py
19
mnist_tf.py
@@ -60,15 +60,30 @@ model = make_model()
|
||||
x_train, y_train, x_test, y_test = get_data()
|
||||
|
||||
model.compile(
|
||||
optimizer="sgd", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
|
||||
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
|
||||
)
|
||||
|
||||
# model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
|
||||
|
||||
print("Training model...")
|
||||
model.fit(x_train, y_train, epochs=1000, batch_size=128, verbose=1)
|
||||
model.fit(x_train, y_train, epochs=100, batch_size=128, verbose=1)
|
||||
|
||||
print("Evaluating model...")
|
||||
model.evaluate(x_test, y_test, verbose=1)
|
||||
|
||||
weights = model.get_weights()
|
||||
|
||||
for w in weights:
|
||||
print(w.shape)
|
||||
print(w)
|
||||
print(w.min(), w.max())
|
||||
|
||||
model.save_weights("weights.h5")
|
||||
|
||||
# %%
|
||||
for w in weights:
|
||||
print(w.shape)
|
||||
print(w)
|
||||
print(w.min(), w.max())
|
||||
|
||||
# %%
|
||||
|
||||
Reference in New Issue
Block a user