mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-19 20:44:39 +09:00
23-06-24
패키지 호출 단순 수정
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
import gc
|
||||
|
||||
# import cupy as cp
|
||||
import numpy as np
|
||||
import gc
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
|
||||
|
||||
class Particle:
|
||||
"""
|
||||
Particle Swarm Optimization의 Particle을 구현한 클래스
|
||||
"""
|
||||
def __init__(self, model: keras.models, loss, negative: bool = False|True):
|
||||
def __init__(self, model: keras.models, loss, negative: bool = False, momentun: bool = False):
|
||||
"""
|
||||
Args:
|
||||
model (keras.models): 학습 및 검증을 위한 모델
|
||||
@@ -23,6 +24,7 @@ class Particle:
|
||||
i_w_ = np.random.rand(len(i_w_)) / 2 - 0.25
|
||||
self.velocities = self._decode(i_w_, s_, l_)
|
||||
self.negative = negative
|
||||
self.momentun = momentun
|
||||
self.best_score = 0
|
||||
self.best_weights = init_weights
|
||||
|
||||
@@ -146,6 +148,8 @@ class Particle:
|
||||
+ local_rate * r0 * (encode_p - encode_w)
|
||||
+ global_rate * r1 * (encode_g - encode_w)
|
||||
)
|
||||
if self.momentun:
|
||||
new_v += 0.5 * encode_v
|
||||
self.velocities = self._decode(new_v, w_sh, w_len)
|
||||
del encode_w, w_sh, w_len
|
||||
del encode_v, v_sh, v_len
|
||||
@@ -184,6 +188,8 @@ class Particle:
|
||||
+ local_rate * r0 * (w_p * encode_p - encode_w)
|
||||
+ global_rate * r1 * (w_g * encode_g - encode_w)
|
||||
)
|
||||
if self.momentun:
|
||||
new_v += 0.5 * encode_v
|
||||
self.velocities = self._decode(new_v, w_sh, w_len)
|
||||
del encode_w, w_sh, w_len
|
||||
del encode_v, v_sh, v_len
|
||||
|
||||
Reference in New Issue
Block a user