Building the database using sql file

This commit is contained in:
2025-04-02 15:11:46 -04:00
parent fa870f3f54
commit c78aa38db9
8 changed files with 642 additions and 34 deletions

41
main.py
View File

@@ -1,5 +1,6 @@
import numpy as np # helps with the math
import matplotlib.pyplot as plt # to plot error during training
from data.db_connect import Database
# input data
inputs = np.array([[0, 0, 1, 0],
@@ -55,22 +56,30 @@ class NeuralNetwork:
prediction = self.sigmoid(np.dot(new_input, self.weights))
return prediction
# create neural network
NN = NeuralNetwork(inputs, outputs)
# train neural network
NN.train()
if __name__ == '__main__':
sql_file = "./data/build_db.sql"
db_file = "./database/baseball.db"
# create two new examples to predict
example = np.array([[1, 1, 1, 0]])
example_2 = np.array([[0, 0, 1, 1]])
db_conn = Database(db_file)
db_conn.build_database(sql_file)
# print the predictions for both examples
print(NN.predict(example), ' - Correct: ', example[0][0])
print(NN.predict(example_2), ' - Correct: ', example_2[0][0])
else:
# create neural network
NN = NeuralNetwork(inputs, outputs)
# train neural network
NN.train()
# plot the error over the entire training duration
plt.figure(figsize=(15,5))
plt.plot(NN.epoch_list, NN.error_history)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.savefig('plot.png')
# create two new examples to predict
example = np.array([[1, 1, 1, 0]])
example_2 = np.array([[0, 0, 1, 1]])
# print the predictions for both examples
print(NN.predict(example), ' - Correct: ', example[0][0])
print(NN.predict(example_2), ' - Correct: ', example_2[0][0])
# plot the error over the entire training duration
plt.figure(figsize=(15,5))
plt.plot(NN.epoch_list, NN.error_history)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.savefig('plot.png')