From 968479ff51813d66a71488f7a27722942c51bb41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C4=81nis=20Kac=C4=93ns?= Date: Mon, 11 May 2026 13:56:44 +0300 Subject: [PATCH] Phase 6: take a test (weighted sampling + question flow) - internal/sampling: ComputeWeight (Laplace-smoothed error rate + recency multiplier, floor 0.15) and SelectWeighted (A-Res reservoir algorithm). 10k-run statistical test verifies weak questions appear >3x more often than mastered, and mastered questions still appear (floor exercised). - GET/POST /test/new: source filter with live available-count JS update, n-questions input, weighted vs uniform mode radio. - GET /test/{id}/q/{n}: deterministic answer shuffle per (test_id, question_id), progress bar, mobile-friendly large tap targets. - POST /test/{id}/q/{n}: records answer + upserts stat; advances to next question or finishes test and redirects to results stub. - GET /test/{id}/results: stub (Phase 7 will add full review). - Ownership enforced: all test routes 404 for wrong user. Co-Authored-By: Claude Sonnet 4.6 --- cmd/server/main.go | 6 + internal/handlers/test.go | 276 +++++++++++++++++++++++++++++ internal/sampling/sampling_test.go | 114 ++++++++++++ internal/sampling/select.go | 46 +++++ internal/sampling/weight.go | 41 +++++ web/templates/test_new.html | 86 +++++++++ web/templates/test_question.html | 41 +++++ web/templates/test_results.html | 18 ++ 8 files changed, 628 insertions(+) create mode 100644 internal/handlers/test.go create mode 100644 internal/sampling/sampling_test.go create mode 100644 internal/sampling/select.go create mode 100644 internal/sampling/weight.go create mode 100644 web/templates/test_new.html create mode 100644 web/templates/test_question.html create mode 100644 web/templates/test_results.html diff --git a/cmd/server/main.go b/cmd/server/main.go index b2d0edb..895d9dc 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -59,6 +59,7 @@ func main() { homeH := handlers.NewHomeHandler(authMgr, repo, renderer) uploadH := handlers.NewUploadHandler(authMgr, repo, llmClient, renderer, cfg.DataDir) questionH := handlers.NewQuestionHandler(authMgr, repo, renderer) + testH := handlers.NewTestHandler(authMgr, repo, renderer) r := chi.NewRouter() r.Use(middleware.RequestID) @@ -89,6 +90,11 @@ func main() { r.Get("/questions/{id}", questionH.Show) r.Post("/questions/{id}", questionH.Edit) r.Post("/questions/{id}/delete", questionH.Delete) + r.Get("/test/new", testH.NewGet) + r.Post("/test/new", testH.NewPost) + r.Get("/test/{id}/q/{n}", testH.QuestionGet) + r.Post("/test/{id}/q/{n}", testH.QuestionPost) + r.Get("/test/{id}/results", testH.ResultsGet) }) srv := &http.Server{ diff --git a/internal/handlers/test.go b/internal/handlers/test.go new file mode 100644 index 0000000..5e223b1 --- /dev/null +++ b/internal/handlers/test.go @@ -0,0 +1,276 @@ +package handlers + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "log/slog" + "math/rand" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" + + "qbank/internal/auth" + "qbank/internal/db" + "qbank/internal/models" + "qbank/internal/sampling" +) + +type TestHandler struct { + auth *auth.Manager + repo *db.Repo + render *Renderer +} + +func NewTestHandler(a *auth.Manager, repo *db.Repo, r *Renderer) *TestHandler { + return &TestHandler{auth: a, repo: repo, render: r} +} + +func (h *TestHandler) NewGet(w http.ResponseWriter, r *http.Request) { + totalQ, _ := h.repo.CountQuestions() + sourceCounts, _ := h.repo.CountBySource() + + data := BaseData(h.auth, r) + data["TotalQ"] = totalQ + data["SourceStats"] = sourceCounts + h.render.Render(w, http.StatusOK, "test_new", data) +} + +func (h *TestHandler) NewPost(w http.ResponseWriter, r *http.Request) { + if !h.auth.CheckCSRF(r) { + HTTPError(w, http.StatusForbidden) + return + } + r.Body = http.MaxBytesReader(w, r.Body, 4096) + if err := r.ParseForm(); err != nil { + HTTPError(w, http.StatusBadRequest) + return + } + + user := auth.UserFromCtx(r.Context()) + source := r.FormValue("source") + mode := r.FormValue("mode") + + n, _ := strconv.Atoi(r.FormValue("n")) + if n <= 0 { + n = 10 + } + + questions, err := h.repo.ListQuestions(db.ListFilter{Source: source}) + if err != nil { + slog.Error("list questions for test", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + + if len(questions) == 0 { + totalQ, _ := h.repo.CountQuestions() + sourceCounts, _ := h.repo.CountBySource() + data := BaseData(h.auth, r) + data["TotalQ"] = totalQ + data["SourceStats"] = sourceCounts + data["Error"] = "No questions available for the selected filter." + h.render.Render(w, http.StatusOK, "test_new", data) + return + } + + if n > len(questions) { + n = len(questions) + } + + ids := make([]string, len(questions)) + for i, q := range questions { + ids[i] = q.ID + } + + var candidates []sampling.Candidate + if mode == "uniform" { + candidates = make([]sampling.Candidate, len(questions)) + for i, q := range questions { + candidates[i] = sampling.Candidate{ID: q.ID, Weight: 1.0} + } + } else { + stats, err := h.repo.GetStatsForUser(user.ID, ids) + if err != nil { + slog.Error("get stats for test", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + now := time.Now() + candidates = make([]sampling.Candidate, len(questions)) + for i, q := range questions { + candidates[i] = sampling.Candidate{ + ID: q.ID, + Weight: sampling.ComputeWeight(stats[q.ID], now), + } + } + } + + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + picked := sampling.SelectWeighted(candidates, n, rng) + // Shuffle presentation order independently of selection order. + rng.Shuffle(len(picked), func(i, j int) { picked[i], picked[j] = picked[j], picked[i] }) + + pickedIDs := make([]string, len(picked)) + for i, c := range picked { + pickedIDs[i] = c.ID + } + + testID, err := h.repo.CreateTest(user.ID, pickedIDs) + if err != nil { + slog.Error("create test", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + + http.Redirect(w, r, fmt.Sprintf("/test/%d/q/1", testID), http.StatusSeeOther) +} + +func (h *TestHandler) QuestionGet(w http.ResponseWriter, r *http.Request) { + n, test, ok := h.loadTestAndN(w, r) + if !ok { + return + } + + questionID := test.QuestionIDs[n-1] + q, answers, err := h.repo.GetQuestion(questionID) + if err != nil { + slog.Error("get question for test", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + + data := BaseData(h.auth, r) + data["TestID"] = test.ID + data["N"] = n + data["Total"] = test.NQuestions + data["Question"] = q + data["Answers"] = deterministicShuffle(answers, test.ID, questionID) + data["ProgressPct"] = (n - 1) * 100 / test.NQuestions + + h.render.Render(w, http.StatusOK, "test_question", data) +} + +func (h *TestHandler) QuestionPost(w http.ResponseWriter, r *http.Request) { + n, test, ok := h.loadTestAndN(w, r) + if !ok { + return + } + + if !h.auth.CheckCSRF(r) { + HTTPError(w, http.StatusForbidden) + return + } + r.Body = http.MaxBytesReader(w, r.Body, 4096) + if err := r.ParseForm(); err != nil { + HTTPError(w, http.StatusBadRequest) + return + } + + user := auth.UserFromCtx(r.Context()) + questionID := test.QuestionIDs[n-1] + + _, answers, err := h.repo.GetQuestion(questionID) + if err != nil { + slog.Error("get question for answer", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + + var selectedID *int64 + var isCorrect bool + if rawID := r.FormValue("answer_id"); rawID != "" { + id, err := strconv.ParseInt(rawID, 10, 64) + if err == nil { + selectedID = &id + for _, a := range answers { + if a.ID == id && a.IsCorrect { + isCorrect = true + break + } + } + } + } + + if err := h.repo.RecordAnswer(test.ID, questionID, selectedID, isCorrect); err != nil { + slog.Error("record answer", "err", err) + HTTPError(w, http.StatusInternalServerError) + return + } + if err := h.repo.UpsertStat(user.ID, questionID, isCorrect); err != nil { + slog.Error("upsert stat", "err", err) + } + + if n < test.NQuestions { + http.Redirect(w, r, fmt.Sprintf("/test/%d/q/%d", test.ID, n+1), http.StatusSeeOther) + return + } + + if err := h.repo.FinishTest(test.ID); err != nil { + slog.Error("finish test", "err", err) + } + http.Redirect(w, r, fmt.Sprintf("/test/%d/results", test.ID), http.StatusSeeOther) +} + +func (h *TestHandler) ResultsGet(w http.ResponseWriter, r *http.Request) { + user := auth.UserFromCtx(r.Context()) + testID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + HTTPError(w, http.StatusNotFound) + return + } + test, err := h.repo.GetTest(testID) + if err != nil || test.UserID != user.ID { + HTTPError(w, http.StatusNotFound) + return + } + data := BaseData(h.auth, r) + data["Test"] = test + h.render.Render(w, http.StatusOK, "test_results", data) +} + +// loadTestAndN extracts and validates the test ID and question number (n) from +// URL params. Returns the 1-based question index and the test, or writes an +// error and returns false. +func (h *TestHandler) loadTestAndN(w http.ResponseWriter, r *http.Request) (int, *models.Test, bool) { + user := auth.UserFromCtx(r.Context()) + + testID, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + HTTPError(w, http.StatusNotFound) + return 0, nil, false + } + + n, err := strconv.Atoi(chi.URLParam(r, "n")) + if err != nil { + HTTPError(w, http.StatusNotFound) + return 0, nil, false + } + + test, err := h.repo.GetTest(testID) + if err != nil || test.UserID != user.ID { + HTTPError(w, http.StatusNotFound) + return 0, nil, false + } + + if n < 1 || n > test.NQuestions { + HTTPError(w, http.StatusNotFound) + return 0, nil, false + } + + return n, test, true +} + +// deterministicShuffle returns a copy of answers shuffled by a seed derived +// from the test ID and question ID, so the order is stable across page reloads. +func deterministicShuffle(answers []*models.Answer, testID int64, questionID string) []*models.Answer { + b, _ := hex.DecodeString(questionID) + qInt := int64(binary.BigEndian.Uint64(b)) + rng := rand.New(rand.NewSource(testID ^ qInt)) + out := make([]*models.Answer, len(answers)) + copy(out, answers) + rng.Shuffle(len(out), func(i, j int) { out[i], out[j] = out[j], out[i] }) + return out +} diff --git a/internal/sampling/sampling_test.go b/internal/sampling/sampling_test.go new file mode 100644 index 0000000..789feab --- /dev/null +++ b/internal/sampling/sampling_test.go @@ -0,0 +1,114 @@ +package sampling_test + +import ( + "database/sql" + "fmt" + "math/rand" + "testing" + "time" + + "qbank/internal/models" + "qbank/internal/sampling" +) + +func TestSelectWeighted_Distribution(t *testing.T) { + // Fixed reference time so weights don't drift with wall clock. + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + recentSeen := sql.NullTime{Time: now, Valid: true} + + // Build a pool of 100 candidates: + // 10 mastered (seen=10, correct=10) → base=max(0.15, 1/12)=0.15, recency=1.0, w=0.15 + // 10 weak (seen=10, correct=1) → base=max(0.15, 10/12)≈0.833, recency=1.0, w≈0.833 + // 10 unseen (no stat row) → w=1.0 (UnseenBaseWeight*RecencyMaxMult) + // 70 average (seen=5, correct=3) → base≈(2+1)/(5+2)≈0.429, recency=1.0, w≈0.429 + type entry struct { + id string + stat *models.UserQuestionStat + } + + var entries []entry + statWith := func(seen, correct int) *models.UserQuestionStat { + return &models.UserQuestionStat{TimesSeen: seen, TimesCorrect: correct, LastSeenAt: recentSeen} + } + + for i := 0; i < 10; i++ { + entries = append(entries, entry{fmt.Sprintf("mastered-%d", i), statWith(10, 10)}) + } + for i := 0; i < 10; i++ { + entries = append(entries, entry{fmt.Sprintf("weak-%d", i), statWith(10, 1)}) + } + for i := 0; i < 10; i++ { + entries = append(entries, entry{fmt.Sprintf("unseen-%d", i), nil}) + } + for i := 0; i < 70; i++ { + entries = append(entries, entry{fmt.Sprintf("avg-%d", i), statWith(5, 3)}) + } + + // Build candidate list. + candidates := make([]sampling.Candidate, len(entries)) + for i, e := range entries { + candidates[i] = sampling.Candidate{ + ID: e.id, + Weight: sampling.ComputeWeight(e.stat, now), + } + } + + // Sample n=1 from the pool 10,000 times with a seeded RNG. + rng := rand.New(rand.NewSource(42)) + counts := make(map[string]int, len(candidates)) + const runs = 10_000 + for range runs { + sel := sampling.SelectWeighted(candidates, 1, rng) + counts[sel[0].ID]++ + } + + // Compute group averages. + masteredTotal, weakTotal := 0, 0 + for i := 0; i < 10; i++ { + masteredTotal += counts[fmt.Sprintf("mastered-%d", i)] + weakTotal += counts[fmt.Sprintf("weak-%d", i)] + } + masteredAvg := float64(masteredTotal) / 10.0 + weakAvg := float64(weakTotal) / 10.0 + + t.Logf("mastered avg %.1f, weak avg %.1f, ratio %.2f", masteredAvg, weakAvg, weakAvg/masteredAvg) + + // Weak questions must appear >3× more often than mastered ones. + if weakAvg < masteredAvg*3 { + t.Errorf("want weakAvg > masteredAvg*3, got weakAvg=%.1f masteredAvg=%.1f", weakAvg, masteredAvg) + } + + // Mastered questions must still appear (floor weight working). + if masteredTotal < 50 { + t.Errorf("want masteredTotal >= 50 (floor weight), got %d", masteredTotal) + } +} + +func TestComputeWeight_Unseen(t *testing.T) { + w := sampling.ComputeWeight(nil, time.Now()) + if w != sampling.UnseenBaseWeight*sampling.RecencyMaxMult { + t.Errorf("unseen weight: got %v, want %v", w, sampling.UnseenBaseWeight*sampling.RecencyMaxMult) + } +} + +func TestComputeWeight_FloorEnforced(t *testing.T) { + now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + stat := &models.UserQuestionStat{ + TimesSeen: 100, + TimesCorrect: 100, + LastSeenAt: sql.NullTime{Time: now, Valid: true}, + } + w := sampling.ComputeWeight(stat, now) + if w < sampling.FloorWeight { + t.Errorf("weight %v below FloorWeight %v", w, sampling.FloorWeight) + } +} + +func TestSelectWeighted_AllReturned_WhenNGeLen(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + cands := []sampling.Candidate{{ID: "a", Weight: 1}, {ID: "b", Weight: 2}} + got := sampling.SelectWeighted(cands, 10, rng) + if len(got) != 2 { + t.Errorf("want 2, got %d", len(got)) + } +} diff --git a/internal/sampling/select.go b/internal/sampling/select.go new file mode 100644 index 0000000..07b38eb --- /dev/null +++ b/internal/sampling/select.go @@ -0,0 +1,46 @@ +package sampling + +import ( + "math" + "math/rand" + "sort" +) + +// Candidate is a question ID paired with its sampling weight. +type Candidate struct { + ID string + Weight float64 +} + +// SelectWeighted picks n distinct candidates using the A-Res weighted +// reservoir algorithm (Efraimidis–Spirakis). Each item's selection +// probability is proportional to its weight. O(m log m) time. +func SelectWeighted(candidates []Candidate, n int, rng *rand.Rand) []Candidate { + if n >= len(candidates) { + out := make([]Candidate, len(candidates)) + copy(out, candidates) + return out + } + + type keyed struct { + c Candidate + key float64 + } + + keys := make([]keyed, len(candidates)) + for i, c := range candidates { + u := rng.Float64() + if u == 0 { + u = 1e-12 // avoid log(0) / pow weirdness + } + keys[i] = keyed{c, math.Pow(u, 1.0/c.Weight)} + } + + sort.Slice(keys, func(i, j int) bool { return keys[i].key > keys[j].key }) + + out := make([]Candidate, n) + for i := range out { + out[i] = keys[i].c + } + return out +} diff --git a/internal/sampling/weight.go b/internal/sampling/weight.go new file mode 100644 index 0000000..b84a960 --- /dev/null +++ b/internal/sampling/weight.go @@ -0,0 +1,41 @@ +package sampling + +import ( + "math" + "time" + + "qbank/internal/models" +) + +const ( + FloorWeight = 0.15 // mastered questions still appear at ~15% base rate + RecencyCapDays = 30.0 // days until recency multiplier saturates + RecencyMaxMult = 2.0 // peak recency multiplier + UnseenBaseWeight = 0.5 // base weight for questions with no stats row +) + +// ComputeWeight returns the sampling weight for a question given its per-user +// stat. A nil stat means the question has never been seen. +func ComputeWeight(stat *models.UserQuestionStat, now time.Time) float64 { + if stat == nil { + // Unseen: mid-range base + full recency = 1.0 + return UnseenBaseWeight * RecencyMaxMult + } + + s := float64(stat.TimesSeen) + c := float64(stat.TimesCorrect) + + // Laplace-smoothed error rate dampens noise from small samples. + errorRate := (s - c + 1) / (s + 2) + base := math.Max(FloorWeight, errorRate) + + var daysSince float64 + if stat.LastSeenAt.Valid { + daysSince = now.Sub(stat.LastSeenAt.Time).Hours() / 24 + } else { + daysSince = RecencyCapDays + } + recency := 1 + math.Min(daysSince/RecencyCapDays, 1.0) + + return base * recency +} diff --git a/web/templates/test_new.html b/web/templates/test_new.html new file mode 100644 index 0000000..67cfa32 --- /dev/null +++ b/web/templates/test_new.html @@ -0,0 +1,86 @@ +{{define "content"}} +
+

New Test

+ {{if .Error}} +
+ {{.Error}} +
+ {{end}} +
+ +{{if eq .TotalQ 0}} +
+

No questions in the library yet.

+ Upload a document first +
+{{else}} +
+ + +
+ +
+ + + {{.TotalQ}} available + +
+
+ + {{if .SourceStats}} +
+ + +
+ {{end}} + +
+

Sampling mode

+
+ + +
+
+ + +
+{{end}} +{{end}} diff --git a/web/templates/test_question.html b/web/templates/test_question.html new file mode 100644 index 0000000..8155b9d --- /dev/null +++ b/web/templates/test_question.html @@ -0,0 +1,41 @@ +{{define "content"}} +
+
+ Question {{.N}} of {{.Total}} + {{.ProgressPct}}% done +
+
+
+
+
+ +
+

{{.Question.Text}}

+ {{if .Question.Source}} +

{{.Question.Source}}

+ {{end}} +
+ +
+ + +
+ {{range .Answers}} + + {{end}} +
+ + +
+{{end}} diff --git a/web/templates/test_results.html b/web/templates/test_results.html new file mode 100644 index 0000000..0fd492b --- /dev/null +++ b/web/templates/test_results.html @@ -0,0 +1,18 @@ +{{define "content"}} +
+

Test Complete!

+

Detailed results coming in the next phase.

+ +
+{{end}}