Files
2026-05-11 13:03:04 +03:00

130 lines
3.2 KiB
Go

package llm_test
import (
"context"
"encoding/json"
"testing"
openai "github.com/sashabaranov/go-openai"
"qbank/internal/llm"
)
// mockChat implements llm.ChatClient for testing.
type mockChat struct{ body string }
func (m *mockChat) CreateChatCompletion(_ context.Context, _ openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{
{Message: openai.ChatCompletionMessage{Content: m.body}},
},
}, nil
}
func mockClient(t *testing.T, questions []map[string]any) *llm.Client {
t.Helper()
body, err := json.Marshal(map[string]any{"questions": questions})
if err != nil {
t.Fatal(err)
}
return llm.NewWithClient(&mockChat{body: string(body)}, "test-model")
}
func TestExtractQuestions_HappyPath(t *testing.T) {
qs, err := mockClient(t, []map[string]any{
{
"question": "What is 2+2?",
"answers": []map[string]any{
{"text": "3", "correct": false},
{"text": "4", "correct": true},
{"text": "5", "correct": false},
},
},
}).ExtractQuestions(context.Background(), "text")
if err != nil {
t.Fatalf("ExtractQuestions: %v", err)
}
if len(qs) != 1 {
t.Fatalf("want 1 question, got %d", len(qs))
}
if qs[0].Question != "What is 2+2?" {
t.Errorf("wrong question text: %q", qs[0].Question)
}
if len(qs[0].Answers) != 3 {
t.Errorf("want 3 answers, got %d", len(qs[0].Answers))
}
}
func TestExtractQuestions_DropsInvalid(t *testing.T) {
qs, err := mockClient(t, []map[string]any{
{
"question": "Two correct — should drop",
"answers": []map[string]any{
{"text": "A", "correct": true},
{"text": "B", "correct": true},
},
},
{
"question": "Zero correct — should drop",
"answers": []map[string]any{
{"text": "A", "correct": false},
{"text": "B", "correct": false},
},
},
{
"question": "Valid question",
"answers": []map[string]any{
{"text": "Wrong", "correct": false},
{"text": "Right", "correct": true},
},
},
}).ExtractQuestions(context.Background(), "text")
if err != nil {
t.Fatalf("ExtractQuestions: %v", err)
}
if len(qs) != 1 {
t.Fatalf("want 1 question after dropping invalid, got %d", len(qs))
}
if qs[0].Question != "Valid question" {
t.Errorf("wrong question kept: %q", qs[0].Question)
}
}
func TestExtractQuestions_Dedup(t *testing.T) {
qs, err := mockClient(t, []map[string]any{
{
"question": "Duplicate?",
"answers": []map[string]any{
{"text": "Yes", "correct": true},
{"text": "No", "correct": false},
},
},
{
"question": "Duplicate?",
"answers": []map[string]any{
{"text": "Yes", "correct": true},
{"text": "No", "correct": false},
},
},
}).ExtractQuestions(context.Background(), "text")
if err != nil {
t.Fatalf("ExtractQuestions: %v", err)
}
if len(qs) != 1 {
t.Errorf("want 1 unique question after dedup, got %d", len(qs))
}
}
func TestExtractQuestions_EmptyResponse(t *testing.T) {
qs, err := mockClient(t, []map[string]any{}).ExtractQuestions(context.Background(), "text")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(qs) != 0 {
t.Errorf("want 0 questions for empty response, got %d", len(qs))
}
}