mirror of
https://github.com/jung-geun/PSO.git
synced 2025-12-20 04:50:45 +09:00
23-10-21
version 1.0.2 back propagation 설정 가능 => 초기에 한해서 역전파 1회 실행 가능
This commit is contained in:
@@ -115,6 +115,7 @@ best_score = pso_mnist.fit(
|
|||||||
empirical_balance=False,
|
empirical_balance=False,
|
||||||
dispersion=False,
|
dispersion=False,
|
||||||
batch_size=5000,
|
batch_size=5000,
|
||||||
|
back_propagation=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|||||||
1
mnist.py
1
mnist.py
@@ -117,6 +117,7 @@ best_score = pso_mnist.fit(
|
|||||||
empirical_balance=False,
|
empirical_balance=False,
|
||||||
dispersion=False,
|
dispersion=False,
|
||||||
batch_size=5000,
|
batch_size=5000,
|
||||||
|
back_propagation=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .optimizer import Optimizer as optimizer
|
from .optimizer import Optimizer as optimizer
|
||||||
from .particle import Particle as particle
|
from .particle import Particle as particle
|
||||||
|
|
||||||
__version__ = "1.0.1"
|
__version__ = "1.0.2"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"optimizer",
|
"optimizer",
|
||||||
|
|||||||
@@ -328,6 +328,7 @@ class Optimizer:
|
|||||||
check_point: int = None,
|
check_point: int = None,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
validate_data: any = None,
|
validate_data: any = None,
|
||||||
|
back_propagation: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
# Args:
|
# Args:
|
||||||
@@ -393,6 +394,7 @@ class Optimizer:
|
|||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
sys.exit(ve)
|
sys.exit(ve)
|
||||||
|
|
||||||
|
if back_propagation:
|
||||||
model_ = keras.models.model_from_json(self.model.to_json())
|
model_ = keras.models.model_from_json(self.model.to_json())
|
||||||
model_.compile(
|
model_.compile(
|
||||||
loss=self.loss,
|
loss=self.loss,
|
||||||
|
|||||||
Reference in New Issue
Block a user