패키지 호출 단순 수정
This commit is contained in:
jung-geun
2023-06-24 03:31:40 +00:00
parent 2a28b7fa04
commit 983913f2d2
5 changed files with 70 additions and 94 deletions

View File

@@ -1,15 +1,16 @@
import tensorflow as tf
from tensorflow import keras
import gc
# import cupy as cp
import numpy as np
import gc
import tensorflow as tf
from tensorflow import keras
class Particle:
"""
Particle Swarm Optimization의 Particle을 구현한 클래스
"""
def __init__(self, model: keras.models, loss, negative: bool = False|True):
def __init__(self, model: keras.models, loss, negative: bool = False, momentun: bool = False):
"""
Args:
model (keras.models): 학습 및 검증을 위한 모델
@@ -23,6 +24,7 @@ class Particle:
i_w_ = np.random.rand(len(i_w_)) / 2 - 0.25
self.velocities = self._decode(i_w_, s_, l_)
self.negative = negative
self.momentun = momentun
self.best_score = 0
self.best_weights = init_weights
@@ -146,6 +148,8 @@ class Particle:
+ local_rate * r0 * (encode_p - encode_w)
+ global_rate * r1 * (encode_g - encode_w)
)
if self.momentun:
new_v += 0.5 * encode_v
self.velocities = self._decode(new_v, w_sh, w_len)
del encode_w, w_sh, w_len
del encode_v, v_sh, v_len
@@ -184,6 +188,8 @@ class Particle:
+ local_rate * r0 * (w_p * encode_p - encode_w)
+ global_rate * r1 * (w_g * encode_g - encode_w)
)
if self.momentun:
new_v += 0.5 * encode_v
self.velocities = self._decode(new_v, w_sh, w_len)
del encode_w, w_sh, w_len
del encode_v, v_sh, v_len