package sampling_test import ( "database/sql" "fmt" "math/rand" "testing" "time" "qbank/internal/models" "qbank/internal/sampling" ) func TestSelectWeighted_Distribution(t *testing.T) { // Fixed reference time so weights don't drift with wall clock. now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) recentSeen := sql.NullTime{Time: now, Valid: true} // Build a pool of 100 candidates: // 10 mastered (seen=10, correct=10) → base=max(0.15, 1/12)=0.15, recency=1.0, w=0.15 // 10 weak (seen=10, correct=1) → base=max(0.15, 10/12)≈0.833, recency=1.0, w≈0.833 // 10 unseen (no stat row) → w=1.0 (UnseenBaseWeight*RecencyMaxMult) // 70 average (seen=5, correct=3) → base≈(2+1)/(5+2)≈0.429, recency=1.0, w≈0.429 type entry struct { id string stat *models.UserQuestionStat } var entries []entry statWith := func(seen, correct int) *models.UserQuestionStat { return &models.UserQuestionStat{TimesSeen: seen, TimesCorrect: correct, LastSeenAt: recentSeen} } for i := 0; i < 10; i++ { entries = append(entries, entry{fmt.Sprintf("mastered-%d", i), statWith(10, 10)}) } for i := 0; i < 10; i++ { entries = append(entries, entry{fmt.Sprintf("weak-%d", i), statWith(10, 1)}) } for i := 0; i < 10; i++ { entries = append(entries, entry{fmt.Sprintf("unseen-%d", i), nil}) } for i := 0; i < 70; i++ { entries = append(entries, entry{fmt.Sprintf("avg-%d", i), statWith(5, 3)}) } // Build candidate list. candidates := make([]sampling.Candidate, len(entries)) for i, e := range entries { candidates[i] = sampling.Candidate{ ID: e.id, Weight: sampling.ComputeWeight(e.stat, now), } } // Sample n=1 from the pool 10,000 times with a seeded RNG. rng := rand.New(rand.NewSource(42)) counts := make(map[string]int, len(candidates)) const runs = 10_000 for range runs { sel := sampling.SelectWeighted(candidates, 1, rng) counts[sel[0].ID]++ } // Compute group averages. masteredTotal, weakTotal := 0, 0 for i := 0; i < 10; i++ { masteredTotal += counts[fmt.Sprintf("mastered-%d", i)] weakTotal += counts[fmt.Sprintf("weak-%d", i)] } masteredAvg := float64(masteredTotal) / 10.0 weakAvg := float64(weakTotal) / 10.0 t.Logf("mastered avg %.1f, weak avg %.1f, ratio %.2f", masteredAvg, weakAvg, weakAvg/masteredAvg) // Weak questions must appear >3× more often than mastered ones. if weakAvg < masteredAvg*3 { t.Errorf("want weakAvg > masteredAvg*3, got weakAvg=%.1f masteredAvg=%.1f", weakAvg, masteredAvg) } // Mastered questions must still appear (floor weight working). if masteredTotal < 50 { t.Errorf("want masteredTotal >= 50 (floor weight), got %d", masteredTotal) } } func TestComputeWeight_Unseen(t *testing.T) { w := sampling.ComputeWeight(nil, time.Now()) if w != sampling.UnseenBaseWeight*sampling.RecencyMaxMult { t.Errorf("unseen weight: got %v, want %v", w, sampling.UnseenBaseWeight*sampling.RecencyMaxMult) } } func TestComputeWeight_FloorEnforced(t *testing.T) { now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) stat := &models.UserQuestionStat{ TimesSeen: 100, TimesCorrect: 100, LastSeenAt: sql.NullTime{Time: now, Valid: true}, } w := sampling.ComputeWeight(stat, now) if w < sampling.FloorWeight { t.Errorf("weight %v below FloorWeight %v", w, sampling.FloorWeight) } } func TestSelectWeighted_AllReturned_WhenNGeLen(t *testing.T) { rng := rand.New(rand.NewSource(1)) cands := []sampling.Candidate{{ID: "a", Weight: 1}, {ID: "b", Weight: 2}} got := sampling.SelectWeighted(cands, 10, rng) if len(got) != 2 { t.Errorf("want 2, got %d", len(got)) } }