모델과 파티클 클래스에 타입 힌트 추가

This commit is contained in:
jung-geun
2024-03-08 20:15:54 +09:00
parent fa9af45a95
commit 940580b7a6
3 changed files with 35 additions and 28 deletions

View File

@@ -1,5 +1,6 @@
import numpy as np
from tensorflow import keras
from typing import Any
class Particle:
@@ -19,8 +20,8 @@ class Particle:
def __init__(
self,
model: keras.models,
loss: any = None,
model: keras.Model,
loss: Any = None,
negative: bool = False,
mutation: float = 0,
converge_reset: bool = False,
@@ -142,6 +143,7 @@ class Particle:
Returns:
(float): 점수
"""
score = self.model.evaluate(x, y, verbose=0)
if renewal == "loss":
if score[0] < self.best_score[0]:
@@ -163,7 +165,7 @@ class Particle:
def __check_converge_reset(
self,
score,
monitor: str = None,
monitor: str = "auto",
patience: int = 10,
min_delta: float = 0.0001,
):
@@ -176,7 +178,7 @@ class Particle:
patience (int, optional): early stop을 위한 기다리는 횟수. Defaults to 10.
min_delta (float, optional): early stop을 위한 최소 변화량. Defaults to 0.0001.
"""
if monitor is None:
if monitor == "auto":
monitor = "acc"
if monitor in ["loss"]:
self.score_history.append(score[0])