package handlers import ( "encoding/binary" "encoding/hex" "fmt" "log/slog" "math/rand" "net/http" "os" "strconv" "strings" "time" "github.com/go-chi/chi/v5" "qbank/internal/auth" "qbank/internal/db" "qbank/internal/models" "qbank/internal/sampling" ) // happyCatThreshold is the minimum score percentage to show a happy cat. const happyCatThreshold = 60 type TestHandler struct { auth *auth.Manager repo *db.Repo render *Renderer } func NewTestHandler(a *auth.Manager, repo *db.Repo, r *Renderer) *TestHandler { return &TestHandler{auth: a, repo: repo, render: r} } func (h *TestHandler) NewGet(w http.ResponseWriter, r *http.Request) { totalQ, _ := h.repo.CountQuestions() sourceCounts, _ := h.repo.CountBySource() data := BaseData(h.auth, r) data["TotalQ"] = totalQ data["SourceStats"] = sourceCounts h.render.Render(w, http.StatusOK, "test_new", data) } func (h *TestHandler) NewPost(w http.ResponseWriter, r *http.Request) { if !h.auth.CheckCSRF(r) { HTTPError(w, http.StatusForbidden) return } r.Body = http.MaxBytesReader(w, r.Body, 4096) if err := r.ParseForm(); err != nil { HTTPError(w, http.StatusBadRequest) return } user := auth.UserFromCtx(r.Context()) source := r.FormValue("source") mode := r.FormValue("mode") n, _ := strconv.Atoi(r.FormValue("n")) if n <= 0 { n = 10 } questions, err := h.repo.ListQuestions(db.ListFilter{Source: source}) if err != nil { slog.Error("list questions for test", "err", err) HTTPError(w, http.StatusInternalServerError) return } if len(questions) == 0 { totalQ, _ := h.repo.CountQuestions() sourceCounts, _ := h.repo.CountBySource() data := BaseData(h.auth, r) data["TotalQ"] = totalQ data["SourceStats"] = sourceCounts data["Error"] = "No questions available for the selected filter." h.render.Render(w, http.StatusOK, "test_new", data) return } if n > len(questions) { n = len(questions) } ids := make([]string, len(questions)) for i, q := range questions { ids[i] = q.ID } var candidates []sampling.Candidate if mode == "uniform" { candidates = make([]sampling.Candidate, len(questions)) for i, q := range questions { candidates[i] = sampling.Candidate{ID: q.ID, Weight: 1.0} } } else { stats, err := h.repo.GetStatsForUser(user.ID, ids) if err != nil { slog.Error("get stats for test", "err", err) HTTPError(w, http.StatusInternalServerError) return } now := time.Now() candidates = make([]sampling.Candidate, len(questions)) for i, q := range questions { candidates[i] = sampling.Candidate{ ID: q.ID, Weight: sampling.ComputeWeight(stats[q.ID], now), } } } rng := rand.New(rand.NewSource(time.Now().UnixNano())) picked := sampling.SelectWeighted(candidates, n, rng) // Shuffle presentation order independently of selection order. rng.Shuffle(len(picked), func(i, j int) { picked[i], picked[j] = picked[j], picked[i] }) pickedIDs := make([]string, len(picked)) for i, c := range picked { pickedIDs[i] = c.ID } testID, err := h.repo.CreateTest(user.ID, pickedIDs) if err != nil { slog.Error("create test", "err", err) HTTPError(w, http.StatusInternalServerError) return } http.Redirect(w, r, fmt.Sprintf("/test/%d/q/1", testID), http.StatusSeeOther) } func (h *TestHandler) QuestionGet(w http.ResponseWriter, r *http.Request) { n, test, ok := h.loadTestAndN(w, r) if !ok { return } questionID := test.QuestionIDs[n-1] q, answers, err := h.repo.GetQuestion(questionID) if err != nil { slog.Error("get question for test", "err", err) HTTPError(w, http.StatusInternalServerError) return } data := BaseData(h.auth, r) data["TestID"] = test.ID data["N"] = n data["Total"] = test.NQuestions data["Question"] = q data["Answers"] = deterministicShuffle(answers, test.ID, questionID) data["ProgressPct"] = (n - 1) * 100 / test.NQuestions h.render.Render(w, http.StatusOK, "test_question", data) } func (h *TestHandler) QuestionPost(w http.ResponseWriter, r *http.Request) { n, test, ok := h.loadTestAndN(w, r) if !ok { return } if !h.auth.CheckCSRF(r) { HTTPError(w, http.StatusForbidden) return } r.Body = http.MaxBytesReader(w, r.Body, 4096) if err := r.ParseForm(); err != nil { HTTPError(w, http.StatusBadRequest) return } user := auth.UserFromCtx(r.Context()) questionID := test.QuestionIDs[n-1] _, answers, err := h.repo.GetQuestion(questionID) if err != nil { slog.Error("get question for answer", "err", err) HTTPError(w, http.StatusInternalServerError) return } var selectedID *int64 var isCorrect bool if rawID := r.FormValue("answer_id"); rawID != "" { id, err := strconv.ParseInt(rawID, 10, 64) if err == nil { selectedID = &id for _, a := range answers { if a.ID == id && a.IsCorrect { isCorrect = true break } } } } if err := h.repo.RecordAnswer(test.ID, questionID, selectedID, isCorrect); err != nil { slog.Error("record answer", "err", err) HTTPError(w, http.StatusInternalServerError) return } if err := h.repo.UpsertStat(user.ID, questionID, isCorrect); err != nil { slog.Error("upsert stat", "err", err) } if n < test.NQuestions { http.Redirect(w, r, fmt.Sprintf("/test/%d/q/%d", test.ID, n+1), http.StatusSeeOther) return } if err := h.repo.FinishTest(test.ID); err != nil { slog.Error("finish test", "err", err) } http.Redirect(w, r, fmt.Sprintf("/test/%d/results", test.ID), http.StatusSeeOther) } // ResultItem holds per-question data for the results page. type ResultItem struct { Question *models.Question Answers []*ResultAnswer UserRight bool // user selected the correct answer Unanswered bool // user skipped without selecting } // ResultAnswer annotates each answer with display markers. type ResultAnswer struct { *models.Answer UserPicked bool // user selected this answer } func (h *TestHandler) ResultsGet(w http.ResponseWriter, r *http.Request) { user := auth.UserFromCtx(r.Context()) testID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) if err != nil { HTTPError(w, http.StatusNotFound) return } test, err := h.repo.GetTest(testID) if err != nil || test.UserID != user.ID { HTTPError(w, http.StatusNotFound) return } testAnswers, err := h.repo.GetTestAnswers(testID) if err != nil { slog.Error("get test answers", "err", err) HTTPError(w, http.StatusInternalServerError) return } // Index test answers by question ID for quick lookup. taByQ := make(map[string]*models.TestAnswer, len(testAnswers)) for _, ta := range testAnswers { taByQ[ta.QuestionID] = ta } var items []ResultItem nCorrect := 0 for _, qid := range test.QuestionIDs { q, answers, err := h.repo.GetQuestion(qid) if err != nil { slog.Error("get question for results", "qid", qid, "err", err) HTTPError(w, http.StatusInternalServerError) return } ta := taByQ[qid] var selectedID int64 unanswered := ta == nil || !ta.SelectedAnswerID.Valid if !unanswered { selectedID = ta.SelectedAnswerID.Int64 } userRight := ta != nil && ta.IsCorrect.Valid && ta.IsCorrect.Bool if userRight { nCorrect++ } ra := make([]*ResultAnswer, len(answers)) for i, a := range answers { ra[i] = &ResultAnswer{ Answer: a, UserPicked: !unanswered && a.ID == selectedID, } } items = append(items, ResultItem{ Question: q, Answers: ra, UserRight: userRight, Unanswered: unanswered, }) } var timeTaken string if test.CompletedAt.Valid { d := test.CompletedAt.Time.Sub(test.CreatedAt).Round(time.Second) h := int(d.Hours()) m := int(d.Minutes()) % 60 s := int(d.Seconds()) % 60 if h > 0 { timeTaken = fmt.Sprintf("%dh %dm %ds", h, m, s) } else if m > 0 { timeTaken = fmt.Sprintf("%dm %ds", m, s) } else { timeTaken = fmt.Sprintf("%ds", s) } } mood := "sad" if test.NQuestions > 0 && nCorrect*100/test.NQuestions >= happyCatThreshold { mood = "happy" } catURL := randomCatURL(mood) data := BaseData(h.auth, r) data["Test"] = test data["Items"] = items data["NCorrect"] = nCorrect data["TimeTaken"] = timeTaken data["CatURL"] = catURL h.render.Render(w, http.StatusOK, "test_results", data) } // loadTestAndN extracts and validates the test ID and question number (n) from // URL params. Returns the 1-based question index and the test, or writes an // error and returns false. func (h *TestHandler) loadTestAndN(w http.ResponseWriter, r *http.Request) (int, *models.Test, bool) { user := auth.UserFromCtx(r.Context()) testID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) if err != nil { HTTPError(w, http.StatusNotFound) return 0, nil, false } n, err := strconv.Atoi(chi.URLParam(r, "n")) if err != nil { HTTPError(w, http.StatusNotFound) return 0, nil, false } test, err := h.repo.GetTest(testID) if err != nil || test.UserID != user.ID { HTTPError(w, http.StatusNotFound) return 0, nil, false } if n < 1 || n > test.NQuestions { HTTPError(w, http.StatusNotFound) return 0, nil, false } return n, test, true } // randomCatURL picks a random image from web/static/cats// and returns // its URL path, or an empty string if the folder is missing or empty. func randomCatURL(mood string) string { dir := "web/static/cats/" + mood entries, err := os.ReadDir(dir) if err != nil { return "" } var images []string for _, e := range entries { if e.IsDir() { continue } n := strings.ToLower(e.Name()) if strings.HasSuffix(n, ".jpg") || strings.HasSuffix(n, ".jpeg") || strings.HasSuffix(n, ".png") || strings.HasSuffix(n, ".gif") || strings.HasSuffix(n, ".webp") { images = append(images, e.Name()) } } if len(images) == 0 { return "" } return "/static/cats/" + mood + "/" + images[rand.Intn(len(images))] } // deterministicShuffle returns a copy of answers shuffled by a seed derived // from the test ID and question ID, so the order is stable across page reloads. func deterministicShuffle(answers []*models.Answer, testID int64, questionID string) []*models.Answer { b, _ := hex.DecodeString(questionID) qInt := int64(binary.BigEndian.Uint64(b)) rng := rand.New(rand.NewSource(testID ^ qInt)) out := make([]*models.Answer, len(answers)) copy(out, answers) rng.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) return out }