xor iris 수치 교정
파티클의 분포 조정 가능하게 수정
random 시드 추출
This commit is contained in:
jung-geun
2023-07-12 05:03:18 +09:00
parent 2b010c4257
commit 7d22ededc7
19 changed files with 201 additions and 280 deletions

View File

@@ -41,6 +41,8 @@ class Optimizer:
mutation_swarm: float = 0,
np_seed: int = None,
tf_seed: int = None,
particle_min: float = -5,
particle_max: float = 5,
):
"""
particle swarm optimization
@@ -63,6 +65,8 @@ class Optimizer:
if tf_seed is not None:
tf.random.set_seed(tf_seed)
self.random_state = np.random.get_state()
self.model = model # 모델 구조
self.loss = loss # 손실함수
self.n_particles = n_particles # 파티클 개수
@@ -82,6 +86,7 @@ class Optimizer:
self.renewal = "acc"
self.Dispersion = False
self.day = datetime.now().strftime("%m-%d-%H-%M")
self.empirical_balance = False
negative_count = 0
@@ -89,7 +94,7 @@ class Optimizer:
m = keras.models.model_from_json(model.to_json())
init_weights = m.get_weights()
w_, sh_, len_ = self._encode(init_weights)
w_ = np.random.uniform(-3, 3, len(w_))
w_ = np.random.uniform(particle_min, particle_max, len(w_))
m.set_weights(self._decode(w_, sh_, len_))
m.compile(loss=self.loss, optimizer="sgd", metrics=["accuracy"])
self.particles[i] = Particle(
@@ -254,6 +259,7 @@ class Optimizer:
elif renewal == "both":
if local_score[1] > self.g_best_score[0]:
self.g_best_score[0] = local_score[1]
self.g_best_score[1] = local_score[0]
self.g_best = p.get_best_weights()
self.g_best_ = p.get_best_weights()
@@ -274,7 +280,6 @@ class Optimizer:
else:
f.write("\n")
f.close()
del local_score
gc.collect()
@@ -424,7 +429,6 @@ class Optimizer:
f.write(", ")
else:
f.write("\n")
f.close()
if check_point is not None:
if epoch % check_point == 0:
@@ -498,6 +502,11 @@ class Optimizer:
"Dispersion": self.Dispersion,
"negative_swarm": self.negative_swarm,
"mutation_swarm": self.mutation_swarm,
"random_state_0": self.random_state[0],
"random_state_1": self.random_state[1].tolist(),
"random_state_2": self.random_state[2],
"random_state_3": self.random_state[3],
"random_state_4": self.random_state[4],
"renewal": self.renewal,
}
@@ -507,8 +516,6 @@ class Optimizer:
) as f:
json.dump(json_save, f, indent=4)
f.close()
def _check_point_save(self, save_path: str = f"./result/check_point"):
"""
중간 저장