Initial setup for training
This commit is contained in:
@@ -1,26 +1,40 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import Sequential
|
||||
from tensorflow.keras.layers import Dense
|
||||
import tensorflow as tf
|
||||
|
||||
# create NeuralNetwork class
|
||||
class NeuralNetwork:
|
||||
def __init__(self, inputs_len: int):
|
||||
def __init__(self, input_length: int):
|
||||
self.model = tf.keras.Sequential([
|
||||
tf.keras.Input(shape=(input_length,), dtype=tf.int64),
|
||||
tf.keras.layers.Dense(512, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.5),
|
||||
tf.keras.layers.Dense(256, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.5),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(1, activation='sigmoid'),
|
||||
])
|
||||
|
||||
self.model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
|
||||
|
||||
# Setup checkpoint
|
||||
self.checkpoint_path = "./training/cp.ckpt"
|
||||
self.cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)
|
||||
self.checkpoint_path = "./training/cp.ckpt.weights.h5"
|
||||
self.cp_callback = tf.keras.callbacks.ModelCheckpoint(
|
||||
filepath=self.checkpoint_path,
|
||||
save_weights_only=True,
|
||||
save_best_only=True,
|
||||
monitor='loss',
|
||||
mode='min',
|
||||
)
|
||||
|
||||
# Setup model
|
||||
self.model = Sequential()
|
||||
self.model.add(Dense(12, input_shape=(inputs_len,), activation='relu'))
|
||||
self.model.add(Dense(8, activation='relu'))
|
||||
self.model.add(Dense(1, activation='sigmoid'))
|
||||
self.model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
|
||||
#if os.path.isfile(self.checkpoint_path):
|
||||
# self.model.load_weights(self.checkpoint_path)
|
||||
|
||||
if os.path.isfile(self.checkpoint_path):
|
||||
self.model.load_weights(self.checkpoint_path)
|
||||
def train(self, inputs :list, outputs :list):
|
||||
self.model.fit(inputs, outputs, epochs=100, batch_size=64, callbacks=[self.cp_callback])
|
||||
|
||||
def train(inputs :list, outputs :list):
|
||||
self.model.fit(inputs, outputs, epochs=150, batch_size=10, callbacks=[self.cp_callback])
|
||||
def summary(self):
|
||||
print(self.model.summary())
|
||||
|
||||
def predict(self, new_input):
|
||||
return self.model.predict(new_input)
|
||||
Reference in New Issue
Block a user