package llm_test import ( "context" "encoding/json" "testing" openai "github.com/sashabaranov/go-openai" "qbank/internal/llm" ) // mockChat implements llm.ChatClient for testing. type mockChat struct{ body string } func (m *mockChat) CreateChatCompletion(_ context.Context, _ openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { return openai.ChatCompletionResponse{ Choices: []openai.ChatCompletionChoice{ {Message: openai.ChatCompletionMessage{Content: m.body}}, }, }, nil } func mockClient(t *testing.T, questions []map[string]any) *llm.Client { t.Helper() body, err := json.Marshal(map[string]any{"questions": questions}) if err != nil { t.Fatal(err) } return llm.NewWithClient(&mockChat{body: string(body)}, "test-model") } func TestExtractQuestions_HappyPath(t *testing.T) { qs, err := mockClient(t, []map[string]any{ { "question": "What is 2+2?", "answers": []map[string]any{ {"text": "3", "correct": false}, {"text": "4", "correct": true}, {"text": "5", "correct": false}, }, }, }).ExtractQuestions(context.Background(), "text") if err != nil { t.Fatalf("ExtractQuestions: %v", err) } if len(qs) != 1 { t.Fatalf("want 1 question, got %d", len(qs)) } if qs[0].Question != "What is 2+2?" { t.Errorf("wrong question text: %q", qs[0].Question) } if len(qs[0].Answers) != 3 { t.Errorf("want 3 answers, got %d", len(qs[0].Answers)) } } func TestExtractQuestions_DropsInvalid(t *testing.T) { qs, err := mockClient(t, []map[string]any{ { "question": "Two correct — should drop", "answers": []map[string]any{ {"text": "A", "correct": true}, {"text": "B", "correct": true}, }, }, { "question": "Zero correct — should drop", "answers": []map[string]any{ {"text": "A", "correct": false}, {"text": "B", "correct": false}, }, }, { "question": "Valid question", "answers": []map[string]any{ {"text": "Wrong", "correct": false}, {"text": "Right", "correct": true}, }, }, }).ExtractQuestions(context.Background(), "text") if err != nil { t.Fatalf("ExtractQuestions: %v", err) } if len(qs) != 1 { t.Fatalf("want 1 question after dropping invalid, got %d", len(qs)) } if qs[0].Question != "Valid question" { t.Errorf("wrong question kept: %q", qs[0].Question) } } func TestExtractQuestions_Dedup(t *testing.T) { qs, err := mockClient(t, []map[string]any{ { "question": "Duplicate?", "answers": []map[string]any{ {"text": "Yes", "correct": true}, {"text": "No", "correct": false}, }, }, { "question": "Duplicate?", "answers": []map[string]any{ {"text": "Yes", "correct": true}, {"text": "No", "correct": false}, }, }, }).ExtractQuestions(context.Background(), "text") if err != nil { t.Fatalf("ExtractQuestions: %v", err) } if len(qs) != 1 { t.Errorf("want 1 unique question after dedup, got %d", len(qs)) } } func TestExtractQuestions_EmptyResponse(t *testing.T) { qs, err := mockClient(t, []map[string]any{}).ExtractQuestions(context.Background(), "text") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(qs) != 0 { t.Errorf("want 0 questions for empty response, got %d", len(qs)) } }