mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-19 20:44:39 +09:00
23-06-03
tensorflow gpu 의 메모리 용량 제한을 추가 readme에 분류 문제별 해결 현황 추가
This commit is contained in:
@@ -12,7 +12,7 @@ class Particle:
|
||||
self.loss = loss
|
||||
init_weights = self.model.get_weights()
|
||||
i_w_, s_, l_ = self._encode(init_weights)
|
||||
i_w_ = np.random.rand(len(i_w_)) / 5 - 0.10
|
||||
i_w_ = np.random.rand(len(i_w_)) / 2 - 0.25
|
||||
self.velocities = self._decode(i_w_, s_, l_)
|
||||
self.negative = negative
|
||||
self.best_score = 0
|
||||
@@ -40,7 +40,7 @@ class Particle:
|
||||
lenght.append(len(w_))
|
||||
# w_gpu = cp.append(w_gpu, w_)
|
||||
w_gpu = np.append(w_gpu, w_)
|
||||
gc.collect()
|
||||
|
||||
return w_gpu, shape, lenght
|
||||
|
||||
"""
|
||||
@@ -62,7 +62,7 @@ class Particle:
|
||||
del start, end, w_
|
||||
del shape, lenght
|
||||
del weight
|
||||
gc.collect()
|
||||
|
||||
return weights
|
||||
|
||||
def get_score(self, x, y, renewal: str = "acc"):
|
||||
@@ -77,7 +77,7 @@ class Particle:
|
||||
if score[0] < self.best_score:
|
||||
self.best_score = score[0]
|
||||
self.best_weights = self.model.get_weights()
|
||||
gc.collect()
|
||||
|
||||
return score
|
||||
|
||||
def _update_velocity(self, local_rate, global_rate, w, g_best):
|
||||
@@ -105,7 +105,6 @@ class Particle:
|
||||
del encode_p, p_sh, p_len
|
||||
del encode_g, g_sh, g_len
|
||||
del r0, r1
|
||||
gc.collect()
|
||||
|
||||
def _update_velocity_w(self, local_rate, global_rate, w, w_p, w_g, g_best):
|
||||
encode_w, w_sh, w_len = self._encode(weights=self.model.get_weights())
|
||||
@@ -132,7 +131,6 @@ class Particle:
|
||||
del encode_p, p_sh, p_len
|
||||
del encode_g, g_sh, g_len
|
||||
del r0, r1
|
||||
gc.collect()
|
||||
|
||||
def _update_weights(self):
|
||||
encode_w, w_sh, w_len = self._encode(weights=self.model.get_weights())
|
||||
@@ -141,12 +139,10 @@ class Particle:
|
||||
self.model.set_weights(self._decode(new_w, w_sh, w_len))
|
||||
del encode_w, w_sh, w_len
|
||||
del encode_v, v_sh, v_len
|
||||
gc.collect()
|
||||
|
||||
def f(self, x, y, weights):
|
||||
self.model.set_weights(weights)
|
||||
score = self.model.evaluate(x, y, verbose=0)[1]
|
||||
gc.collect()
|
||||
if score > 0:
|
||||
return 1 / (1 + score)
|
||||
else:
|
||||
@@ -155,7 +151,6 @@ class Particle:
|
||||
def step(self, x, y, local_rate, global_rate, w, g_best, renewal: str = "acc"):
|
||||
self._update_velocity(local_rate, global_rate, w, g_best)
|
||||
self._update_weights()
|
||||
gc.collect()
|
||||
return self.get_score(x, y, renewal)
|
||||
|
||||
def step_w(
|
||||
@@ -163,7 +158,6 @@ class Particle:
|
||||
):
|
||||
self._update_velocity_w(local_rate, global_rate, w, w_p, w_g, g_best)
|
||||
self._update_weights()
|
||||
gc.collect()
|
||||
return self.get_score(x, y, renewal)
|
||||
|
||||
def get_best_score(self):
|
||||
|
||||
Reference in New Issue
Block a user