26 lines
1.0 KiB
Python
26 lines
1.0 KiB
Python
import numpy as np
|
|
from tensorflow.keras.models import Sequential
|
|
from tensorflow.keras.layers import Dense
|
|
|
|
# create NeuralNetwork class
|
|
class NeuralNetwork:
|
|
def __init__(self, inputs_len: int):
|
|
# Setup checkpoint
|
|
self.checkpoint_path = "./training/cp.ckpt"
|
|
self.cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)
|
|
|
|
# Setup model
|
|
self.model = Sequential()
|
|
self.model.add(Dense(12, input_shape=(inputs_len,), activation='relu'))
|
|
self.model.add(Dense(8, activation='relu'))
|
|
self.model.add(Dense(1, activation='sigmoid'))
|
|
self.model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
|
|
|
|
if os.path.isfile(self.checkpoint_path):
|
|
self.model.load_weights(self.checkpoint_path)
|
|
|
|
def train(inputs :list, outputs :list):
|
|
self.model.fit(inputs, outputs, epochs=150, batch_size=10, callbacks=[self.cp_callback])
|
|
|
|
def predict(self, new_input):
|
|
return self.model.predict(new_input) |