mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
모델과 파티클 클래스에 타입 힌트 추가
This commit is contained in:
@@ -13,7 +13,7 @@ from sklearn.model_selection import train_test_split
|
||||
from tensorboard.plugins.hparams import api as hp
|
||||
from tensorflow import keras
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from typing import Any
|
||||
from .particle import Particle
|
||||
|
||||
|
||||
@@ -33,22 +33,9 @@ class Optimizer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: keras.models,
|
||||
loss: any = None,
|
||||
n_particles: int = None,
|
||||
c0: float = 0.5,
|
||||
c1: float = 0.3,
|
||||
w_min: float = 0.1,
|
||||
w_max: float = 0.9,
|
||||
negative_swarm: float = 0,
|
||||
mutation_swarm: float = 0,
|
||||
np_seed: int = None,
|
||||
tf_seed: int = None,
|
||||
random_state: tuple = None,
|
||||
convergence_reset: bool = False,
|
||||
convergence_reset_patience: int = 10,
|
||||
convergence_reset_min_delta: float = 0.0001,
|
||||
convergence_reset_monitor: str = "loss",
|
||||
model: keras.Model,
|
||||
loss: Any,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
particle swarm optimization
|
||||
@@ -63,15 +50,33 @@ class Optimizer:
|
||||
w_max (float): 최대 관성 수치
|
||||
negative_swarm (float): 최적해와 반대로 이동할 파티클 비율 - 0 ~ 1 사이의 값
|
||||
mutation_swarm (float): 돌연변이가 일어날 확률
|
||||
np_seed (int, optional): numpy seed. Defaults to None.
|
||||
tf_seed (int, optional): tensorflow seed. Defaults to None.
|
||||
convergence_reset (bool, optional): early stopping 사용 여부. Defaults to False.
|
||||
convergence_reset_patience (int, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 10.
|
||||
convergence_reset_min_delta (float, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001.
|
||||
convergence_reset_monitor (str, optional): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse"
|
||||
np_seed (int | None): numpy seed. Defaults to None.
|
||||
tf_seed (int | None): tensorflow seed. Defaults to None.
|
||||
random_state (tuple): numpy random state. Defaults to None.
|
||||
convergence_reset (bool): early stopping 사용 여부. Defaults to False.
|
||||
convergence_reset_patience (int): early stopping 사용시 얼마나 기다릴지. Defaults to 10.
|
||||
convergence_reset_min_delta (float): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001.
|
||||
convergence_reset_monitor (str): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse"
|
||||
"""
|
||||
|
||||
try:
|
||||
n_particles = kwargs.get("n_particles", 10)
|
||||
c0 = kwargs.get("c0", 0.5)
|
||||
c1 = kwargs.get("c1", 0.3)
|
||||
w_min = kwargs.get("w_min", 0.1)
|
||||
w_max = kwargs.get("w_max", 0.9)
|
||||
negative_swarm = kwargs.get("negative_swarm", 0)
|
||||
mutation_swarm = kwargs.get("mutation_swarm", 0)
|
||||
np_seed = kwargs.get("np_seed", None)
|
||||
tf_seed = kwargs.get("tf_seed", None)
|
||||
random_state = kwargs.get("random_state", None)
|
||||
convergence_reset = kwargs.get("convergence_reset", False)
|
||||
convergence_reset_patience = kwargs.get("convergence_reset_patience", 10)
|
||||
convergence_reset_min_delta = kwargs.get(
|
||||
"convergence_reset_min_delta", 0.0001
|
||||
)
|
||||
convergence_reset_monitor = kwargs.get("convergence_reset_monitor", "loss")
|
||||
|
||||
if model is None:
|
||||
raise ValueError("model is None")
|
||||
elif model is not None and not isinstance(model, keras.models.Model):
|
||||
|
||||
Reference in New Issue
Block a user