diff --git a/pso/particle.py b/pso/particle.py index 764c782..282cfaa 100644 --- a/pso/particle.py +++ b/pso/particle.py @@ -293,9 +293,8 @@ class Particle: * r_0 * (self.best_weights - self.weights) + global_rate - * Particle.g_best_score[1] - * r_1 - * (best_particle_weights - self.weights) + # * Particle.g_best_score[1] + * r_1 * (best_particle_weights - self.weights) ) if self.mutation != 0.0 and rng.random() < self.mutation: @@ -310,9 +309,9 @@ class Particle: """ 가중치 업데이트 """ - new_w = np.add(self.weights, self.velocities) + self.weights = np.add(self.weights, self.velocities) - self.model.set_weights(self._decode(new_w)) + self.model.set_weights(self.get_weights()) def step(self, x, y, local_rate, global_rate, w, renewal: str = "acc"): """