diff --git a/server/models/poll.go b/server/models/poll.go index fcc1505..e1d5f38 100644 --- a/server/models/poll.go +++ b/server/models/poll.go @@ -3,6 +3,10 @@ package models type Poll struct { ID int64 `json:"id"` Question string `json:"question"` + MemberYes int64 `json:"member_yes"` + MemberNo int64 `json:"member_no"` + NonMemberYes int64 `json:"non_member_yes` + NonMemberNo int64 `json:"non_member_no` TotalVotes int `json:"total_votes"` WhoVoted []string `json:"who_voted"` CreatedAt string `json:"created_at"` diff --git a/server/services/poll.go b/server/services/poll.go index ac22f3f..c9ee6b7 100644 --- a/server/services/poll.go +++ b/server/services/poll.go @@ -9,6 +9,7 @@ import ( ) var ErrQuestionAlreadyExists = errors.New("Question already exists") +var ErrQuestionDoesntExist = errors.New("Question does not exist yet") func CreatePoll(poll *models.Poll) (*models.Poll, error) { new_poll := models.Poll{} @@ -62,4 +63,66 @@ func CreatePoll(poll *models.Poll) (*models.Poll, error) { new_poll.ID, err = res.LastInsertId() return &new_poll, err +} + +func GetPollByQuestion(question string) (*models.Poll, error) { + new_poll := models.Poll{} + + db_conn, err := db.Connect() + if err != nil { + return nil, err + } + defer db.Close() + + get_poll_stmt, err := db_conn.Prepare(` + SELECT + id, question, + member_yes_votes, member_no_votes, + non_member_yes_votes, non_member_no_votes, + created_at, updated_at, + expires_at + FROM polls + WHERE question == $1 + `) + if err != nil { + return nil, err + } + defer get_poll_stmt.Close() + + err = get_poll_stmt.QueryRow(question).Scan( + &new_poll.ID, &new_poll.Question, + &new_poll.MemberYes, &new_poll.MemberNo, + &new_poll.NonMemberYes, &new_poll.NonMemberNo, + &new_poll.CreatedAt, &new_poll.UpdatedAt, + &new_poll.ExpiresAt, + ) + + if err == sql.ErrNoRows { + return nil, ErrQuestionDoesntExist + } else if err != nil { + return nil, err + } + + get_voters_stmt, err := db_conn.Prepare (` + SELECT voter_email + FROM voters + WHERE poll_id == $1 + `) + if err != nil { + return nil, err + } + defer get_voters_stmt.Close() + + rows, err := get_voters_stmt.Query(new_poll.ID) + + for rows.Next() { + var voter_email string + err = rows.Scan(&voter_email) + if err != nil { + return nil, err + } + new_poll.WhoVoted = append(new_poll.WhoVoted, voter_email) + } + + return &new_poll, nil } \ No newline at end of file diff --git a/server/services/services_test.go b/server/services/services_test.go index 1e71387..5b37305 100644 --- a/server/services/services_test.go +++ b/server/services/services_test.go @@ -113,4 +113,50 @@ func TestAlreadyExists(t *testing.T) { if err != ErrQuestionAlreadyExists { t.Fatalf(`Should have failed adding %s as it already exists`, question) } -} \ No newline at end of file +} + +func TestGetPollByQuestion(t *testing.T) { + question := "TestQuestion" + + tmp_db, err := os.CreateTemp("", "vote_test.*.db") + if err != nil { + t.Errorf(`Failed to create temporary db file: %v`, err) + } + + init_conf := &config.Config{ + DBPath: string(tmp_db.Name()), + } + config.SetConfig(init_conf) + + defer os.Remove(tmp_db.Name()) + tmp_db.Close() + + if _, err := db.Connect(); err != nil { + t.Errorf(`Failed to create the database: %v`, err) + } + + create_poll := &models.Poll{ + Question: question, + ExpiresAt: time.Now().Add(time.Hour * 10).Format("2006-01-02 15:04:05"), + } + + new_poll, err := CreatePoll(create_poll) + + if err != nil { + t.Fatalf(`Failed to create new poll %s: %v`, question, err) + } + + if new_poll == nil { + t.Fatalf(`Failed to insert %s into table`, question) + } + + get_poll, err := GetPollByQuestion(question) + + if err != nil { + t.Fatalf(`Failed to get the poll %s: %v`, question, err) + } + + if get_poll.Question != question { + t.Fatalf(`Questions don't match: expected %s: recieved %s`, question, get_poll.Question) + } +}