모델과 파티클 클래스에 타입 힌트 추가

This commit is contained in:
jung-geun
2024-03-08 20:15:54 +09:00
parent fa9af45a95
commit 940580b7a6
3 changed files with 35 additions and 28 deletions

View File

@@ -13,7 +13,7 @@ from sklearn.model_selection import train_test_split
from tensorboard.plugins.hparams import api as hp
from tensorflow import keras
from tqdm.auto import tqdm
from typing import Any
from .particle import Particle
@@ -33,22 +33,9 @@ class Optimizer:
def __init__(
self,
model: keras.models,
loss: any = None,
n_particles: int = None,
c0: float = 0.5,
c1: float = 0.3,
w_min: float = 0.1,
w_max: float = 0.9,
negative_swarm: float = 0,
mutation_swarm: float = 0,
np_seed: int = None,
tf_seed: int = None,
random_state: tuple = None,
convergence_reset: bool = False,
convergence_reset_patience: int = 10,
convergence_reset_min_delta: float = 0.0001,
convergence_reset_monitor: str = "loss",
model: keras.Model,
loss: Any,
**kwargs,
):
"""
particle swarm optimization
@@ -63,15 +50,33 @@ class Optimizer:
w_max (float): 최대 관성 수치
negative_swarm (float): 최적해와 반대로 이동할 파티클 비율 - 0 ~ 1 사이의 값
mutation_swarm (float): 돌연변이가 일어날 확률
np_seed (int, optional): numpy seed. Defaults to None.
tf_seed (int, optional): tensorflow seed. Defaults to None.
convergence_reset (bool, optional): early stopping 사용 여부. Defaults to False.
convergence_reset_patience (int, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 10.
convergence_reset_min_delta (float, optional): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001.
convergence_reset_monitor (str, optional): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse"
np_seed (int | None): numpy seed. Defaults to None.
tf_seed (int | None): tensorflow seed. Defaults to None.
random_state (tuple): numpy random state. Defaults to None.
convergence_reset (bool): early stopping 사용 여부. Defaults to False.
convergence_reset_patience (int): early stopping 사용시 얼마나 기다릴지. Defaults to 10.
convergence_reset_min_delta (float): early stopping 사용시 얼마나 기다릴지. Defaults to 0.0001.
convergence_reset_monitor (str): early stopping 사용시 어떤 값을 기준으로 할지. Defaults to "loss". - "loss" or "acc" or "mse"
"""
try:
n_particles = kwargs.get("n_particles", 10)
c0 = kwargs.get("c0", 0.5)
c1 = kwargs.get("c1", 0.3)
w_min = kwargs.get("w_min", 0.1)
w_max = kwargs.get("w_max", 0.9)
negative_swarm = kwargs.get("negative_swarm", 0)
mutation_swarm = kwargs.get("mutation_swarm", 0)
np_seed = kwargs.get("np_seed", None)
tf_seed = kwargs.get("tf_seed", None)
random_state = kwargs.get("random_state", None)
convergence_reset = kwargs.get("convergence_reset", False)
convergence_reset_patience = kwargs.get("convergence_reset_patience", 10)
convergence_reset_min_delta = kwargs.get(
"convergence_reset_min_delta", 0.0001
)
convergence_reset_monitor = kwargs.get("convergence_reset_monitor", "loss")
if model is None:
raise ValueError("model is None")
elif model is not None and not isinstance(model, keras.models.Model):