Files
breakpilot-lehrer/school-service/internal/services/ai_service_test.go
Benjamin Boenisch 5a31f52310 Initial commit: breakpilot-lehrer - Lehrer KI Platform
Services: Admin-Lehrer, Backend-Lehrer, Studio v2, Website,
Klausur-Service, School-Service, Voice-Service, Geo-Service,
BreakPilot Drive, Agent-Core

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-11 23:47:26 +01:00

541 lines
11 KiB
Go

package services
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAIService_ValidateVariationType(t *testing.T) {
tests := []struct {
name string
variationType string
wantErr bool
}{
{
name: "valid - rewrite",
variationType: "rewrite",
wantErr: false,
},
{
name: "valid - alternative",
variationType: "alternative",
wantErr: false,
},
{
name: "valid - similar",
variationType: "similar",
wantErr: false,
},
{
name: "invalid type",
variationType: "invalid",
wantErr: true,
},
{
name: "empty type",
variationType: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateVariationType(tt.variationType)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestAIService_BuildExamVariantPrompt(t *testing.T) {
tests := []struct {
name string
originalContent string
variationType string
expectedContains []string
}{
{
name: "rewrite prompt",
originalContent: "Berechne 5 + 3",
variationType: "rewrite",
expectedContains: []string{
"Nachschreiber",
"gleichen Schwierigkeitsgrad",
"Berechne 5 + 3",
},
},
{
name: "alternative prompt",
originalContent: "Erkläre die Photosynthese",
variationType: "alternative",
expectedContains: []string{
"alternative",
"gleichen Lernziele",
"Erkläre die Photosynthese",
},
},
{
name: "similar prompt",
originalContent: "Löse die Gleichung x + 5 = 10",
variationType: "similar",
expectedContains: []string{
"ähnliche",
"Übung",
"Löse die Gleichung x + 5 = 10",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prompt := buildExamVariantPrompt(tt.originalContent, tt.variationType)
for _, expected := range tt.expectedContains {
assert.Contains(t, prompt, expected)
}
})
}
}
func TestAIService_BuildFeedbackPrompt(t *testing.T) {
tests := []struct {
name string
studentName string
subject string
grade float64
expectedContains []string
}{
{
name: "good grade feedback",
studentName: "Max Mustermann",
subject: "Mathematik",
grade: 1.5,
expectedContains: []string{
"Max Mustermann",
"Mathematik",
"1.5",
"Zeugnis",
},
},
{
name: "improvement needed feedback",
studentName: "Anna Schmidt",
subject: "Deutsch",
grade: 4.0,
expectedContains: []string{
"Anna Schmidt",
"Deutsch",
"4.0",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prompt := buildFeedbackPrompt(tt.studentName, tt.subject, tt.grade)
for _, expected := range tt.expectedContains {
assert.Contains(t, prompt, expected)
}
})
}
}
func TestAIService_ValidateContentLength(t *testing.T) {
tests := []struct {
name string
content string
maxLength int
wantErr bool
}{
{
name: "valid content length",
content: "Short content",
maxLength: 1000,
wantErr: false,
},
{
name: "empty content",
content: "",
maxLength: 1000,
wantErr: true,
},
{
name: "content too long",
content: generateLongString(10001),
maxLength: 10000,
wantErr: true,
},
{
name: "exactly at max length",
content: generateLongString(1000),
maxLength: 1000,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateContentLength(tt.content, tt.maxLength)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestAIService_ParseLLMResponse(t *testing.T) {
tests := []struct {
name string
response string
expectedResult string
wantErr bool
}{
{
name: "valid response",
response: `{"content": "Generated exam content here"}`,
expectedResult: "Generated exam content here",
wantErr: false,
},
{
name: "empty response",
response: "",
expectedResult: "",
wantErr: true,
},
{
name: "plain text response",
response: "This is a plain text response",
expectedResult: "This is a plain text response",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseLLMResponse(tt.response)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedResult, result)
}
})
}
}
func TestAIService_EstimateTokenCount(t *testing.T) {
tests := []struct {
name string
text string
expectedTokens int
}{
{
name: "short text",
text: "Hello world",
expectedTokens: 3, // Rough estimate: words + overhead
},
{
name: "empty text",
text: "",
expectedTokens: 0,
},
{
name: "longer text",
text: "This is a longer text with multiple words that should result in more tokens",
expectedTokens: 15, // Rough estimate
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := estimateTokenCount(tt.text)
// Allow some variance in token estimation
assert.InDelta(t, tt.expectedTokens, tokens, float64(tt.expectedTokens)*0.5+2)
})
}
}
func TestAIService_SanitizePrompt(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "clean input",
input: "Calculate the sum of 5 and 3",
expected: "Calculate the sum of 5 and 3",
},
{
name: "input with newlines",
input: "Line 1\nLine 2\nLine 3",
expected: "Line 1\nLine 2\nLine 3",
},
{
name: "input with excessive whitespace",
input: "Word with spaces",
expected: "Word with spaces",
},
{
name: "input with leading/trailing whitespace",
input: " trimmed content ",
expected: "trimmed content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizePrompt(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestAIService_DetermineModel(t *testing.T) {
tests := []struct {
name string
taskType string
expectedModel string
}{
{
name: "exam generation - complex task",
taskType: "exam_generation",
expectedModel: "gpt-4",
},
{
name: "feedback generation - simpler task",
taskType: "feedback",
expectedModel: "gpt-3.5-turbo",
},
{
name: "improvement - complex task",
taskType: "improvement",
expectedModel: "gpt-4",
},
{
name: "unknown task - default",
taskType: "unknown",
expectedModel: "gpt-3.5-turbo",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model := determineModel(tt.taskType)
assert.Equal(t, tt.expectedModel, model)
})
}
}
// Helper functions
func validateVariationType(varType string) error {
validTypes := map[string]bool{
"rewrite": true,
"alternative": true,
"similar": true,
}
if !validTypes[varType] {
return assert.AnError
}
return nil
}
func buildExamVariantPrompt(originalContent, variationType string) string {
var instruction string
switch variationType {
case "rewrite":
instruction = "Erstelle eine Nachschreiber-Version mit dem gleichen Schwierigkeitsgrad."
case "alternative":
instruction = "Erstelle eine alternative Version mit den gleichen Lernzielen."
case "similar":
instruction = "Erstelle ähnliche Aufgaben für Übung."
}
return "Du bist ein erfahrener Lehrer.\n\n" +
instruction + "\n\n" +
"Original:\n" + originalContent
}
func buildFeedbackPrompt(studentName, subject string, grade float64) string {
gradeStr := ""
if grade < 10 {
gradeStr = "0" + string(rune('0'+int(grade)))
} else {
gradeStr = string(rune('0'+int(grade/10))) + string(rune('0'+int(grade)%10))
}
// Simplified grade formatting
gradeStr = formatGrade(grade)
return "Erstelle einen Zeugnis-Kommentar für " + studentName + " im Fach " + subject + " mit Note " + gradeStr + "."
}
func formatGrade(grade float64) string {
whole := int(grade)
frac := int((grade - float64(whole)) * 10)
return string(rune('0'+whole)) + "." + string(rune('0'+frac))
}
func validateContentLength(content string, maxLength int) error {
if content == "" {
return assert.AnError
}
if len(content) > maxLength {
return assert.AnError
}
return nil
}
func generateLongString(length int) string {
result := ""
for i := 0; i < length; i++ {
result += "a"
}
return result
}
func parseLLMResponse(response string) (string, error) {
if response == "" {
return "", assert.AnError
}
// Check if it's JSON
if len(response) > 0 && response[0] == '{' {
// Simple JSON extraction - look for "content": "..."
start := findString(response, `"content": "`)
if start >= 0 {
start += len(`"content": "`)
end := findString(response[start:], `"`)
if end >= 0 {
return response[start : start+end], nil
}
}
}
// Return as-is for plain text
return response, nil
}
func findString(s, substr string) int {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
func estimateTokenCount(text string) int {
if text == "" {
return 0
}
// Rough estimation: ~4 characters per token on average
// Plus some overhead for special tokens
return len(text)/4 + 1
}
func sanitizePrompt(input string) string {
// Trim leading/trailing whitespace
result := trimSpace(input)
// Collapse multiple spaces into one
result = collapseSpaces(result)
return result
}
func trimSpace(s string) string {
start := 0
end := len(s)
for start < end && (s[start] == ' ' || s[start] == '\t') {
start++
}
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
end--
}
return s[start:end]
}
func collapseSpaces(s string) string {
result := ""
lastWasSpace := false
for _, c := range s {
if c == ' ' || c == '\t' {
if !lastWasSpace {
result += " "
lastWasSpace = true
}
} else {
result += string(c)
lastWasSpace = false
}
}
return result
}
func determineModel(taskType string) string {
complexTasks := map[string]bool{
"exam_generation": true,
"improvement": true,
}
if complexTasks[taskType] {
return "gpt-4"
}
return "gpt-3.5-turbo"
}
func TestAIService_RetryLogic(t *testing.T) {
tests := []struct {
name string
maxRetries int
failuresCount int
shouldSucceed bool
}{
{
name: "succeeds first try",
maxRetries: 3,
failuresCount: 0,
shouldSucceed: true,
},
{
name: "succeeds after retries",
maxRetries: 3,
failuresCount: 2,
shouldSucceed: true,
},
{
name: "fails after max retries",
maxRetries: 3,
failuresCount: 4,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
success := simulateRetryLogic(tt.maxRetries, tt.failuresCount)
assert.Equal(t, tt.shouldSucceed, success)
})
}
}
func simulateRetryLogic(maxRetries, failuresCount int) bool {
attempts := 0
for attempts <= maxRetries {
if attempts >= failuresCount {
return true // Success
}
attempts++
}
return false // All retries failed
}