diff --git a/pso/optimizer.py b/pso/optimizer.py index e2142fd..11c6401 100644 --- a/pso/optimizer.py +++ b/pso/optimizer.py @@ -13,7 +13,7 @@ from sklearn.model_selection import train_test_split from tensorboard.plugins.hparams import api as hp from tensorflow import keras from tqdm.auto import tqdm - +from typing import Any from .particle import Particle @@ -33,22 +33,9 @@ class Optimizer: def __init__( self, - model: keras.models, - loss: any = None, - n_particles: int = None, - c0: float = 0.5, - c1: float = 0.3, - w_min: float = 0.1, - w_max: float = 0.9, - negative_swarm: float = 0, - mutation_swarm: float = 0, - np_seed: int = None, - tf_seed: int = None, - random_state: tuple = None, - convergence_reset: bool = False, - convergence_reset_patience: int = 10, - convergence_reset_min_delta: float = 0.0001, - convergence_reset_monitor: str = "loss", + model: keras.Model, + loss: Any, + **kwargs, ): """ particle swarm optimization @@ -63,15 +50,33 @@ class Optimizer: w_max (float): 최대 관성 수치 negative_swarm (float): 최적해와 반대로 이동할 파티클 비율 - 0 ~ 1 사이의 값 mutation_swarm (float): 돌연변이가 일어날 확률 - np_seed (int, optional): numpy seed. Defaults to None. - tf_seed (int, optional): tensorflow seed. Defaults to None. - convergence_reset (bool, optional): early stopping 사용 여부. Defaults to False. - convergence_reset_patience (int, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 10. - convergence_reset_min_delta (float, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001. - convergence_reset_monitor (str, optional): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse" + np_seed (int | None): numpy seed. Defaults to None. + tf_seed (int | None): tensorflow seed. Defaults to None. + random_state (tuple): numpy random state. Defaults to None. + convergence_reset (bool): early stopping 사용 여부. Defaults to False. + convergence_reset_patience (int): early stopping 사용시 얼마나 기다릴지. Defaults to 10. + convergence_reset_min_delta (float): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001. + convergence_reset_monitor (str): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse" """ try: + n_particles = kwargs.get("n_particles", 10) + c0 = kwargs.get("c0", 0.5) + c1 = kwargs.get("c1", 0.3) + w_min = kwargs.get("w_min", 0.1) + w_max = kwargs.get("w_max", 0.9) + negative_swarm = kwargs.get("negative_swarm", 0) + mutation_swarm = kwargs.get("mutation_swarm", 0) + np_seed = kwargs.get("np_seed", None) + tf_seed = kwargs.get("tf_seed", None) + random_state = kwargs.get("random_state", None) + convergence_reset = kwargs.get("convergence_reset", False) + convergence_reset_patience = kwargs.get("convergence_reset_patience", 10) + convergence_reset_min_delta = kwargs.get( + "convergence_reset_min_delta", 0.0001 + ) + convergence_reset_monitor = kwargs.get("convergence_reset_monitor", "loss") + if model is None: raise ValueError("model is None") elif model is not None and not isinstance(model, keras.models.Model): diff --git a/pso/particle.py b/pso/particle.py index ee97518..3f48677 100644 --- a/pso/particle.py +++ b/pso/particle.py @@ -1,5 +1,6 @@ import numpy as np from tensorflow import keras +from typing import Any class Particle: @@ -19,8 +20,8 @@ class Particle: def __init__( self, - model: keras.models, - loss: any = None, + model: keras.Model, + loss: Any = None, negative: bool = False, mutation: float = 0, converge_reset: bool = False, @@ -142,6 +143,7 @@ class Particle: Returns: (float): 점수 """ + score = self.model.evaluate(x, y, verbose=0) if renewal == "loss": if score[0] < self.best_score[0]: @@ -163,7 +165,7 @@ class Particle: def __check_converge_reset( self, score, - monitor: str = None, + monitor: str = "auto", patience: int = 10, min_delta: float = 0.0001, ): @@ -176,7 +178,7 @@ class Particle: patience (int, optional): early stop을 위한 기다리는 횟수. Defaults to 10. min_delta (float, optional): early stop을 위한 최소 변화량. Defaults to 0.0001. """ - if monitor is None: + if monitor == "auto": monitor = "acc" if monitor in ["loss"]: self.score_history.append(score[0]) diff --git a/test/iris.py b/test/iris.py index 6a356a4..e0e8cf7 100644 --- a/test/iris.py +++ b/test/iris.py @@ -41,7 +41,7 @@ x_train, x_test, y_train, y_test = load_data() pso_iris = optimizer( - model, + model=model, loss="categorical_crossentropy", n_particles=100, c0=0.5,