mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-05-31
전체 파티클 중 일부를 현재 속도의 음수 방향으로 진행하도록 하여 지역해에 갇혀 조기수렴하는 문제의 방안으로 사용
This commit is contained in:
@@ -6,14 +6,14 @@ from tensorflow import keras
|
||||
import numpy as np
|
||||
|
||||
class Particle:
|
||||
def __init__(self, model:keras.models, loss):
|
||||
def __init__(self, model:keras.models, loss, random:bool = False):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.init_weights = self.model.get_weights()
|
||||
i_w_,s_,l_ = self._encode(self.init_weights)
|
||||
i_w_ = np.random.rand(len(i_w_)) / 5 - 0.10
|
||||
self.velocities = self._decode(i_w_,s_,l_)
|
||||
|
||||
self.random = random
|
||||
self.best_score = 0
|
||||
self.best_weights = self.init_weights
|
||||
|
||||
@@ -94,6 +94,8 @@ class Particle:
|
||||
def _update_weights(self):
|
||||
encode_w, w_sh, w_len = self._encode(weights = self.model.get_weights())
|
||||
encode_v, _, _ = self._encode(weights = self.velocities)
|
||||
if self.random:
|
||||
encode_v = -1 * encode_v
|
||||
new_w = encode_w + encode_v
|
||||
self.model.set_weights(self._decode(new_w, w_sh, w_len))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user