package services import ( "testing" "github.com/stretchr/testify/assert" ) func TestAIService_ValidateVariationType(t *testing.T) { tests := []struct { name string variationType string wantErr bool }{ { name: "valid - rewrite", variationType: "rewrite", wantErr: false, }, { name: "valid - alternative", variationType: "alternative", wantErr: false, }, { name: "valid - similar", variationType: "similar", wantErr: false, }, { name: "invalid type", variationType: "invalid", wantErr: true, }, { name: "empty type", variationType: "", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateVariationType(tt.variationType) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } func TestAIService_BuildExamVariantPrompt(t *testing.T) { tests := []struct { name string originalContent string variationType string expectedContains []string }{ { name: "rewrite prompt", originalContent: "Berechne 5 + 3", variationType: "rewrite", expectedContains: []string{ "Nachschreiber", "gleichen Schwierigkeitsgrad", "Berechne 5 + 3", }, }, { name: "alternative prompt", originalContent: "Erkläre die Photosynthese", variationType: "alternative", expectedContains: []string{ "alternative", "gleichen Lernziele", "Erkläre die Photosynthese", }, }, { name: "similar prompt", originalContent: "Löse die Gleichung x + 5 = 10", variationType: "similar", expectedContains: []string{ "ähnliche", "Übung", "Löse die Gleichung x + 5 = 10", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { prompt := buildExamVariantPrompt(tt.originalContent, tt.variationType) for _, expected := range tt.expectedContains { assert.Contains(t, prompt, expected) } }) } } func TestAIService_BuildFeedbackPrompt(t *testing.T) { tests := []struct { name string studentName string subject string grade float64 expectedContains []string }{ { name: "good grade feedback", studentName: "Max Mustermann", subject: "Mathematik", grade: 1.5, expectedContains: []string{ "Max Mustermann", "Mathematik", "1.5", "Zeugnis", }, }, { name: "improvement needed feedback", studentName: "Anna Schmidt", subject: "Deutsch", grade: 4.0, expectedContains: []string{ "Anna Schmidt", "Deutsch", "4.0", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { prompt := buildFeedbackPrompt(tt.studentName, tt.subject, tt.grade) for _, expected := range tt.expectedContains { assert.Contains(t, prompt, expected) } }) } } func TestAIService_ValidateContentLength(t *testing.T) { tests := []struct { name string content string maxLength int wantErr bool }{ { name: "valid content length", content: "Short content", maxLength: 1000, wantErr: false, }, { name: "empty content", content: "", maxLength: 1000, wantErr: true, }, { name: "content too long", content: generateLongString(10001), maxLength: 10000, wantErr: true, }, { name: "exactly at max length", content: generateLongString(1000), maxLength: 1000, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := validateContentLength(tt.content, tt.maxLength) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } func TestAIService_ParseLLMResponse(t *testing.T) { tests := []struct { name string response string expectedResult string wantErr bool }{ { name: "valid response", response: `{"content": "Generated exam content here"}`, expectedResult: "Generated exam content here", wantErr: false, }, { name: "empty response", response: "", expectedResult: "", wantErr: true, }, { name: "plain text response", response: "This is a plain text response", expectedResult: "This is a plain text response", wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := parseLLMResponse(tt.response) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.Equal(t, tt.expectedResult, result) } }) } } func TestAIService_EstimateTokenCount(t *testing.T) { tests := []struct { name string text string expectedTokens int }{ { name: "short text", text: "Hello world", expectedTokens: 3, // Rough estimate: words + overhead }, { name: "empty text", text: "", expectedTokens: 0, }, { name: "longer text", text: "This is a longer text with multiple words that should result in more tokens", expectedTokens: 15, // Rough estimate }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tokens := estimateTokenCount(tt.text) // Allow some variance in token estimation assert.InDelta(t, tt.expectedTokens, tokens, float64(tt.expectedTokens)*0.5+2) }) } } func TestAIService_SanitizePrompt(t *testing.T) { tests := []struct { name string input string expected string }{ { name: "clean input", input: "Calculate the sum of 5 and 3", expected: "Calculate the sum of 5 and 3", }, { name: "input with newlines", input: "Line 1\nLine 2\nLine 3", expected: "Line 1\nLine 2\nLine 3", }, { name: "input with excessive whitespace", input: "Word with spaces", expected: "Word with spaces", }, { name: "input with leading/trailing whitespace", input: " trimmed content ", expected: "trimmed content", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := sanitizePrompt(tt.input) assert.Equal(t, tt.expected, result) }) } } func TestAIService_DetermineModel(t *testing.T) { tests := []struct { name string taskType string expectedModel string }{ { name: "exam generation - complex task", taskType: "exam_generation", expectedModel: "gpt-4", }, { name: "feedback generation - simpler task", taskType: "feedback", expectedModel: "gpt-3.5-turbo", }, { name: "improvement - complex task", taskType: "improvement", expectedModel: "gpt-4", }, { name: "unknown task - default", taskType: "unknown", expectedModel: "gpt-3.5-turbo", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { model := determineModel(tt.taskType) assert.Equal(t, tt.expectedModel, model) }) } } // Helper functions func validateVariationType(varType string) error { validTypes := map[string]bool{ "rewrite": true, "alternative": true, "similar": true, } if !validTypes[varType] { return assert.AnError } return nil } func buildExamVariantPrompt(originalContent, variationType string) string { var instruction string switch variationType { case "rewrite": instruction = "Erstelle eine Nachschreiber-Version mit dem gleichen Schwierigkeitsgrad." case "alternative": instruction = "Erstelle eine alternative Version mit den gleichen Lernzielen." case "similar": instruction = "Erstelle ähnliche Aufgaben für Übung." } return "Du bist ein erfahrener Lehrer.\n\n" + instruction + "\n\n" + "Original:\n" + originalContent } func buildFeedbackPrompt(studentName, subject string, grade float64) string { gradeStr := "" if grade < 10 { gradeStr = "0" + string(rune('0'+int(grade))) } else { gradeStr = string(rune('0'+int(grade/10))) + string(rune('0'+int(grade)%10)) } // Simplified grade formatting gradeStr = formatGrade(grade) return "Erstelle einen Zeugnis-Kommentar für " + studentName + " im Fach " + subject + " mit Note " + gradeStr + "." } func formatGrade(grade float64) string { whole := int(grade) frac := int((grade - float64(whole)) * 10) return string(rune('0'+whole)) + "." + string(rune('0'+frac)) } func validateContentLength(content string, maxLength int) error { if content == "" { return assert.AnError } if len(content) > maxLength { return assert.AnError } return nil } func generateLongString(length int) string { result := "" for i := 0; i < length; i++ { result += "a" } return result } func parseLLMResponse(response string) (string, error) { if response == "" { return "", assert.AnError } // Check if it's JSON if len(response) > 0 && response[0] == '{' { // Simple JSON extraction - look for "content": "..." start := findString(response, `"content": "`) if start >= 0 { start += len(`"content": "`) end := findString(response[start:], `"`) if end >= 0 { return response[start : start+end], nil } } } // Return as-is for plain text return response, nil } func findString(s, substr string) int { for i := 0; i <= len(s)-len(substr); i++ { if s[i:i+len(substr)] == substr { return i } } return -1 } func estimateTokenCount(text string) int { if text == "" { return 0 } // Rough estimation: ~4 characters per token on average // Plus some overhead for special tokens return len(text)/4 + 1 } func sanitizePrompt(input string) string { // Trim leading/trailing whitespace result := trimSpace(input) // Collapse multiple spaces into one result = collapseSpaces(result) return result } func trimSpace(s string) string { start := 0 end := len(s) for start < end && (s[start] == ' ' || s[start] == '\t') { start++ } for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { end-- } return s[start:end] } func collapseSpaces(s string) string { result := "" lastWasSpace := false for _, c := range s { if c == ' ' || c == '\t' { if !lastWasSpace { result += " " lastWasSpace = true } } else { result += string(c) lastWasSpace = false } } return result } func determineModel(taskType string) string { complexTasks := map[string]bool{ "exam_generation": true, "improvement": true, } if complexTasks[taskType] { return "gpt-4" } return "gpt-3.5-turbo" } func TestAIService_RetryLogic(t *testing.T) { tests := []struct { name string maxRetries int failuresCount int shouldSucceed bool }{ { name: "succeeds first try", maxRetries: 3, failuresCount: 0, shouldSucceed: true, }, { name: "succeeds after retries", maxRetries: 3, failuresCount: 2, shouldSucceed: true, }, { name: "fails after max retries", maxRetries: 3, failuresCount: 4, shouldSucceed: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { success := simulateRetryLogic(tt.maxRetries, tt.failuresCount) assert.Equal(t, tt.shouldSucceed, success) }) } } func simulateRetryLogic(maxRetries, failuresCount int) bool { attempts := 0 for attempts <= maxRetries { if attempts >= failuresCount { return true // Success } attempts++ } return false // All retries failed }