mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-07-12
xor iris 수치 교정 파티클의 분포 조정 가능하게 수정 random 시드 추출
This commit is contained in:
@@ -41,6 +41,8 @@ class Optimizer:
|
||||
mutation_swarm: float = 0,
|
||||
np_seed: int = None,
|
||||
tf_seed: int = None,
|
||||
particle_min: float = -5,
|
||||
particle_max: float = 5,
|
||||
):
|
||||
"""
|
||||
particle swarm optimization
|
||||
@@ -63,6 +65,8 @@ class Optimizer:
|
||||
if tf_seed is not None:
|
||||
tf.random.set_seed(tf_seed)
|
||||
|
||||
self.random_state = np.random.get_state()
|
||||
|
||||
self.model = model # 모델 구조
|
||||
self.loss = loss # 손실함수
|
||||
self.n_particles = n_particles # 파티클 개수
|
||||
@@ -82,6 +86,7 @@ class Optimizer:
|
||||
self.renewal = "acc"
|
||||
self.Dispersion = False
|
||||
self.day = datetime.now().strftime("%m-%d-%H-%M")
|
||||
self.empirical_balance = False
|
||||
|
||||
negative_count = 0
|
||||
|
||||
@@ -89,7 +94,7 @@ class Optimizer:
|
||||
m = keras.models.model_from_json(model.to_json())
|
||||
init_weights = m.get_weights()
|
||||
w_, sh_, len_ = self._encode(init_weights)
|
||||
w_ = np.random.uniform(-3, 3, len(w_))
|
||||
w_ = np.random.uniform(particle_min, particle_max, len(w_))
|
||||
m.set_weights(self._decode(w_, sh_, len_))
|
||||
m.compile(loss=self.loss, optimizer="sgd", metrics=["accuracy"])
|
||||
self.particles[i] = Particle(
|
||||
@@ -254,6 +259,7 @@ class Optimizer:
|
||||
elif renewal == "both":
|
||||
if local_score[1] > self.g_best_score[0]:
|
||||
self.g_best_score[0] = local_score[1]
|
||||
self.g_best_score[1] = local_score[0]
|
||||
self.g_best = p.get_best_weights()
|
||||
self.g_best_ = p.get_best_weights()
|
||||
|
||||
@@ -274,7 +280,6 @@ class Optimizer:
|
||||
else:
|
||||
f.write("\n")
|
||||
|
||||
f.close()
|
||||
del local_score
|
||||
gc.collect()
|
||||
|
||||
@@ -424,7 +429,6 @@ class Optimizer:
|
||||
f.write(", ")
|
||||
else:
|
||||
f.write("\n")
|
||||
f.close()
|
||||
|
||||
if check_point is not None:
|
||||
if epoch % check_point == 0:
|
||||
@@ -498,6 +502,11 @@ class Optimizer:
|
||||
"Dispersion": self.Dispersion,
|
||||
"negative_swarm": self.negative_swarm,
|
||||
"mutation_swarm": self.mutation_swarm,
|
||||
"random_state_0": self.random_state[0],
|
||||
"random_state_1": self.random_state[1].tolist(),
|
||||
"random_state_2": self.random_state[2],
|
||||
"random_state_3": self.random_state[3],
|
||||
"random_state_4": self.random_state[4],
|
||||
"renewal": self.renewal,
|
||||
}
|
||||
|
||||
@@ -507,8 +516,6 @@ class Optimizer:
|
||||
) as f:
|
||||
json.dump(json_save, f, indent=4)
|
||||
|
||||
f.close()
|
||||
|
||||
def _check_point_save(self, save_path: str = f"./result/check_point"):
|
||||
"""
|
||||
중간 저장
|
||||
|
||||
Reference in New Issue
Block a user