2477130dd9
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
568 lines
15 KiB
Go
568 lines
15 KiB
Go
package db
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"qbank/internal/models"
|
|
)
|
|
|
|
const timeLayout = "2006-01-02 15:04:05"
|
|
|
|
func parseTime(s string) time.Time {
|
|
t, _ := time.Parse(timeLayout, s)
|
|
return t
|
|
}
|
|
|
|
func parseNullTime(s sql.NullString) sql.NullTime {
|
|
if !s.Valid {
|
|
return sql.NullTime{}
|
|
}
|
|
t, err := time.Parse(timeLayout, s.String)
|
|
if err != nil {
|
|
return sql.NullTime{}
|
|
}
|
|
return sql.NullTime{Time: t, Valid: true}
|
|
}
|
|
|
|
// QuestionID computes the canonical ID for a question from its text.
|
|
func QuestionID(text string) string {
|
|
h := sha256.Sum256([]byte(text))
|
|
return fmt.Sprintf("%x", h[:8])
|
|
}
|
|
|
|
type SortOrder int
|
|
|
|
const (
|
|
SortAlpha SortOrder = iota // alphabetical by question text
|
|
SortWeakest // lowest accuracy first (requires UserID)
|
|
SortMostSeen // most-seen first (requires UserID)
|
|
)
|
|
|
|
type ListFilter struct {
|
|
Source string
|
|
Search string
|
|
Sort SortOrder
|
|
UserID int64
|
|
}
|
|
|
|
type Repo struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func New(db *sql.DB) *Repo {
|
|
return &Repo{db: db}
|
|
}
|
|
|
|
func (r *Repo) CreateUser(name, passwordHash string) (int64, error) {
|
|
res, err := r.db.Exec("INSERT INTO users (name, password_hash) VALUES (?, ?)", name, passwordHash)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return res.LastInsertId()
|
|
}
|
|
|
|
func (r *Repo) GetUserByName(name string) (*models.User, error) {
|
|
u := &models.User{}
|
|
var createdAt string
|
|
err := r.db.QueryRow(
|
|
"SELECT id, name, password_hash, created_at FROM users WHERE name = ?", name,
|
|
).Scan(&u.ID, &u.Name, &u.PasswordHash, &createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
u.CreatedAt = parseTime(createdAt)
|
|
return u, nil
|
|
}
|
|
|
|
// InsertQuestion inserts q and its answers in a transaction. Duplicate questions
|
|
// (same text hash) are silently ignored; their answers are not re-inserted.
|
|
func (r *Repo) InsertQuestion(q *models.Question, answers []*models.Answer) error {
|
|
q.ID = QuestionID(q.Text)
|
|
tx, err := r.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
res, err := tx.Exec(
|
|
"INSERT OR IGNORE INTO questions (id, text, source) VALUES (?, ?, ?)",
|
|
q.ID, q.Text, q.Source,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n, _ := res.RowsAffected(); n == 0 {
|
|
return tx.Commit() // already exists
|
|
}
|
|
|
|
for i, a := range answers {
|
|
a.QuestionID = q.ID
|
|
a.Position = i
|
|
res, err := tx.Exec(
|
|
"INSERT INTO answers (question_id, text, is_correct, position) VALUES (?, ?, ?, ?)",
|
|
a.QuestionID, a.Text, a.IsCorrect, a.Position,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
a.ID, _ = res.LastInsertId()
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (r *Repo) GetQuestion(id string) (*models.Question, []*models.Answer, error) {
|
|
q := &models.Question{}
|
|
var createdAt string
|
|
err := r.db.QueryRow(
|
|
"SELECT id, text, source, created_at FROM questions WHERE id = ?", id,
|
|
).Scan(&q.ID, &q.Text, &q.Source, &createdAt)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
q.CreatedAt = parseTime(createdAt)
|
|
|
|
rows, err := r.db.Query(
|
|
"SELECT id, question_id, text, is_correct, position FROM answers WHERE question_id = ? ORDER BY position",
|
|
id,
|
|
)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var answers []*models.Answer
|
|
for rows.Next() {
|
|
a := &models.Answer{}
|
|
if err := rows.Scan(&a.ID, &a.QuestionID, &a.Text, &a.IsCorrect, &a.Position); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
answers = append(answers, a)
|
|
}
|
|
return q, answers, rows.Err()
|
|
}
|
|
|
|
func (r *Repo) ListQuestions(f ListFilter) ([]*models.Question, error) {
|
|
var args []any
|
|
|
|
join := ""
|
|
if f.UserID != 0 && (f.Sort == SortWeakest || f.Sort == SortMostSeen) {
|
|
join = " LEFT JOIN user_question_stats s ON s.question_id = q.id AND s.user_id = ?"
|
|
args = append(args, f.UserID)
|
|
}
|
|
|
|
var where []string
|
|
if f.Source != "" {
|
|
where = append(where, "q.source = ?")
|
|
args = append(args, f.Source)
|
|
}
|
|
if f.Search != "" {
|
|
where = append(where, "q.text LIKE ?")
|
|
args = append(args, "%"+f.Search+"%")
|
|
}
|
|
|
|
query := "SELECT q.id, q.text, q.source, q.created_at FROM questions q" + join
|
|
if len(where) > 0 {
|
|
query += " WHERE " + strings.Join(where, " AND ")
|
|
}
|
|
|
|
switch f.Sort {
|
|
case SortWeakest:
|
|
query += " ORDER BY CASE WHEN s.times_seen IS NULL OR s.times_seen = 0 THEN 0.0 ELSE CAST(s.times_correct AS REAL) / s.times_seen END ASC, q.text COLLATE NOCASE ASC"
|
|
case SortMostSeen:
|
|
query += " ORDER BY COALESCE(s.times_seen, 0) DESC, q.text COLLATE NOCASE ASC"
|
|
default:
|
|
query += " ORDER BY q.text COLLATE NOCASE ASC"
|
|
}
|
|
|
|
rows, err := r.db.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var qs []*models.Question
|
|
for rows.Next() {
|
|
q := &models.Question{}
|
|
var createdAt string
|
|
if err := rows.Scan(&q.ID, &q.Text, &q.Source, &createdAt); err != nil {
|
|
return nil, err
|
|
}
|
|
q.CreatedAt = parseTime(createdAt)
|
|
qs = append(qs, q)
|
|
}
|
|
return qs, rows.Err()
|
|
}
|
|
|
|
func (r *Repo) ListSources() ([]string, error) {
|
|
rows, err := r.db.Query(
|
|
"SELECT DISTINCT source FROM questions WHERE source != '' ORDER BY source",
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var sources []string
|
|
for rows.Next() {
|
|
var s string
|
|
if err := rows.Scan(&s); err != nil {
|
|
return nil, err
|
|
}
|
|
sources = append(sources, s)
|
|
}
|
|
return sources, rows.Err()
|
|
}
|
|
|
|
// SourceStat holds a source name with its question count.
|
|
type SourceStat struct {
|
|
Source string
|
|
Count int
|
|
}
|
|
|
|
func (r *Repo) CountBySource() ([]SourceStat, error) {
|
|
rows, err := r.db.Query(
|
|
"SELECT source, COUNT(*) FROM questions WHERE source != '' GROUP BY source ORDER BY source",
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var stats []SourceStat
|
|
for rows.Next() {
|
|
var s SourceStat
|
|
if err := rows.Scan(&s.Source, &s.Count); err != nil {
|
|
return nil, err
|
|
}
|
|
stats = append(stats, s)
|
|
}
|
|
return stats, rows.Err()
|
|
}
|
|
|
|
func (r *Repo) UpdateQuestion(id, text, source string) error {
|
|
_, err := r.db.Exec("UPDATE questions SET text = ?, source = ? WHERE id = ?", text, source, id)
|
|
return err
|
|
}
|
|
|
|
// AnswerUpdate carries the fields to write for a single answer row.
|
|
type AnswerUpdate struct {
|
|
ID int64
|
|
Text string
|
|
IsCorrect bool
|
|
}
|
|
|
|
func (r *Repo) UpdateAnswers(updates []AnswerUpdate) error {
|
|
for _, u := range updates {
|
|
if _, err := r.db.Exec(
|
|
"UPDATE answers SET text = ?, is_correct = ? WHERE id = ?",
|
|
u.Text, u.IsCorrect, u.ID,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *Repo) DeleteQuestion(id string) error {
|
|
_, err := r.db.Exec("DELETE FROM questions WHERE id = ?", id)
|
|
return err
|
|
}
|
|
|
|
func (r *Repo) CountQuestions() (int, error) {
|
|
var n int
|
|
return n, r.db.QueryRow("SELECT COUNT(*) FROM questions").Scan(&n)
|
|
}
|
|
|
|
func (r *Repo) CountAnswers() (int, error) {
|
|
var n int
|
|
return n, r.db.QueryRow("SELECT COUNT(*) FROM answers").Scan(&n)
|
|
}
|
|
|
|
func (r *Repo) CreateTest(userID int64, questionIDs []string) (int64, error) {
|
|
ids, err := json.Marshal(questionIDs)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
res, err := r.db.Exec(
|
|
"INSERT INTO tests (user_id, n_questions, question_ids) VALUES (?, ?, ?)",
|
|
userID, len(questionIDs), string(ids),
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return res.LastInsertId()
|
|
}
|
|
|
|
func (r *Repo) RecordAnswer(testID int64, questionID string, selectedAnswerID *int64, isCorrect bool) error {
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO test_answers (test_id, question_id, selected_answer_id, is_correct, answered_at)
|
|
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)
|
|
ON CONFLICT (test_id, question_id) DO UPDATE SET
|
|
selected_answer_id = excluded.selected_answer_id,
|
|
is_correct = excluded.is_correct,
|
|
answered_at = excluded.answered_at`,
|
|
testID, questionID, selectedAnswerID, isCorrect,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (r *Repo) FinishTest(id int64) error {
|
|
_, err := r.db.Exec("UPDATE tests SET completed_at = CURRENT_TIMESTAMP WHERE id = ?", id)
|
|
return err
|
|
}
|
|
|
|
func (r *Repo) GetTest(id int64) (*models.Test, error) {
|
|
t := &models.Test{}
|
|
var createdAt string
|
|
var completedAt sql.NullString
|
|
var ids string
|
|
|
|
err := r.db.QueryRow(
|
|
"SELECT id, user_id, created_at, completed_at, n_questions, question_ids FROM tests WHERE id = ?", id,
|
|
).Scan(&t.ID, &t.UserID, &createdAt, &completedAt, &t.NQuestions, &ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.CreatedAt = parseTime(createdAt)
|
|
t.CompletedAt = parseNullTime(completedAt)
|
|
if err := json.Unmarshal([]byte(ids), &t.QuestionIDs); err != nil {
|
|
return nil, err
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (r *Repo) ListTestsForUser(userID int64) ([]*models.Test, error) {
|
|
rows, err := r.db.Query(`
|
|
SELECT id, user_id, created_at, completed_at, n_questions, question_ids
|
|
FROM tests WHERE user_id = ? ORDER BY created_at DESC`, userID,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tests []*models.Test
|
|
for rows.Next() {
|
|
t := &models.Test{}
|
|
var createdAt string
|
|
var completedAt sql.NullString
|
|
var ids string
|
|
if err := rows.Scan(&t.ID, &t.UserID, &createdAt, &completedAt, &t.NQuestions, &ids); err != nil {
|
|
return nil, err
|
|
}
|
|
t.CreatedAt = parseTime(createdAt)
|
|
t.CompletedAt = parseNullTime(completedAt)
|
|
if err := json.Unmarshal([]byte(ids), &t.QuestionIDs); err != nil {
|
|
return nil, err
|
|
}
|
|
tests = append(tests, t)
|
|
}
|
|
return tests, rows.Err()
|
|
}
|
|
|
|
func (r *Repo) GetTestAnswers(testID int64) ([]*models.TestAnswer, error) {
|
|
rows, err := r.db.Query(`
|
|
SELECT test_id, question_id, selected_answer_id, is_correct, answered_at
|
|
FROM test_answers WHERE test_id = ?`, testID,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var answers []*models.TestAnswer
|
|
for rows.Next() {
|
|
ta := &models.TestAnswer{}
|
|
var answeredAt sql.NullString
|
|
if err := rows.Scan(&ta.TestID, &ta.QuestionID, &ta.SelectedAnswerID, &ta.IsCorrect, &answeredAt); err != nil {
|
|
return nil, err
|
|
}
|
|
ta.AnsweredAt = parseNullTime(answeredAt)
|
|
answers = append(answers, ta)
|
|
}
|
|
return answers, rows.Err()
|
|
}
|
|
|
|
func (r *Repo) UpsertStat(userID int64, questionID string, gotItRight bool) error {
|
|
correct := 0
|
|
if gotItRight {
|
|
correct = 1
|
|
}
|
|
_, err := r.db.Exec(`
|
|
INSERT INTO user_question_stats (user_id, question_id, times_seen, times_correct, last_seen_at)
|
|
VALUES (?, ?, 1, ?, CURRENT_TIMESTAMP)
|
|
ON CONFLICT (user_id, question_id) DO UPDATE SET
|
|
times_seen = times_seen + 1,
|
|
times_correct = times_correct + excluded.times_correct,
|
|
last_seen_at = CURRENT_TIMESTAMP`,
|
|
userID, questionID, correct,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (r *Repo) GetStatsForUser(userID int64, questionIDs []string) (map[string]*models.UserQuestionStat, error) {
|
|
result := make(map[string]*models.UserQuestionStat, len(questionIDs))
|
|
if len(questionIDs) == 0 {
|
|
return result, nil
|
|
}
|
|
|
|
placeholders := make([]string, len(questionIDs))
|
|
args := make([]any, 0, len(questionIDs)+1)
|
|
args = append(args, userID)
|
|
for i, id := range questionIDs {
|
|
placeholders[i] = "?"
|
|
args = append(args, id)
|
|
}
|
|
|
|
rows, err := r.db.Query(fmt.Sprintf(`
|
|
SELECT user_id, question_id, times_seen, times_correct, last_seen_at
|
|
FROM user_question_stats
|
|
WHERE user_id = ? AND question_id IN (%s)`,
|
|
strings.Join(placeholders, ",")),
|
|
args...,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
s := &models.UserQuestionStat{}
|
|
var lastSeen sql.NullString
|
|
if err := rows.Scan(&s.UserID, &s.QuestionID, &s.TimesSeen, &s.TimesCorrect, &lastSeen); err != nil {
|
|
return nil, err
|
|
}
|
|
s.LastSeenAt = parseNullTime(lastSeen)
|
|
result[s.QuestionID] = s
|
|
}
|
|
return result, rows.Err()
|
|
}
|
|
|
|
// ── History ──────────────────────────────────────────────────────────────────
|
|
|
|
// GetCorrectCountsForUser returns a map of test_id → correct-answer count for
|
|
// all completed tests belonging to userID.
|
|
func (r *Repo) GetCorrectCountsForUser(userID int64) (map[int64]int, error) {
|
|
rows, err := r.db.Query(`
|
|
SELECT ta.test_id, SUM(CASE WHEN ta.is_correct = 1 THEN 1 ELSE 0 END)
|
|
FROM test_answers ta
|
|
JOIN tests t ON ta.test_id = t.id
|
|
WHERE t.user_id = ? AND t.completed_at IS NOT NULL
|
|
GROUP BY ta.test_id`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
result := make(map[int64]int)
|
|
for rows.Next() {
|
|
var testID int64
|
|
var correct int
|
|
if err := rows.Scan(&testID, &correct); err != nil {
|
|
return nil, err
|
|
}
|
|
result[testID] = correct
|
|
}
|
|
return result, rows.Err()
|
|
}
|
|
|
|
// GetAggregateStats returns total correct and total answered across all
|
|
// completed tests for userID.
|
|
func (r *Repo) GetAggregateStats(userID int64) (totalCorrect, totalAnswered int, err error) {
|
|
err = r.db.QueryRow(`
|
|
SELECT
|
|
COALESCE(SUM(CASE WHEN ta.is_correct = 1 THEN 1 ELSE 0 END), 0),
|
|
COALESCE(COUNT(ta.question_id), 0)
|
|
FROM test_answers ta
|
|
JOIN tests t ON ta.test_id = t.id
|
|
WHERE t.user_id = ? AND t.completed_at IS NOT NULL`, userID,
|
|
).Scan(&totalCorrect, &totalAnswered)
|
|
return
|
|
}
|
|
|
|
// WeakSpot is a question the user has answered incorrectly more than once.
|
|
type WeakSpot struct {
|
|
QuestionID string
|
|
QuestionText string
|
|
TimesWrong int
|
|
TimesSeen int
|
|
}
|
|
|
|
// GetWeakSpots returns up to 10 questions the user has gotten wrong more than
|
|
// once, ordered by wrong-answer count descending.
|
|
func (r *Repo) GetWeakSpots(userID int64) ([]*WeakSpot, error) {
|
|
rows, err := r.db.Query(`
|
|
SELECT uqs.question_id, q.text, uqs.times_seen,
|
|
(uqs.times_seen - uqs.times_correct) AS times_wrong
|
|
FROM user_question_stats uqs
|
|
JOIN questions q ON uqs.question_id = q.id
|
|
WHERE uqs.user_id = ? AND (uqs.times_seen - uqs.times_correct) > 1
|
|
ORDER BY times_wrong DESC, uqs.times_seen DESC
|
|
LIMIT 10`, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var spots []*WeakSpot
|
|
for rows.Next() {
|
|
s := &WeakSpot{}
|
|
if err := rows.Scan(&s.QuestionID, &s.QuestionText, &s.TimesSeen, &s.TimesWrong); err != nil {
|
|
return nil, err
|
|
}
|
|
spots = append(spots, s)
|
|
}
|
|
return spots, rows.Err()
|
|
}
|
|
|
|
// ── Draft (import review) ────────────────────────────────────────────────────
|
|
|
|
func newDraftID() string {
|
|
b := make([]byte, 16)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func (r *Repo) CreateDraft(userID int64, source string, questions []models.DraftQuestion) (string, error) {
|
|
data, err := json.Marshal(questions)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
id := newDraftID()
|
|
_, err = r.db.Exec(
|
|
"INSERT INTO import_drafts (id, user_id, source, questions) VALUES (?, ?, ?, ?)",
|
|
id, userID, source, string(data),
|
|
)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *Repo) GetDraftForUser(id string, userID int64) (*models.Draft, error) {
|
|
d := &models.Draft{}
|
|
var questionsJSON, createdAt string
|
|
err := r.db.QueryRow(
|
|
"SELECT id, user_id, source, questions, created_at FROM import_drafts WHERE id = ? AND user_id = ?",
|
|
id, userID,
|
|
).Scan(&d.ID, &d.UserID, &d.Source, &questionsJSON, &createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
d.CreatedAt = parseTime(createdAt)
|
|
if err := json.Unmarshal([]byte(questionsJSON), &d.Questions); err != nil {
|
|
return nil, err
|
|
}
|
|
return d, nil
|
|
}
|
|
|
|
func (r *Repo) DeleteDraft(id string) error {
|
|
_, err := r.db.Exec("DELETE FROM import_drafts WHERE id = ?", id)
|
|
return err
|
|
}
|