mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-10-21
loss + mse 로 조기 수렴 시 초기화 적용 파티클의 초기화를 opeimizer 에서 particle 객체로 변경 메모리의 점진적인 누수 #6 현재 누수가 다시 조금씩 증가하는것이 보임
This commit is contained in:
@@ -38,7 +38,7 @@ class Particle:
|
||||
self.loss = loss
|
||||
|
||||
try:
|
||||
if converge_reset and converge_reset_monitor not in ["acc", "accuracy", "loss"]:
|
||||
if converge_reset and converge_reset_monitor not in ["acc", "accuracy", "loss", "mse"]:
|
||||
raise ValueError(
|
||||
"converge_reset_monitor must be 'acc' or 'accuracy' or 'loss'"
|
||||
)
|
||||
@@ -50,10 +50,12 @@ class Particle:
|
||||
print(e)
|
||||
exit(1)
|
||||
|
||||
self.reset_particle()
|
||||
self.__reset_particle__()
|
||||
self.best_weights = self.model.get_weights()
|
||||
self.before_best = self.model.get_weights()
|
||||
self.negative = negative
|
||||
self.mutation = mutation
|
||||
self.best_score = 0
|
||||
self.best_score = [np.inf, 0, np.inf]
|
||||
self.before_w = 0
|
||||
self.score_history = []
|
||||
self.converge_reset = converge_reset
|
||||
@@ -131,14 +133,20 @@ class Particle:
|
||||
(float): 점수
|
||||
"""
|
||||
score = self.model.evaluate(x, y, verbose=0)
|
||||
if renewal == "acc":
|
||||
if score[1] > self.best_score:
|
||||
self.best_score = score[1]
|
||||
if renewal == "loss":
|
||||
if score[0] < self.best_score[0]:
|
||||
self.best_score[0] = score[0]
|
||||
self.best_weights = self.model.get_weights()
|
||||
elif renewal == "loss":
|
||||
if score[0] < self.best_score:
|
||||
self.best_score = score[0]
|
||||
elif renewal == "acc":
|
||||
if score[1] > self.best_score[1]:
|
||||
self.best_score[1] = score[1]
|
||||
self.best_weights = self.model.get_weights()
|
||||
elif renewal == "mse":
|
||||
if score[2] < self.best_score[2]:
|
||||
self.best_score[2] = score[2]
|
||||
self.best_weights = self.model.get_weights()
|
||||
else:
|
||||
raise ValueError("renewal must be 'acc' or 'loss' or 'mse'")
|
||||
|
||||
return score
|
||||
|
||||
@@ -156,6 +164,11 @@ class Particle:
|
||||
self.score_history.append(score[1])
|
||||
elif monitor in ["loss"]:
|
||||
self.score_history.append(score[0])
|
||||
elif monitor in ["mse"]:
|
||||
self.score_history.append(score[2])
|
||||
else:
|
||||
raise ValueError(
|
||||
"monitor must be 'acc' or 'accuracy' or 'loss' or 'mse'")
|
||||
|
||||
if len(self.score_history) > patience:
|
||||
last_scores = self.score_history[-patience:]
|
||||
@@ -163,19 +176,18 @@ class Particle:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset_particle(self):
|
||||
def __reset_particle__(self):
|
||||
self.model = keras.models.model_from_json(self.model.to_json())
|
||||
self.model.compile(optimizer="adam", loss=self.loss,
|
||||
metrics=["accuracy"])
|
||||
init_weights = self.model.get_weights()
|
||||
i_w_, i_s, i_l = self._encode(init_weights)
|
||||
self.model.compile(
|
||||
optimizer="adam",
|
||||
loss=self.loss,
|
||||
metrics=["accuracy", "mse"]
|
||||
)
|
||||
i_w_, i_s, i_l = self._encode(self.model.get_weights())
|
||||
i_w_ = np.random.uniform(-0.05, 0.05, len(i_w_))
|
||||
self.velocities = self._decode(i_w_, i_s, i_l)
|
||||
|
||||
self.best_weights = init_weights
|
||||
self.before_best = init_weights
|
||||
|
||||
del init_weights, i_w_, i_s, i_l
|
||||
del i_w_, i_s, i_l
|
||||
self.score_history = []
|
||||
|
||||
def _update_velocity(self, local_rate, global_rate, w, g_best):
|
||||
@@ -212,7 +224,7 @@ class Particle:
|
||||
+ -1 * global_rate * r_1 * (encode_g - encode_w)
|
||||
)
|
||||
if len(self.score_history) > 10 and max(self.score_history[-10:]) - min(self.score_history[-10:]) < 0.01:
|
||||
self.reset_particle()
|
||||
self.__reset_particle__()
|
||||
|
||||
else:
|
||||
new_v = (
|
||||
@@ -324,7 +336,7 @@ class Particle:
|
||||
|
||||
if self.converge_reset and self.__check_converge_reset__(
|
||||
score, self.converge_reset_monitor, self.converge_reset_patience, self.converge_reset_min_delta):
|
||||
self.reset_particle()
|
||||
self.__reset_particle__()
|
||||
|
||||
return score
|
||||
|
||||
|
||||
Reference in New Issue
Block a user