diff --git a/server/models/poll.go b/server/models/poll.go index fcc1505..e1d5f38 100644 --- a/server/models/poll.go +++ b/server/models/poll.go @@ -3,6 +3,10 @@ package models type Poll struct { ID int64 `json:"id"` Question string `json:"question"` + MemberYes int64 `json:"member_yes"` + MemberNo int64 `json:"member_no"` + NonMemberYes int64 `json:"non_member_yes` + NonMemberNo int64 `json:"non_member_no` TotalVotes int `json:"total_votes"` WhoVoted []string `json:"who_voted"` CreatedAt string `json:"created_at"` diff --git a/server/services/poll.go b/server/services/poll.go index ac22f3f..dee5ea7 100644 --- a/server/services/poll.go +++ b/server/services/poll.go @@ -3,12 +3,15 @@ package services import ( "database/sql" "errors" + "time" "go-sjles-pta-vote/server/db" "go-sjles-pta-vote/server/models" ) var ErrQuestionAlreadyExists = errors.New("Question already exists") +var ErrQuestionDoesntExist = errors.New("Question does not exist yet") +var ErrVoterAlreadyVoted = errors.New("Voter already voted") func CreatePoll(poll *models.Poll) (*models.Poll, error) { new_poll := models.Poll{} @@ -62,4 +65,172 @@ func CreatePoll(poll *models.Poll) (*models.Poll, error) { new_poll.ID, err = res.LastInsertId() return &new_poll, err +} + +func GetPollByQuestion(question string) (*models.Poll, error) { + new_poll := models.Poll{} + + db_conn, err := db.Connect() + if err != nil { + return nil, err + } + defer db.Close() + + get_poll_stmt, err := db_conn.Prepare(` + SELECT + id, question, + member_yes_votes, member_no_votes, + non_member_yes_votes, non_member_no_votes, + created_at, updated_at, + expires_at + FROM polls + WHERE question == $1 + `) + if err != nil { + return nil, err + } + defer get_poll_stmt.Close() + + err = get_poll_stmt.QueryRow(question).Scan( + &new_poll.ID, &new_poll.Question, + &new_poll.MemberYes, &new_poll.MemberNo, + &new_poll.NonMemberYes, &new_poll.NonMemberNo, + &new_poll.CreatedAt, &new_poll.UpdatedAt, + &new_poll.ExpiresAt, + ) + + if err == sql.ErrNoRows { + return nil, ErrQuestionDoesntExist + } else if err != nil { + return nil, err + } + + get_voters_stmt, err := db_conn.Prepare (` + SELECT voter_email + FROM voters + WHERE poll_id == $1 + `) + if err != nil { + return nil, err + } + defer get_voters_stmt.Close() + + rows, err := get_voters_stmt.Query(new_poll.ID) + + for rows.Next() { + var voter_email string + err = rows.Scan(&voter_email) + if err != nil { + return nil, err + } + new_poll.WhoVoted = append(new_poll.WhoVoted, voter_email) + } + + return &new_poll, nil +} + +func GetAndCreatePollByQuestion(question string) (*models.Poll, error) { + new_poll, err := GetPollByQuestion(question) + + if err == ErrQuestionDoesntExist { + create_poll := &models.Poll{ + Question: question, + ExpiresAt: time.Now().Add(time.Hour * 10).Format("2006-01-02 15:04:05"), + } + + if _, err = CreatePoll(create_poll); err != nil { + return nil, err + } + + return GetPollByQuestion(question) + } else if err != nil { + return nil, err + } else { + return new_poll, err + } +} + +func SetVote(poll_id int64, email string, vote bool) error { + db_conn, err := db.Connect() + if err != nil { + return err + } + defer db.Close() + + set_voter_stmt, err := db_conn.Prepare(` + INSERT IGNORE INTO voters + (poll_id, voter_email) + VALUES ($1, $2) + `) + if err != nil { + return err + } + defer set_voter_stmt.Close() + + res, err := set_voter_stmt.Exec(poll_id, email) + if err != nil { + return err + } else { + rows_changed, err := res.RowsAffected() + if rows_changed != 1 { + return ErrVoterAlreadyVoted + } else if err != nil { + return err + } + } + + is_voter_member_stmt, err := db_conn.Prepare(` + SELECT 1 + FROM members + WHERE email == $1 + `) + if err != nil { + return err + } + defer is_voter_member_stmt.Close() + + var member_check int64 + is_member := true + err = is_voter_member_stmt.QueryRow(email).Scan(&member_check) + if err == sql.ErrNoRows { + is_member = false + } else if err != nil { + return err + } + + // Member column name is not dependant on user input + // So it's ok to put it directly in the query + member_column_name := "member_" + if !is_member { + member_column_name = "non_" + member_column_name + } + + if vote { + member_column_name += "yes_votes" + } else { + member_column_name += "no_votes" + } + + add_vote_stmt, err := db_conn.Prepare(` + UPDATE polls + SET ` + member_column_name + ` = ` + member_column_name + ` 1 + WHERE id == $1 + `) + if err != nil { + return err + } + defer add_vote_stmt.Close() + + res, err = add_vote_stmt.Exec(poll_id) + if err != nil { + return err + } + + if num, err := res.RowsAffected(); num != 1 { + return errors.New("Failed to update votes") + } else if err != nil { + return err + } + + return nil } \ No newline at end of file diff --git a/server/services/services_test.go b/server/services/services_test.go index 1e71387..6d1ef47 100644 --- a/server/services/services_test.go +++ b/server/services/services_test.go @@ -30,8 +30,10 @@ func TestCreatePoll(t *testing.T) { {RandString(10) + "2", 2}, {RandString(10) + "3", 3}, {"\"" + RandString(10) + "4", 4}, - {"'" + RandString(10) + "5", 5}, - {";" + RandString(10) + "6", 6}, + {"\\\"" + RandString(10) + "5", 5}, + {"'" + RandString(10) + "6", 6}, + {";" + RandString(10) + "7", 7}, + {"\\" + RandString(10) + "8", 8}, } tmp_db, err := os.CreateTemp("", "vote_test.*.db") @@ -113,4 +115,102 @@ func TestAlreadyExists(t *testing.T) { if err != ErrQuestionAlreadyExists { t.Fatalf(`Should have failed adding %s as it already exists`, question) } -} \ No newline at end of file +} + +func TestGetPollByQuestion(t *testing.T) { + question := "TestQuestion" + + tmp_db, err := os.CreateTemp("", "vote_test.*.db") + if err != nil { + t.Errorf(`Failed to create temporary db file: %v`, err) + } + + init_conf := &config.Config{ + DBPath: string(tmp_db.Name()), + } + config.SetConfig(init_conf) + + defer os.Remove(tmp_db.Name()) + tmp_db.Close() + + if _, err := db.Connect(); err != nil { + t.Errorf(`Failed to create the database: %v`, err) + } + + create_poll := &models.Poll{ + Question: question, + ExpiresAt: time.Now().Add(time.Hour * 10).Format("2006-01-02 15:04:05"), + } + + new_poll, err := CreatePoll(create_poll) + + if err != nil { + t.Fatalf(`Failed to create new poll %s: %v`, question, err) + } + + if new_poll == nil { + t.Fatalf(`Failed to insert %s into table`, question) + } + + get_poll, err := GetPollByQuestion(question) + + if err != nil { + t.Fatalf(`Failed to get the poll %s: %v`, question, err) + } + + if get_poll.Question != question { + t.Fatalf(`Questions don't match: expected %s: recieved %s`, question, get_poll.Question) + } +} + +func TestGetCreatePollByQuestion(t *testing.T) { + parameters := []struct{ + question string + table_index int64 + }{ + {RandString(10) + "1", 1}, + {RandString(10) + "2", 2}, + {RandString(10) + "3", 3}, + {"\"" + RandString(10) + "4", 4}, + {"'" + RandString(10) + "5", 5}, + {";" + RandString(10) + "6", 6}, + } + + tmp_db, err := os.CreateTemp("", "vote_test.*.db") + if err != nil { + t.Errorf(`Failed to create temporary db file: %v`, err) + } + + init_conf := &config.Config{ + DBPath: string(tmp_db.Name()), + } + config.SetConfig(init_conf) + + defer os.Remove(tmp_db.Name()) + tmp_db.Close() + + if _, err := db.Connect(); err != nil { + t.Errorf(`Failed to create the database: %v`, err) + } + + for i := range parameters { + new_poll, err := GetAndCreatePollByQuestion(parameters[i].question) + + if err != nil { + t.Fatalf(`Failed to create new poll %s: %v`, parameters[i].question, err) + } + + if new_poll == nil { + t.Fatalf(`Failed to insert %s into table`, parameters[i].question) + } + + if new_poll.ID != parameters[i].table_index { + t.Fatalf(`Incorrect increment in index for %s: expected %d != %d`, parameters[i].question, parameters[i].table_index, new_poll.ID) + } + + if new_poll.Question != parameters[i].question { + t.Fatalf(`Incorrect question returned: Expected %s != %s`, parameters[i].question, new_poll.Question) + } + } +} +