fix: Restore all files lost during destructive rebase

A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-09 09:51:32 +01:00
parent f7487ee240
commit bfdaf63ba9
2009 changed files with 749983 additions and 1731 deletions

View File

@@ -0,0 +1,505 @@
package services
import (
"context"
"fmt"
"time"
"github.com/breakpilot/consent-service/internal/database"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/google/uuid"
)
// AttendanceService handles attendance tracking and notifications
type AttendanceService struct {
db *database.DB
matrix *matrix.MatrixService
}
// NewAttendanceService creates a new attendance service
func NewAttendanceService(db *database.DB, matrixService *matrix.MatrixService) *AttendanceService {
return &AttendanceService{
db: db,
matrix: matrixService,
}
}
// ========================================
// Attendance Recording
// ========================================
// RecordAttendance records a student's attendance for a specific lesson
func (s *AttendanceService) RecordAttendance(ctx context.Context, req models.RecordAttendanceRequest, recordedByUserID uuid.UUID) (*models.AttendanceRecord, error) {
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
return nil, fmt.Errorf("invalid student ID: %w", err)
}
slotID, err := uuid.Parse(req.SlotID)
if err != nil {
return nil, fmt.Errorf("invalid slot ID: %w", err)
}
date, err := time.Parse("2006-01-02", req.Date)
if err != nil {
return nil, fmt.Errorf("invalid date format: %w", err)
}
record := &models.AttendanceRecord{
ID: uuid.New(),
StudentID: studentID,
Date: date,
SlotID: slotID,
Status: req.Status,
RecordedBy: recordedByUserID,
Note: req.Note,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (student_id, date, slot_id)
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = EXCLUDED.updated_at
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
record.ID, record.StudentID, record.Date, record.SlotID,
record.Status, record.RecordedBy, record.Note, record.CreatedAt, record.UpdatedAt,
).Scan(&record.ID)
if err != nil {
return nil, fmt.Errorf("failed to record attendance: %w", err)
}
// If student is absent, send notification to parents
if record.Status == models.AttendanceAbsent || record.Status == models.AttendancePending {
go s.notifyParentsOfAbsence(context.Background(), record)
}
return record, nil
}
// RecordBulkAttendance records attendance for multiple students at once
func (s *AttendanceService) RecordBulkAttendance(ctx context.Context, classID uuid.UUID, date string, slotID uuid.UUID, records []struct {
StudentID string
Status string
Note *string
}, recordedByUserID uuid.UUID) error {
parsedDate, err := time.Parse("2006-01-02", date)
if err != nil {
return fmt.Errorf("invalid date format: %w", err)
}
for _, rec := range records {
studentID, err := uuid.Parse(rec.StudentID)
if err != nil {
continue
}
query := `
INSERT INTO attendance_records (id, student_id, date, slot_id, status, recorded_by, note, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
ON CONFLICT (student_id, date, slot_id)
DO UPDATE SET status = EXCLUDED.status, note = EXCLUDED.note, updated_at = NOW()`
_, err = s.db.Pool.Exec(ctx, query,
uuid.New(), studentID, parsedDate, slotID, rec.Status, recordedByUserID, rec.Note,
)
if err != nil {
return fmt.Errorf("failed to record attendance for student %s: %w", rec.StudentID, err)
}
// Notify parents if absent
if rec.Status == models.AttendanceAbsent || rec.Status == models.AttendancePending {
go s.notifyParentsOfAbsenceByStudentID(context.Background(), studentID, parsedDate, slotID)
}
}
return nil
}
// GetAttendanceByClass gets attendance records for a class on a specific date
func (s *AttendanceService) GetAttendanceByClass(ctx context.Context, classID uuid.UUID, date string) (*models.ClassAttendanceOverview, error) {
parsedDate, err := time.Parse("2006-01-02", date)
if err != nil {
return nil, fmt.Errorf("invalid date format: %w", err)
}
// Get class info
classQuery := `SELECT id, school_id, school_year_id, name, grade, section, room, is_active FROM classes WHERE id = $1`
class := &models.Class{}
err = s.db.Pool.QueryRow(ctx, classQuery, classID).Scan(
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
&class.Grade, &class.Section, &class.Room, &class.IsActive,
)
if err != nil {
return nil, fmt.Errorf("failed to get class: %w", err)
}
// Get total students
var totalStudents int
err = s.db.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM students WHERE class_id = $1 AND is_active = true`, classID).Scan(&totalStudents)
if err != nil {
return nil, fmt.Errorf("failed to count students: %w", err)
}
// Get attendance records for the date
recordsQuery := `
SELECT ar.id, ar.student_id, ar.date, ar.slot_id, ar.status, ar.recorded_by, ar.note, ar.created_at, ar.updated_at
FROM attendance_records ar
JOIN students s ON ar.student_id = s.id
WHERE s.class_id = $1 AND ar.date = $2
ORDER BY ar.slot_id`
rows, err := s.db.Pool.Query(ctx, recordsQuery, classID, parsedDate)
if err != nil {
return nil, fmt.Errorf("failed to get attendance records: %w", err)
}
defer rows.Close()
var records []models.AttendanceRecord
presentCount := 0
absentCount := 0
lateCount := 0
seenStudents := make(map[uuid.UUID]bool)
for rows.Next() {
var record models.AttendanceRecord
err := rows.Scan(
&record.ID, &record.StudentID, &record.Date, &record.SlotID,
&record.Status, &record.RecordedBy, &record.Note, &record.CreatedAt, &record.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
}
records = append(records, record)
// Count unique students for summary (use first slot's status)
if !seenStudents[record.StudentID] {
seenStudents[record.StudentID] = true
switch record.Status {
case models.AttendancePresent:
presentCount++
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused, models.AttendancePending:
absentCount++
case models.AttendanceLate, models.AttendanceLateExcused:
lateCount++
}
}
}
return &models.ClassAttendanceOverview{
Class: *class,
Date: parsedDate,
TotalStudents: totalStudents,
PresentCount: presentCount,
AbsentCount: absentCount,
LateCount: lateCount,
Records: records,
}, nil
}
// GetStudentAttendance gets attendance history for a student
func (s *AttendanceService) GetStudentAttendance(ctx context.Context, studentID uuid.UUID, startDate, endDate time.Time) ([]models.AttendanceRecord, error) {
query := `
SELECT id, student_id, timetable_entry_id, date, slot_id, status, recorded_by, note, created_at, updated_at
FROM attendance_records
WHERE student_id = $1 AND date >= $2 AND date <= $3
ORDER BY date DESC, slot_id`
rows, err := s.db.Pool.Query(ctx, query, studentID, startDate, endDate)
if err != nil {
return nil, fmt.Errorf("failed to get student attendance: %w", err)
}
defer rows.Close()
var records []models.AttendanceRecord
for rows.Next() {
var record models.AttendanceRecord
err := rows.Scan(
&record.ID, &record.StudentID, &record.TimetableEntryID, &record.Date,
&record.SlotID, &record.Status, &record.RecordedBy, &record.Note,
&record.CreatedAt, &record.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan attendance record: %w", err)
}
records = append(records, record)
}
return records, nil
}
// ========================================
// Absence Reports (Parent-initiated)
// ========================================
// ReportAbsence allows parents to report a student's absence
func (s *AttendanceService) ReportAbsence(ctx context.Context, req models.ReportAbsenceRequest, reportedByUserID uuid.UUID) (*models.AbsenceReport, error) {
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
return nil, fmt.Errorf("invalid student ID: %w", err)
}
startDate, err := time.Parse("2006-01-02", req.StartDate)
if err != nil {
return nil, fmt.Errorf("invalid start date format: %w", err)
}
endDate, err := time.Parse("2006-01-02", req.EndDate)
if err != nil {
return nil, fmt.Errorf("invalid end date format: %w", err)
}
report := &models.AbsenceReport{
ID: uuid.New(),
StudentID: studentID,
StartDate: startDate,
EndDate: endDate,
Reason: req.Reason,
ReasonCategory: req.ReasonCategory,
Status: "reported",
ReportedBy: reportedByUserID,
ReportedAt: time.Now(),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO absence_reports (id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
report.ID, report.StudentID, report.StartDate, report.EndDate,
report.Reason, report.ReasonCategory, report.Status,
report.ReportedBy, report.ReportedAt, report.CreatedAt, report.UpdatedAt,
).Scan(&report.ID)
if err != nil {
return nil, fmt.Errorf("failed to create absence report: %w", err)
}
return report, nil
}
// ConfirmAbsence allows teachers to confirm/excuse an absence
func (s *AttendanceService) ConfirmAbsence(ctx context.Context, reportID uuid.UUID, confirmedByUserID uuid.UUID, status string) error {
query := `
UPDATE absence_reports
SET status = $1, confirmed_by = $2, confirmed_at = NOW(), updated_at = NOW()
WHERE id = $3`
result, err := s.db.Pool.Exec(ctx, query, status, confirmedByUserID, reportID)
if err != nil {
return fmt.Errorf("failed to confirm absence: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("absence report not found")
}
return nil
}
// GetAbsenceReports gets absence reports for a student
func (s *AttendanceService) GetAbsenceReports(ctx context.Context, studentID uuid.UUID) ([]models.AbsenceReport, error) {
query := `
SELECT id, student_id, start_date, end_date, reason, reason_category, status, reported_by, reported_at, confirmed_by, confirmed_at, medical_certificate, certificate_uploaded, matrix_notification_sent, email_notification_sent, created_at, updated_at
FROM absence_reports
WHERE student_id = $1
ORDER BY start_date DESC`
rows, err := s.db.Pool.Query(ctx, query, studentID)
if err != nil {
return nil, fmt.Errorf("failed to get absence reports: %w", err)
}
defer rows.Close()
var reports []models.AbsenceReport
for rows.Next() {
var report models.AbsenceReport
err := rows.Scan(
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
&report.Reason, &report.ReasonCategory, &report.Status,
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
&report.MedicalCertificate, &report.CertificateUploaded,
&report.MatrixNotificationSent, &report.EmailNotificationSent,
&report.CreatedAt, &report.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan absence report: %w", err)
}
reports = append(reports, report)
}
return reports, nil
}
// GetPendingAbsenceReports gets all unconfirmed absence reports for a class
func (s *AttendanceService) GetPendingAbsenceReports(ctx context.Context, classID uuid.UUID) ([]models.AbsenceReport, error) {
query := `
SELECT ar.id, ar.student_id, ar.start_date, ar.end_date, ar.reason, ar.reason_category, ar.status, ar.reported_by, ar.reported_at, ar.confirmed_by, ar.confirmed_at, ar.medical_certificate, ar.certificate_uploaded, ar.matrix_notification_sent, ar.email_notification_sent, ar.created_at, ar.updated_at
FROM absence_reports ar
JOIN students s ON ar.student_id = s.id
WHERE s.class_id = $1 AND ar.status = 'reported'
ORDER BY ar.start_date DESC`
rows, err := s.db.Pool.Query(ctx, query, classID)
if err != nil {
return nil, fmt.Errorf("failed to get pending absence reports: %w", err)
}
defer rows.Close()
var reports []models.AbsenceReport
for rows.Next() {
var report models.AbsenceReport
err := rows.Scan(
&report.ID, &report.StudentID, &report.StartDate, &report.EndDate,
&report.Reason, &report.ReasonCategory, &report.Status,
&report.ReportedBy, &report.ReportedAt, &report.ConfirmedBy, &report.ConfirmedAt,
&report.MedicalCertificate, &report.CertificateUploaded,
&report.MatrixNotificationSent, &report.EmailNotificationSent,
&report.CreatedAt, &report.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan absence report: %w", err)
}
reports = append(reports, report)
}
return reports, nil
}
// ========================================
// Attendance Statistics
// ========================================
// GetStudentAttendanceStats gets attendance statistics for a student
func (s *AttendanceService) GetStudentAttendanceStats(ctx context.Context, studentID uuid.UUID, schoolYearID uuid.UUID) (map[string]interface{}, error) {
query := `
SELECT
COUNT(*) as total_records,
COUNT(CASE WHEN status = 'present' THEN 1 END) as present_count,
COUNT(CASE WHEN status IN ('absent', 'excused', 'unexcused', 'pending_confirmation') THEN 1 END) as absent_count,
COUNT(CASE WHEN status = 'unexcused' THEN 1 END) as unexcused_count,
COUNT(CASE WHEN status IN ('late', 'late_excused') THEN 1 END) as late_count
FROM attendance_records ar
JOIN timetable_slots ts ON ar.slot_id = ts.id
JOIN schools sch ON ts.school_id = sch.id
JOIN school_years sy ON sy.school_id = sch.id AND sy.id = $2
WHERE ar.student_id = $1 AND ar.date >= sy.start_date AND ar.date <= sy.end_date`
var totalRecords, presentCount, absentCount, unexcusedCount, lateCount int
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID).Scan(
&totalRecords, &presentCount, &absentCount, &unexcusedCount, &lateCount,
)
if err != nil {
return nil, fmt.Errorf("failed to get attendance stats: %w", err)
}
var attendanceRate float64
if totalRecords > 0 {
attendanceRate = float64(presentCount) / float64(totalRecords) * 100
}
return map[string]interface{}{
"total_records": totalRecords,
"present_count": presentCount,
"absent_count": absentCount,
"unexcused_count": unexcusedCount,
"late_count": lateCount,
"attendance_rate": attendanceRate,
}, nil
}
// ========================================
// Parent Notifications
// ========================================
func (s *AttendanceService) notifyParentsOfAbsence(ctx context.Context, record *models.AttendanceRecord) {
if s.matrix == nil {
return
}
// Get student info
var studentFirstName, studentLastName, matrixDMRoom string
err := s.db.Pool.QueryRow(ctx, `
SELECT first_name, last_name, matrix_dm_room
FROM students
WHERE id = $1`, record.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
if err != nil || matrixDMRoom == "" {
return
}
// Get slot info
var slotNumber int
err = s.db.Pool.QueryRow(ctx, `SELECT slot_number FROM timetable_slots WHERE id = $1`, record.SlotID).Scan(&slotNumber)
if err != nil {
return
}
studentName := studentFirstName + " " + studentLastName
dateStr := record.Date.Format("02.01.2006")
// Send Matrix notification
err = s.matrix.SendAbsenceNotification(ctx, matrixDMRoom, studentName, dateStr, slotNumber)
if err != nil {
fmt.Printf("Failed to send absence notification: %v\n", err)
return
}
// Update notification status
s.db.Pool.Exec(ctx, `
UPDATE attendance_records
SET updated_at = NOW()
WHERE id = $1`, record.ID)
// Log the notification
s.createAbsenceNotificationLog(ctx, record.ID, studentName, dateStr, slotNumber)
}
func (s *AttendanceService) notifyParentsOfAbsenceByStudentID(ctx context.Context, studentID uuid.UUID, date time.Time, slotID uuid.UUID) {
record := &models.AttendanceRecord{
StudentID: studentID,
Date: date,
SlotID: slotID,
}
s.notifyParentsOfAbsence(ctx, record)
}
func (s *AttendanceService) createAbsenceNotificationLog(ctx context.Context, recordID uuid.UUID, studentName, dateStr string, slotNumber int) {
// Get parent IDs for this student
query := `
SELECT p.id
FROM parents p
JOIN student_parents sp ON p.id = sp.parent_id
JOIN attendance_records ar ON sp.student_id = ar.student_id
WHERE ar.id = $1`
rows, err := s.db.Pool.Query(ctx, query, recordID)
if err != nil {
return
}
defer rows.Close()
message := fmt.Sprintf("Abwesenheitsmeldung: %s war am %s in der %d. Stunde nicht anwesend.", studentName, dateStr, slotNumber)
for rows.Next() {
var parentID uuid.UUID
if err := rows.Scan(&parentID); err != nil {
continue
}
// Insert notification log
s.db.Pool.Exec(ctx, `
INSERT INTO absence_notifications (id, attendance_record_id, parent_id, channel, message_content, sent_at, created_at)
VALUES ($1, $2, $3, 'matrix', $4, NOW(), NOW())`,
uuid.New(), recordID, parentID, message)
}
}

View File

@@ -0,0 +1,388 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestValidateAttendanceRecord tests attendance record validation
func TestValidateAttendanceRecord(t *testing.T) {
slotID := uuid.New()
tests := []struct {
name string
record models.AttendanceRecord
expectValid bool
}{
{
name: "valid present record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "valid absent record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendanceAbsent,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "valid late record",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: models.AttendanceLate,
RecordedBy: uuid.New(),
},
expectValid: true,
},
{
name: "missing student ID",
record: models.AttendanceRecord{
StudentID: uuid.Nil,
SlotID: slotID,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "invalid status",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now(),
Status: "invalid_status",
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "future date",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: slotID,
Date: time.Now().AddDate(0, 0, 7),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
{
name: "missing slot ID",
record: models.AttendanceRecord{
StudentID: uuid.New(),
SlotID: uuid.Nil,
Date: time.Now(),
Status: models.AttendancePresent,
RecordedBy: uuid.New(),
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateAttendanceRecord(tt.record)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateAttendanceRecord validates an attendance record
func validateAttendanceRecord(record models.AttendanceRecord) bool {
if record.StudentID == uuid.Nil {
return false
}
if record.SlotID == uuid.Nil {
return false
}
if record.RecordedBy == uuid.Nil {
return false
}
if record.Date.After(time.Now().AddDate(0, 0, 1)) {
return false
}
// Validate status
validStatuses := map[string]bool{
models.AttendancePresent: true,
models.AttendanceAbsent: true,
models.AttendanceAbsentExcused: true,
models.AttendanceAbsentUnexcused: true,
models.AttendanceLate: true,
models.AttendanceLateExcused: true,
models.AttendancePending: true,
}
if !validStatuses[record.Status] {
return false
}
return true
}
// TestValidateAbsenceReport tests absence report validation
func TestValidateAbsenceReport(t *testing.T) {
now := time.Now()
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC)
reason := "Krankheit"
medicalReason := "Arzttermin"
tests := []struct {
name string
report models.AbsenceReport
expectValid bool
}{
{
name: "valid single day absence",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "illness",
Status: "reported",
},
expectValid: true,
},
{
name: "valid multi-day absence",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today.AddDate(0, 0, 3),
Reason: &medicalReason,
ReasonCategory: "appointment",
Status: "reported",
},
expectValid: true,
},
{
name: "end before start",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today.AddDate(0, 0, 3),
EndDate: today,
Reason: &reason,
ReasonCategory: "illness",
Status: "reported",
},
expectValid: false,
},
{
name: "missing reason category",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "",
Status: "reported",
},
expectValid: false,
},
{
name: "invalid reason category",
report: models.AbsenceReport{
StudentID: uuid.New(),
ReportedBy: uuid.New(),
StartDate: today,
EndDate: today,
Reason: &reason,
ReasonCategory: "invalid_type",
Status: "reported",
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateAbsenceReport(tt.report)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateAbsenceReport validates an absence report
func validateAbsenceReport(report models.AbsenceReport) bool {
if report.StudentID == uuid.Nil {
return false
}
if report.ReportedBy == uuid.Nil {
return false
}
if report.EndDate.Before(report.StartDate) {
return false
}
if report.ReasonCategory == "" {
return false
}
// Validate reason category
validCategories := map[string]bool{
"illness": true,
"appointment": true,
"family": true,
"other": true,
}
if !validCategories[report.ReasonCategory] {
return false
}
return true
}
// TestCalculateAttendanceStats tests attendance statistics calculation
func TestCalculateAttendanceStats(t *testing.T) {
tests := []struct {
name string
records []models.AttendanceRecord
expectedPresent int
expectedAbsent int
expectedLate int
}{
{
name: "all present",
records: []models.AttendanceRecord{
{Status: models.AttendancePresent},
{Status: models.AttendancePresent},
{Status: models.AttendancePresent},
},
expectedPresent: 3,
expectedAbsent: 0,
expectedLate: 0,
},
{
name: "mixed attendance",
records: []models.AttendanceRecord{
{Status: models.AttendancePresent},
{Status: models.AttendanceAbsent},
{Status: models.AttendanceLate},
{Status: models.AttendancePresent},
{Status: models.AttendanceAbsentExcused},
},
expectedPresent: 2,
expectedAbsent: 2,
expectedLate: 1,
},
{
name: "empty records",
records: []models.AttendanceRecord{},
expectedPresent: 0,
expectedAbsent: 0,
expectedLate: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
present, absent, late := calculateAttendanceStats(tt.records)
if present != tt.expectedPresent {
t.Errorf("expected present=%d, got present=%d", tt.expectedPresent, present)
}
if absent != tt.expectedAbsent {
t.Errorf("expected absent=%d, got absent=%d", tt.expectedAbsent, absent)
}
if late != tt.expectedLate {
t.Errorf("expected late=%d, got late=%d", tt.expectedLate, late)
}
})
}
}
// calculateAttendanceStats calculates attendance statistics
func calculateAttendanceStats(records []models.AttendanceRecord) (present, absent, late int) {
for _, r := range records {
switch r.Status {
case models.AttendancePresent:
present++
case models.AttendanceAbsent, models.AttendanceAbsentExcused, models.AttendanceAbsentUnexcused:
absent++
case models.AttendanceLate, models.AttendanceLateExcused:
late++
}
}
return
}
// TestAttendanceRateCalculation tests attendance rate percentage calculation
func TestAttendanceRateCalculation(t *testing.T) {
tests := []struct {
name string
present int
total int
expectedRate float64
}{
{
name: "100% attendance",
present: 26,
total: 26,
expectedRate: 100.0,
},
{
name: "92.3% attendance",
present: 24,
total: 26,
expectedRate: 92.31,
},
{
name: "0% attendance",
present: 0,
total: 26,
expectedRate: 0.0,
},
{
name: "empty class",
present: 0,
total: 0,
expectedRate: 0.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rate := calculateAttendanceRate(tt.present, tt.total)
// Allow small floating point differences
if rate < tt.expectedRate-0.1 || rate > tt.expectedRate+0.1 {
t.Errorf("expected rate=%.2f, got rate=%.2f", tt.expectedRate, rate)
}
})
}
}
// calculateAttendanceRate calculates attendance rate as percentage
func calculateAttendanceRate(present, total int) float64 {
if total == 0 {
return 0.0
}
rate := float64(present) / float64(total) * 100
// Round to 2 decimal places
return float64(int(rate*100)) / 100
}

View File

@@ -0,0 +1,568 @@
package services
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
"github.com/breakpilot/consent-service/internal/models"
)
var (
ErrInvalidCredentials = errors.New("invalid email or password")
ErrUserNotFound = errors.New("user not found")
ErrUserExists = errors.New("user with this email already exists")
ErrInvalidToken = errors.New("invalid or expired token")
ErrAccountLocked = errors.New("account is temporarily locked")
ErrAccountSuspended = errors.New("account is suspended")
ErrEmailNotVerified = errors.New("email not verified")
)
// AuthService handles authentication logic
type AuthService struct {
db *pgxpool.Pool
jwtSecret string
jwtRefreshSecret string
accessTokenExp time.Duration
refreshTokenExp time.Duration
}
// NewAuthService creates a new AuthService
func NewAuthService(db *pgxpool.Pool, jwtSecret, jwtRefreshSecret string) *AuthService {
return &AuthService{
db: db,
jwtSecret: jwtSecret,
jwtRefreshSecret: jwtRefreshSecret,
accessTokenExp: time.Hour * 1, // 1 hour
refreshTokenExp: time.Hour * 24 * 30, // 30 days
}
}
// HashPassword hashes a password using bcrypt
func (s *AuthService) HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", fmt.Errorf("failed to hash password: %w", err)
}
return string(bytes), nil
}
// VerifyPassword verifies a password against a hash
func (s *AuthService) VerifyPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// GenerateSecureToken generates a cryptographically secure token
func (s *AuthService) GenerateSecureToken(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// HashToken creates a SHA256 hash of a token for storage
func (s *AuthService) HashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
// JWTClaims for access tokens
type JWTClaims struct {
UserID string `json:"user_id"`
Email string `json:"email"`
Role string `json:"role"`
AccountStatus string `json:"account_status"`
jwt.RegisteredClaims
}
// GenerateAccessToken creates a new JWT access token
func (s *AuthService) GenerateAccessToken(user *models.User) (string, error) {
claims := JWTClaims{
UserID: user.ID.String(),
Email: user.Email,
Role: user.Role,
AccountStatus: user.AccountStatus,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(s.accessTokenExp)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Subject: user.ID.String(),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(s.jwtSecret))
}
// GenerateRefreshToken creates a new refresh token
func (s *AuthService) GenerateRefreshToken() (string, string, error) {
token, err := s.GenerateSecureToken(32)
if err != nil {
return "", "", err
}
hash := s.HashToken(token)
return token, hash, nil
}
// ValidateAccessToken validates a JWT access token
func (s *AuthService) ValidateAccessToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.jwtSecret), nil
})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, ErrInvalidToken
}
return claims, nil
}
// Register creates a new user account
func (s *AuthService) Register(ctx context.Context, req *models.RegisterRequest) (*models.User, string, error) {
// Check if user already exists
var exists bool
err := s.db.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE email = $1)", req.Email).Scan(&exists)
if err != nil {
return nil, "", fmt.Errorf("failed to check existing user: %w", err)
}
if exists {
return nil, "", ErrUserExists
}
// Hash password
passwordHash, err := s.HashPassword(req.Password)
if err != nil {
return nil, "", err
}
// Create user
user := &models.User{
ID: uuid.New(),
Email: req.Email,
PasswordHash: &passwordHash,
Name: req.Name,
Role: "user",
EmailVerified: false,
AccountStatus: "active",
}
_, err = s.db.Exec(ctx, `
INSERT INTO users (id, email, password_hash, name, role, email_verified, account_status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW(), NOW())
`, user.ID, user.Email, user.PasswordHash, user.Name, user.Role, user.EmailVerified, user.AccountStatus)
if err != nil {
return nil, "", fmt.Errorf("failed to create user: %w", err)
}
// Generate email verification token
verificationToken, err := s.GenerateSecureToken(32)
if err != nil {
return nil, "", err
}
// Store verification token
_, err = s.db.Exec(ctx, `
INSERT INTO email_verification_tokens (user_id, token, expires_at, created_at)
VALUES ($1, $2, $3, NOW())
`, user.ID, verificationToken, time.Now().Add(24*time.Hour))
if err != nil {
return nil, "", fmt.Errorf("failed to create verification token: %w", err)
}
// Create notification preferences
_, err = s.db.Exec(ctx, `
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, created_at, updated_at)
VALUES ($1, true, true, true, 'weekly', NOW(), NOW())
`, user.ID)
if err != nil {
// Non-critical error, just log
fmt.Printf("Warning: failed to create notification preferences: %v\n", err)
}
return user, verificationToken, nil
}
// Login authenticates a user and returns tokens
func (s *AuthService) Login(ctx context.Context, req *models.LoginRequest, ipAddress, userAgent string) (*models.LoginResponse, error) {
var user models.User
var passwordHash *string
err := s.db.QueryRow(ctx, `
SELECT id, email, password_hash, name, role, email_verified, account_status,
failed_login_attempts, locked_until, created_at, updated_at
FROM users WHERE email = $1
`, req.Email).Scan(
&user.ID, &user.Email, &passwordHash, &user.Name, &user.Role, &user.EmailVerified,
&user.AccountStatus, &user.FailedLoginAttempts, &user.LockedUntil, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrInvalidCredentials
}
// Check if account is locked
if user.LockedUntil != nil && user.LockedUntil.After(time.Now()) {
return nil, ErrAccountLocked
}
// Check if account is suspended
if user.AccountStatus == "suspended" {
return nil, ErrAccountSuspended
}
// Verify password
if passwordHash == nil || !s.VerifyPassword(req.Password, *passwordHash) {
// Increment failed login attempts
_, _ = s.db.Exec(ctx, `
UPDATE users SET
failed_login_attempts = failed_login_attempts + 1,
locked_until = CASE WHEN failed_login_attempts >= 4 THEN NOW() + INTERVAL '30 minutes' ELSE locked_until END,
updated_at = NOW()
WHERE id = $1
`, user.ID)
return nil, ErrInvalidCredentials
}
// Reset failed login attempts and update last login
_, _ = s.db.Exec(ctx, `
UPDATE users SET
failed_login_attempts = 0,
locked_until = NULL,
last_login_at = NOW(),
updated_at = NOW()
WHERE id = $1
`, user.ID)
// Generate tokens
accessToken, err := s.GenerateAccessToken(&user)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
refreshToken, refreshTokenHash, err := s.GenerateRefreshToken()
if err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
// Store session
_, err = s.db.Exec(ctx, `
INSERT INTO user_sessions (user_id, token_hash, ip_address, user_agent, expires_at, created_at, last_activity_at)
VALUES ($1, $2, $3, $4, $5, NOW(), NOW())
`, user.ID, refreshTokenHash, ipAddress, userAgent, time.Now().Add(s.refreshTokenExp))
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
return &models.LoginResponse{
User: user,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int(s.accessTokenExp.Seconds()),
}, nil
}
// RefreshToken refreshes the access token using a refresh token
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*models.LoginResponse, error) {
tokenHash := s.HashToken(refreshToken)
var session models.UserSession
var userID uuid.UUID
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, revoked_at FROM user_sessions
WHERE token_hash = $1
`, tokenHash).Scan(&session.ID, &userID, &session.ExpiresAt, &session.RevokedAt)
if err != nil {
return nil, ErrInvalidToken
}
// Check if session is expired or revoked
if session.RevokedAt != nil || session.ExpiresAt.Before(time.Now()) {
return nil, ErrInvalidToken
}
// Get user
var user models.User
err = s.db.QueryRow(ctx, `
SELECT id, email, name, role, email_verified, account_status, created_at, updated_at
FROM users WHERE id = $1
`, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified,
&user.AccountStatus, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrUserNotFound
}
// Check account status
if user.AccountStatus == "suspended" {
return nil, ErrAccountSuspended
}
// Generate new access token
accessToken, err := s.GenerateAccessToken(&user)
if err != nil {
return nil, fmt.Errorf("failed to generate access token: %w", err)
}
// Update session last activity
_, _ = s.db.Exec(ctx, `
UPDATE user_sessions SET last_activity_at = NOW() WHERE id = $1
`, session.ID)
return &models.LoginResponse{
User: user,
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: int(s.accessTokenExp.Seconds()),
}, nil
}
// VerifyEmail verifies a user's email address
func (s *AuthService) VerifyEmail(ctx context.Context, token string) error {
var tokenID uuid.UUID
var userID uuid.UUID
var expiresAt time.Time
var usedAt *time.Time
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM email_verification_tokens
WHERE token = $1
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
if err != nil {
return ErrInvalidToken
}
if usedAt != nil || expiresAt.Before(time.Now()) {
return ErrInvalidToken
}
// Mark token as used
_, err = s.db.Exec(ctx, `UPDATE email_verification_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
// Verify user email
_, err = s.db.Exec(ctx, `
UPDATE users SET email_verified = true, email_verified_at = NOW(), updated_at = NOW()
WHERE id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to verify email: %w", err)
}
return nil
}
// CreatePasswordResetToken creates a password reset token
func (s *AuthService) CreatePasswordResetToken(ctx context.Context, email, ipAddress string) (string, *uuid.UUID, error) {
var userID uuid.UUID
err := s.db.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", email).Scan(&userID)
if err != nil {
// Don't reveal if user exists
return "", nil, nil
}
token, err := s.GenerateSecureToken(32)
if err != nil {
return "", nil, err
}
_, err = s.db.Exec(ctx, `
INSERT INTO password_reset_tokens (user_id, token, expires_at, ip_address, created_at)
VALUES ($1, $2, $3, $4, NOW())
`, userID, token, time.Now().Add(time.Hour), ipAddress)
if err != nil {
return "", nil, fmt.Errorf("failed to create reset token: %w", err)
}
return token, &userID, nil
}
// ResetPassword resets a user's password using a reset token
func (s *AuthService) ResetPassword(ctx context.Context, token, newPassword string) error {
var tokenID uuid.UUID
var userID uuid.UUID
var expiresAt time.Time
var usedAt *time.Time
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM password_reset_tokens
WHERE token = $1
`, token).Scan(&tokenID, &userID, &expiresAt, &usedAt)
if err != nil {
return ErrInvalidToken
}
if usedAt != nil || expiresAt.Before(time.Now()) {
return ErrInvalidToken
}
// Hash new password
passwordHash, err := s.HashPassword(newPassword)
if err != nil {
return err
}
// Mark token as used
_, err = s.db.Exec(ctx, `UPDATE password_reset_tokens SET used_at = NOW() WHERE id = $1`, tokenID)
if err != nil {
return fmt.Errorf("failed to update token: %w", err)
}
// Update password
_, err = s.db.Exec(ctx, `
UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2
`, passwordHash, userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Revoke all sessions for security
_, err = s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
if err != nil {
fmt.Printf("Warning: failed to revoke sessions: %v\n", err)
}
return nil
}
// ChangePassword changes a user's password (requires current password)
func (s *AuthService) ChangePassword(ctx context.Context, userID uuid.UUID, currentPassword, newPassword string) error {
var passwordHash *string
err := s.db.QueryRow(ctx, "SELECT password_hash FROM users WHERE id = $1", userID).Scan(&passwordHash)
if err != nil {
return ErrUserNotFound
}
if passwordHash == nil || !s.VerifyPassword(currentPassword, *passwordHash) {
return ErrInvalidCredentials
}
newPasswordHash, err := s.HashPassword(newPassword)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `UPDATE users SET password_hash = $1, updated_at = NOW() WHERE id = $2`, newPasswordHash, userID)
if err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
return nil
}
// GetUserByID retrieves a user by ID
func (s *AuthService) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
var user models.User
err := s.db.QueryRow(ctx, `
SELECT id, email, name, role, email_verified, email_verified_at, account_status,
last_login_at, created_at, updated_at
FROM users WHERE id = $1
`, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.Role, &user.EmailVerified, &user.EmailVerifiedAt,
&user.AccountStatus, &user.LastLoginAt, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, ErrUserNotFound
}
return &user, nil
}
// UpdateProfile updates a user's profile
func (s *AuthService) UpdateProfile(ctx context.Context, userID uuid.UUID, req *models.UpdateProfileRequest) (*models.User, error) {
_, err := s.db.Exec(ctx, `UPDATE users SET name = $1, updated_at = NOW() WHERE id = $2`, req.Name, userID)
if err != nil {
return nil, fmt.Errorf("failed to update profile: %w", err)
}
return s.GetUserByID(ctx, userID)
}
// GetActiveSessions retrieves all active sessions for a user
func (s *AuthService) GetActiveSessions(ctx context.Context, userID uuid.UUID) ([]models.UserSession, error) {
rows, err := s.db.Query(ctx, `
SELECT id, user_id, device_info, ip_address, user_agent, expires_at, created_at, last_activity_at
FROM user_sessions
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW()
ORDER BY last_activity_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to get sessions: %w", err)
}
defer rows.Close()
var sessions []models.UserSession
for rows.Next() {
var session models.UserSession
err := rows.Scan(
&session.ID, &session.UserID, &session.DeviceInfo, &session.IPAddress,
&session.UserAgent, &session.ExpiresAt, &session.CreatedAt, &session.LastActivityAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan session: %w", err)
}
sessions = append(sessions, session)
}
return sessions, nil
}
// RevokeSession revokes a specific session
func (s *AuthService) RevokeSession(ctx context.Context, userID, sessionID uuid.UUID) error {
result, err := s.db.Exec(ctx, `
UPDATE user_sessions SET revoked_at = NOW() WHERE id = $1 AND user_id = $2 AND revoked_at IS NULL
`, sessionID, userID)
if err != nil {
return fmt.Errorf("failed to revoke session: %w", err)
}
if result.RowsAffected() == 0 {
return errors.New("session not found")
}
return nil
}
// Logout revokes a session by refresh token
func (s *AuthService) Logout(ctx context.Context, refreshToken string) error {
tokenHash := s.HashToken(refreshToken)
_, err := s.db.Exec(ctx, `UPDATE user_sessions SET revoked_at = NOW() WHERE token_hash = $1`, tokenHash)
return err
}

View File

@@ -0,0 +1,367 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestHashPassword tests password hashing
func TestHashPassword(t *testing.T) {
// Create service without DB for unit tests
s := &AuthService{}
password := "testPassword123!"
hash, err := s.HashPassword(password)
if err != nil {
t.Fatalf("HashPassword failed: %v", err)
}
if hash == "" {
t.Error("Hash should not be empty")
}
if hash == password {
t.Error("Hash should not equal the original password")
}
// Hash should be different each time (bcrypt uses random salt)
hash2, _ := s.HashPassword(password)
if hash == hash2 {
t.Error("Same password should produce different hashes due to salt")
}
}
// TestVerifyPassword tests password verification
func TestVerifyPassword(t *testing.T) {
s := &AuthService{}
password := "testPassword123!"
hash, _ := s.HashPassword(password)
// Should verify correct password
if !s.VerifyPassword(password, hash) {
t.Error("VerifyPassword should return true for correct password")
}
// Should reject incorrect password
if s.VerifyPassword("wrongPassword", hash) {
t.Error("VerifyPassword should return false for incorrect password")
}
// Should reject empty password
if s.VerifyPassword("", hash) {
t.Error("VerifyPassword should return false for empty password")
}
}
// TestGenerateSecureToken tests token generation
func TestGenerateSecureToken(t *testing.T) {
s := &AuthService{}
tests := []struct {
name string
length int
}{
{"short token", 16},
{"standard token", 32},
{"long token", 64},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token, err := s.GenerateSecureToken(tt.length)
if err != nil {
t.Fatalf("GenerateSecureToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
// Tokens should be unique
token2, _ := s.GenerateSecureToken(tt.length)
if token == token2 {
t.Error("Generated tokens should be unique")
}
})
}
}
// TestHashToken tests token hashing for storage
func TestHashToken(t *testing.T) {
s := &AuthService{}
token := "test-token-123"
hash := s.HashToken(token)
if hash == "" {
t.Error("Hash should not be empty")
}
if hash == token {
t.Error("Hash should not equal the original token")
}
// Same token should produce same hash (deterministic)
hash2 := s.HashToken(token)
if hash != hash2 {
t.Error("Same token should produce same hash")
}
// Different tokens should produce different hashes
differentHash := s.HashToken("different-token")
if hash == differentHash {
t.Error("Different tokens should produce different hashes")
}
}
// TestGenerateAccessToken tests JWT access token generation
func TestGenerateAccessToken(t *testing.T) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
token, err := s.GenerateAccessToken(user)
if err != nil {
t.Fatalf("GenerateAccessToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
// Token should have three parts (header.payload.signature)
parts := 0
for _, c := range token {
if c == '.' {
parts++
}
}
if parts != 2 {
t.Errorf("JWT token should have 3 parts, got %d dots", parts)
}
}
// TestValidateAccessToken tests JWT token validation
func TestValidateAccessToken(t *testing.T) {
secret := "test-secret-key-for-testing-purposes"
s := &AuthService{
jwtSecret: secret,
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "admin",
AccountStatus: "active",
}
token, _ := s.GenerateAccessToken(user)
// Should validate valid token
claims, err := s.ValidateAccessToken(token)
if err != nil {
t.Fatalf("ValidateAccessToken failed: %v", err)
}
if claims.UserID != user.ID.String() {
t.Errorf("Expected UserID %s, got %s", user.ID.String(), claims.UserID)
}
if claims.Email != user.Email {
t.Errorf("Expected Email %s, got %s", user.Email, claims.Email)
}
if claims.Role != user.Role {
t.Errorf("Expected Role %s, got %s", user.Role, claims.Role)
}
}
// TestValidateAccessToken_Invalid tests invalid token scenarios
func TestValidateAccessToken_Invalid(t *testing.T) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
tests := []struct {
name string
token string
}{
{"empty token", ""},
{"invalid format", "not-a-jwt-token"},
{"invalid signature", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMTIzIn0.invalidsignature"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := s.ValidateAccessToken(tt.token)
if err == nil {
t.Error("ValidateAccessToken should fail for invalid token")
}
})
}
}
// TestValidateAccessToken_WrongSecret tests token with wrong secret
func TestValidateAccessToken_WrongSecret(t *testing.T) {
s1 := &AuthService{
jwtSecret: "secret-one",
accessTokenExp: time.Hour,
}
s2 := &AuthService{
jwtSecret: "secret-two",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
// Generate token with first secret
token, _ := s1.GenerateAccessToken(user)
// Try to validate with second secret (should fail)
_, err := s2.ValidateAccessToken(token)
if err == nil {
t.Error("ValidateAccessToken should fail when using wrong secret")
}
}
// TestGenerateRefreshToken tests refresh token generation
func TestGenerateRefreshToken(t *testing.T) {
s := &AuthService{}
token, hash, err := s.GenerateRefreshToken()
if err != nil {
t.Fatalf("GenerateRefreshToken failed: %v", err)
}
if token == "" {
t.Error("Token should not be empty")
}
if hash == "" {
t.Error("Hash should not be empty")
}
// Verify hash matches token
expectedHash := s.HashToken(token)
if hash != expectedHash {
t.Error("Returned hash should match hashed token")
}
// Tokens should be unique
token2, hash2, _ := s.GenerateRefreshToken()
if token == token2 {
t.Error("Generated tokens should be unique")
}
if hash == hash2 {
t.Error("Generated hashes should be unique")
}
}
// TestPasswordStrength tests various password scenarios
func TestPasswordStrength(t *testing.T) {
s := &AuthService{}
passwords := []struct {
password string
valid bool
}{
{"short", true}, // bcrypt accepts any length
{"12345678", true}, // numbers only
{"password", true}, // letters only
{"Pass123!", true}, // mixed
{"", true}, // empty (bcrypt allows)
{string(make([]byte, 72)), true}, // max bcrypt length
}
for _, p := range passwords {
hash, err := s.HashPassword(p.password)
if p.valid && err != nil {
t.Errorf("HashPassword failed for valid password %q: %v", p.password, err)
}
if p.valid && !s.VerifyPassword(p.password, hash) {
t.Errorf("VerifyPassword failed for password %q", p.password)
}
}
}
// BenchmarkHashPassword benchmarks password hashing
func BenchmarkHashPassword(b *testing.B) {
s := &AuthService{}
password := "testPassword123!"
for i := 0; i < b.N; i++ {
s.HashPassword(password)
}
}
// BenchmarkVerifyPassword benchmarks password verification
func BenchmarkVerifyPassword(b *testing.B) {
s := &AuthService{}
password := "testPassword123!"
hash, _ := s.HashPassword(password)
for i := 0; i < b.N; i++ {
s.VerifyPassword(password, hash)
}
}
// BenchmarkGenerateAccessToken benchmarks JWT token generation
func BenchmarkGenerateAccessToken(b *testing.B) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
for i := 0; i < b.N; i++ {
s.GenerateAccessToken(user)
}
}
// BenchmarkValidateAccessToken benchmarks JWT token validation
func BenchmarkValidateAccessToken(b *testing.B) {
s := &AuthService{
jwtSecret: "test-secret-key-for-testing-purposes",
accessTokenExp: time.Hour,
}
user := &models.User{
ID: uuid.New(),
Email: "test@example.com",
Role: "user",
AccountStatus: "active",
}
token, _ := s.GenerateAccessToken(user)
for i := 0; i < b.N; i++ {
s.ValidateAccessToken(token)
}
}

View File

@@ -0,0 +1,518 @@
package services
import (
"testing"
"time"
"github.com/google/uuid"
)
// TestConsentService_CreateConsent tests creating a new consent
func TestConsentService_CreateConsent(t *testing.T) {
// This is a unit test with table-driven approach
tests := []struct {
name string
userID uuid.UUID
versionID uuid.UUID
consented bool
expectError bool
errorContains string
}{
{
name: "valid consent - accepted",
userID: uuid.New(),
versionID: uuid.New(),
consented: true,
expectError: false,
},
{
name: "valid consent - declined",
userID: uuid.New(),
versionID: uuid.New(),
consented: false,
expectError: false,
},
{
name: "empty user ID",
userID: uuid.Nil,
versionID: uuid.New(),
consented: true,
expectError: true,
errorContains: "user ID",
},
{
name: "empty version ID",
userID: uuid.New(),
versionID: uuid.Nil,
consented: true,
expectError: true,
errorContains: "version ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate inputs (in real implementation this would be in the service)
var hasError bool
if tt.userID == uuid.Nil {
hasError = true
} else if tt.versionID == uuid.Nil {
hasError = true
}
// Assert
if tt.expectError && !hasError {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
if !tt.expectError && hasError {
t.Error("Expected no error, got error")
}
})
}
}
// TestConsentService_WithdrawConsent tests withdrawing consent
func TestConsentService_WithdrawConsent(t *testing.T) {
tests := []struct {
name string
consentID uuid.UUID
userID uuid.UUID
expectError bool
errorContains string
}{
{
name: "valid withdrawal",
consentID: uuid.New(),
userID: uuid.New(),
expectError: false,
},
{
name: "empty consent ID",
consentID: uuid.Nil,
userID: uuid.New(),
expectError: true,
errorContains: "consent ID",
},
{
name: "empty user ID",
consentID: uuid.New(),
userID: uuid.Nil,
expectError: true,
errorContains: "user ID",
},
{
name: "both empty",
consentID: uuid.Nil,
userID: uuid.Nil,
expectError: true,
errorContains: "ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate
var hasError bool
if tt.consentID == uuid.Nil || tt.userID == uuid.Nil {
hasError = true
}
// Assert
if tt.expectError && !hasError {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
if !tt.expectError && hasError {
t.Error("Expected no error, got error")
}
})
}
}
// TestConsentService_CheckConsent tests checking consent status
func TestConsentService_CheckConsent(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
documentType string
language string
hasConsent bool
needsUpdate bool
expectedConsent bool
expectedNeedsUpd bool
}{
{
name: "user has current consent",
userID: uuid.New(),
documentType: "terms",
language: "de",
hasConsent: true,
needsUpdate: false,
expectedConsent: true,
expectedNeedsUpd: false,
},
{
name: "user has outdated consent",
userID: uuid.New(),
documentType: "privacy",
language: "de",
hasConsent: true,
needsUpdate: true,
expectedConsent: true,
expectedNeedsUpd: true,
},
{
name: "user has no consent",
userID: uuid.New(),
documentType: "cookies",
language: "de",
hasConsent: false,
needsUpdate: true,
expectedConsent: false,
expectedNeedsUpd: true,
},
{
name: "english language",
userID: uuid.New(),
documentType: "terms",
language: "en",
hasConsent: true,
needsUpdate: false,
expectedConsent: true,
expectedNeedsUpd: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate consent check logic
hasConsent := tt.hasConsent
needsUpdate := tt.needsUpdate
// Assert
if hasConsent != tt.expectedConsent {
t.Errorf("Expected hasConsent=%v, got %v", tt.expectedConsent, hasConsent)
}
if needsUpdate != tt.expectedNeedsUpd {
t.Errorf("Expected needsUpdate=%v, got %v", tt.expectedNeedsUpd, needsUpdate)
}
})
}
}
// TestConsentService_GetConsentHistory tests retrieving consent history
func TestConsentService_GetConsentHistory(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
expectError bool
expectEmpty bool
}{
{
name: "valid user with consents",
userID: uuid.New(),
expectError: false,
expectEmpty: false,
},
{
name: "valid user without consents",
userID: uuid.New(),
expectError: false,
expectEmpty: true,
},
{
name: "invalid user ID",
userID: uuid.Nil,
expectError: true,
expectEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
// Assert error expectation
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_UpdateConsent tests updating existing consent
func TestConsentService_UpdateConsent(t *testing.T) {
tests := []struct {
name string
consentID uuid.UUID
userID uuid.UUID
newConsented bool
expectError bool
}{
{
name: "update to consented",
consentID: uuid.New(),
userID: uuid.New(),
newConsented: true,
expectError: false,
},
{
name: "update to not consented",
consentID: uuid.New(),
userID: uuid.New(),
newConsented: false,
expectError: false,
},
{
name: "invalid consent ID",
consentID: uuid.Nil,
userID: uuid.New(),
newConsented: true,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.consentID == uuid.Nil {
err = &ValidationError{Field: "consent ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_GetConsentStats tests getting consent statistics
func TestConsentService_GetConsentStats(t *testing.T) {
tests := []struct {
name string
documentType string
totalUsers int
consentedUsers int
expectedRate float64
}{
{
name: "100% consent rate",
documentType: "terms",
totalUsers: 100,
consentedUsers: 100,
expectedRate: 100.0,
},
{
name: "50% consent rate",
documentType: "privacy",
totalUsers: 100,
consentedUsers: 50,
expectedRate: 50.0,
},
{
name: "0% consent rate",
documentType: "cookies",
totalUsers: 100,
consentedUsers: 0,
expectedRate: 0.0,
},
{
name: "no users",
documentType: "terms",
totalUsers: 0,
consentedUsers: 0,
expectedRate: 0.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Calculate consent rate
var consentRate float64
if tt.totalUsers > 0 {
consentRate = float64(tt.consentedUsers) / float64(tt.totalUsers) * 100
}
// Assert
if consentRate != tt.expectedRate {
t.Errorf("Expected consent rate %.2f%%, got %.2f%%", tt.expectedRate, consentRate)
}
})
}
}
// TestConsentService_BulkConsentCheck tests checking multiple consents at once
func TestConsentService_BulkConsentCheck(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
documentTypes []string
expectError bool
}{
{
name: "check multiple documents",
userID: uuid.New(),
documentTypes: []string{"terms", "privacy", "cookies"},
expectError: false,
},
{
name: "check single document",
userID: uuid.New(),
documentTypes: []string{"terms"},
expectError: false,
},
{
name: "empty document list",
userID: uuid.New(),
documentTypes: []string{},
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
documentTypes: []string{"terms"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestConsentService_ConsentVersionComparison tests version comparison logic
func TestConsentService_ConsentVersionComparison(t *testing.T) {
tests := []struct {
name string
currentVersion string
consentedVersion string
needsUpdate bool
}{
{
name: "same version",
currentVersion: "1.0.0",
consentedVersion: "1.0.0",
needsUpdate: false,
},
{
name: "minor version update",
currentVersion: "1.1.0",
consentedVersion: "1.0.0",
needsUpdate: true,
},
{
name: "major version update",
currentVersion: "2.0.0",
consentedVersion: "1.0.0",
needsUpdate: true,
},
{
name: "patch version update",
currentVersion: "1.0.1",
consentedVersion: "1.0.0",
needsUpdate: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simple version comparison (in real implementation use proper semver)
needsUpdate := tt.currentVersion != tt.consentedVersion
if needsUpdate != tt.needsUpdate {
t.Errorf("Expected needsUpdate=%v, got %v", tt.needsUpdate, needsUpdate)
}
})
}
}
// TestConsentService_ConsentDeadlineCheck tests deadline validation
func TestConsentService_ConsentDeadlineCheck(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadline time.Time
isOverdue bool
daysLeft int
}{
{
name: "deadline in 30 days",
deadline: now.AddDate(0, 0, 30),
isOverdue: false,
daysLeft: 30,
},
{
name: "deadline in 7 days",
deadline: now.AddDate(0, 0, 7),
isOverdue: false,
daysLeft: 7,
},
{
name: "deadline today",
deadline: now,
isOverdue: false,
daysLeft: 0,
},
{
name: "deadline 1 day overdue",
deadline: now.AddDate(0, 0, -1),
isOverdue: true,
daysLeft: -1,
},
{
name: "deadline 30 days overdue",
deadline: now.AddDate(0, 0, -30),
isOverdue: true,
daysLeft: -30,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Calculate if overdue
isOverdue := tt.deadline.Before(now)
daysLeft := int(tt.deadline.Sub(now).Hours() / 24)
if isOverdue != tt.isOverdue {
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
}
// Allow 1 day difference due to time precision
if abs(daysLeft-tt.daysLeft) > 1 {
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
}
})
}
}
// Helper functions
// abs returns the absolute value of an integer
func abs(n int) int {
if n < 0 {
return -n
}
return n
}

View File

@@ -0,0 +1,434 @@
package services
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// DeadlineService handles consent deadlines and account suspensions
type DeadlineService struct {
pool *pgxpool.Pool
notificationService *NotificationService
}
// ConsentDeadline represents a consent deadline for a user
type ConsentDeadline struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
DocumentVersionID uuid.UUID `json:"document_version_id"`
DeadlineAt time.Time `json:"deadline_at"`
ReminderCount int `json:"reminder_count"`
LastReminderAt *time.Time `json:"last_reminder_at"`
ConsentGivenAt *time.Time `json:"consent_given_at"`
CreatedAt time.Time `json:"created_at"`
// Joined fields
DocumentName string `json:"document_name"`
VersionNumber string `json:"version_number"`
}
// AccountSuspension represents an account suspension
type AccountSuspension struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Reason string `json:"reason"`
Details map[string]interface{} `json:"details"`
SuspendedAt time.Time `json:"suspended_at"`
LiftedAt *time.Time `json:"lifted_at"`
LiftedBy *uuid.UUID `json:"lifted_by"`
}
// NewDeadlineService creates a new deadline service
func NewDeadlineService(pool *pgxpool.Pool, notificationService *NotificationService) *DeadlineService {
return &DeadlineService{
pool: pool,
notificationService: notificationService,
}
}
// CreateDeadlinesForPublishedVersion creates consent deadlines for all active users
// when a new mandatory document version is published
func (s *DeadlineService) CreateDeadlinesForPublishedVersion(ctx context.Context, versionID uuid.UUID) error {
// Get version info
var documentName, versionNumber string
var isMandatory bool
err := s.pool.QueryRow(ctx, `
SELECT ld.name, dv.version, ld.is_mandatory
FROM document_versions dv
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE dv.id = $1
`, versionID).Scan(&documentName, &versionNumber, &isMandatory)
if err != nil {
return fmt.Errorf("failed to get version info: %w", err)
}
// Only create deadlines for mandatory documents
if !isMandatory {
return nil
}
// Deadline is 30 days from now
deadlineAt := time.Now().AddDate(0, 0, 30)
// Get all active users who haven't given consent to this version
_, err = s.pool.Exec(ctx, `
INSERT INTO consent_deadlines (user_id, document_version_id, deadline_at)
SELECT u.id, $1, $2
FROM users u
WHERE u.account_status = 'active'
AND NOT EXISTS (
SELECT 1 FROM user_consents uc
WHERE uc.user_id = u.id AND uc.document_version_id = $1 AND uc.consented = TRUE
)
ON CONFLICT (user_id, document_version_id) DO NOTHING
`, versionID, deadlineAt)
if err != nil {
return fmt.Errorf("failed to create deadlines: %w", err)
}
// Notify users via notification service
if s.notificationService != nil {
go s.notificationService.NotifyConsentRequired(ctx, documentName, versionID.String())
}
return nil
}
// MarkConsentGiven marks a deadline as fulfilled when user gives consent
func (s *DeadlineService) MarkConsentGiven(ctx context.Context, userID, versionID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE consent_deadlines
SET consent_given_at = NOW()
WHERE user_id = $1 AND document_version_id = $2 AND consent_given_at IS NULL
`, userID, versionID)
if err != nil {
return err
}
// Check if user should be unsuspended
return s.checkAndLiftSuspension(ctx, userID)
}
// GetPendingDeadlines returns all pending deadlines for a user
func (s *DeadlineService) GetPendingDeadlines(ctx context.Context, userID uuid.UUID) ([]ConsentDeadline, error) {
rows, err := s.pool.Query(ctx, `
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at,
cd.reminder_count, cd.last_reminder_at, cd.consent_given_at, cd.created_at,
ld.name as document_name, dv.version as version_number
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.user_id = $1 AND cd.consent_given_at IS NULL
ORDER BY cd.deadline_at ASC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var deadlines []ConsentDeadline
for rows.Next() {
var d ConsentDeadline
if err := rows.Scan(&d.ID, &d.UserID, &d.DocumentVersionID, &d.DeadlineAt,
&d.ReminderCount, &d.LastReminderAt, &d.ConsentGivenAt, &d.CreatedAt,
&d.DocumentName, &d.VersionNumber); err != nil {
continue
}
deadlines = append(deadlines, d)
}
return deadlines, nil
}
// ProcessDailyDeadlines is meant to be called by a cron job daily
// It sends reminders and suspends accounts that have missed deadlines
func (s *DeadlineService) ProcessDailyDeadlines(ctx context.Context) error {
now := time.Now()
// 1. Send reminders for upcoming deadlines
if err := s.sendReminders(ctx, now); err != nil {
fmt.Printf("Error sending reminders: %v\n", err)
}
// 2. Suspend accounts with expired deadlines
if err := s.suspendExpiredAccounts(ctx, now); err != nil {
fmt.Printf("Error suspending accounts: %v\n", err)
}
return nil
}
// sendReminders sends reminder notifications based on days remaining
func (s *DeadlineService) sendReminders(ctx context.Context, now time.Time) error {
// Reminder schedule: Day 7, 14, 21, 28
reminderDays := []int{7, 14, 21, 28}
for _, days := range reminderDays {
targetDate := now.AddDate(0, 0, days)
dayStart := time.Date(targetDate.Year(), targetDate.Month(), targetDate.Day(), 0, 0, 0, 0, targetDate.Location())
dayEnd := dayStart.AddDate(0, 0, 1)
// Find deadlines that fall on this reminder day
rows, err := s.pool.Query(ctx, `
SELECT cd.id, cd.user_id, cd.document_version_id, cd.deadline_at, cd.reminder_count,
ld.name as document_name
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.consent_given_at IS NULL
AND cd.deadline_at >= $1 AND cd.deadline_at < $2
AND (cd.last_reminder_at IS NULL OR cd.last_reminder_at < $3)
`, dayStart, dayEnd, dayStart)
if err != nil {
continue
}
for rows.Next() {
var id, userID, versionID uuid.UUID
var deadlineAt time.Time
var reminderCount int
var documentName string
if err := rows.Scan(&id, &userID, &versionID, &deadlineAt, &reminderCount, &documentName); err != nil {
continue
}
// Send reminder notification
daysLeft := 30 - (30 - days)
urgency := "freundlich"
if days <= 7 {
urgency = "dringend"
} else if days <= 14 {
urgency = "wichtig"
}
title := fmt.Sprintf("Erinnerung: Zustimmung erforderlich (%s)", urgency)
body := fmt.Sprintf("Bitte bestätigen Sie '%s' innerhalb von %d Tagen.", documentName, daysLeft)
if s.notificationService != nil {
s.notificationService.CreateNotification(ctx, userID, NotificationTypeConsentReminder, title, body, map[string]interface{}{
"document_name": documentName,
"days_left": daysLeft,
"version_id": versionID.String(),
})
}
// Update reminder count and timestamp
s.pool.Exec(ctx, `
UPDATE consent_deadlines
SET reminder_count = reminder_count + 1, last_reminder_at = NOW()
WHERE id = $1
`, id)
}
rows.Close()
}
return nil
}
// suspendExpiredAccounts suspends accounts that have missed their deadline
func (s *DeadlineService) suspendExpiredAccounts(ctx context.Context, now time.Time) error {
// Find users with expired deadlines
rows, err := s.pool.Query(ctx, `
SELECT DISTINCT cd.user_id, array_agg(ld.name) as documents
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
JOIN users u ON cd.user_id = u.id
WHERE cd.consent_given_at IS NULL
AND cd.deadline_at < $1
AND u.account_status = 'active'
AND ld.is_mandatory = TRUE
GROUP BY cd.user_id
`, now)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
var documents []string
if err := rows.Scan(&userID, &documents); err != nil {
continue
}
// Suspend the account
if err := s.suspendAccount(ctx, userID, "consent_deadline_missed", documents); err != nil {
fmt.Printf("Failed to suspend user %s: %v\n", userID, err)
}
}
return nil
}
// suspendAccount suspends a user account
func (s *DeadlineService) suspendAccount(ctx context.Context, userID uuid.UUID, reason string, documents []string) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
// Update user status
_, err = tx.Exec(ctx, `
UPDATE users SET account_status = 'suspended', updated_at = NOW()
WHERE id = $1 AND account_status = 'active'
`, userID)
if err != nil {
return err
}
// Create suspension record
_, err = tx.Exec(ctx, `
INSERT INTO account_suspensions (user_id, reason, details)
VALUES ($1, $2, $3)
`, userID, reason, map[string]interface{}{"documents": documents})
if err != nil {
return err
}
// Log to audit
_, err = tx.Exec(ctx, `
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id, details)
VALUES ($1, 'account_suspended', 'user', $1, $2)
`, userID, map[string]interface{}{"reason": reason, "documents": documents})
if err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
// Send suspension notification
if s.notificationService != nil {
title := "Account vorübergehend gesperrt"
body := "Ihr Account wurde gesperrt, da ausstehende Zustimmungen nicht innerhalb der Frist erteilt wurden. Bitte bestätigen Sie die ausstehenden Dokumente."
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountSuspended, title, body, map[string]interface{}{
"documents": documents,
})
}
return nil
}
// checkAndLiftSuspension checks if user has completed all required consents and lifts suspension
func (s *DeadlineService) checkAndLiftSuspension(ctx context.Context, userID uuid.UUID) error {
// Check if user is currently suspended
var accountStatus string
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&accountStatus)
if err != nil || accountStatus != "suspended" {
return nil
}
// Check if there are any pending mandatory consents
var pendingCount int
err = s.pool.QueryRow(ctx, `
SELECT COUNT(*)
FROM consent_deadlines cd
JOIN document_versions dv ON cd.document_version_id = dv.id
JOIN legal_documents ld ON dv.document_id = ld.id
WHERE cd.user_id = $1
AND cd.consent_given_at IS NULL
AND ld.is_mandatory = TRUE
`, userID).Scan(&pendingCount)
if err != nil {
return err
}
// If no pending consents, lift the suspension
if pendingCount == 0 {
return s.liftSuspension(ctx, userID)
}
return nil
}
// liftSuspension lifts a user's suspension
func (s *DeadlineService) liftSuspension(ctx context.Context, userID uuid.UUID) error {
tx, err := s.pool.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
// Update user status
_, err = tx.Exec(ctx, `
UPDATE users SET account_status = 'active', updated_at = NOW()
WHERE id = $1 AND account_status = 'suspended'
`, userID)
if err != nil {
return err
}
// Update suspension record
_, err = tx.Exec(ctx, `
UPDATE account_suspensions
SET lifted_at = NOW()
WHERE user_id = $1 AND lifted_at IS NULL
`, userID)
if err != nil {
return err
}
// Log to audit
_, err = tx.Exec(ctx, `
INSERT INTO consent_audit_log (user_id, action, entity_type, entity_id)
VALUES ($1, 'account_restored', 'user', $1)
`, userID)
if err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
// Send restoration notification
if s.notificationService != nil {
title := "Account wiederhergestellt"
body := "Vielen Dank! Ihr Account wurde wiederhergestellt. Sie können die Anwendung wieder vollständig nutzen."
s.notificationService.CreateNotification(ctx, userID, NotificationTypeAccountRestored, title, body, nil)
}
return nil
}
// GetAccountSuspension returns the current suspension for a user
func (s *DeadlineService) GetAccountSuspension(ctx context.Context, userID uuid.UUID) (*AccountSuspension, error) {
var suspension AccountSuspension
err := s.pool.QueryRow(ctx, `
SELECT id, user_id, reason, details, suspended_at, lifted_at, lifted_by
FROM account_suspensions
WHERE user_id = $1 AND lifted_at IS NULL
ORDER BY suspended_at DESC
LIMIT 1
`, userID).Scan(&suspension.ID, &suspension.UserID, &suspension.Reason, &suspension.Details,
&suspension.SuspendedAt, &suspension.LiftedAt, &suspension.LiftedBy)
if err != nil {
return nil, err
}
return &suspension, nil
}
// IsUserSuspended checks if a user is currently suspended
func (s *DeadlineService) IsUserSuspended(ctx context.Context, userID uuid.UUID) (bool, error) {
var status string
err := s.pool.QueryRow(ctx, `SELECT account_status FROM users WHERE id = $1`, userID).Scan(&status)
if err != nil {
return false, err
}
return status == "suspended", nil
}

View File

@@ -0,0 +1,439 @@
package services
import (
"testing"
"time"
"github.com/google/uuid"
)
// TestDeadlineService_CreateDeadline tests creating consent deadlines
func TestDeadlineService_CreateDeadline(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
versionID uuid.UUID
deadlineAt time.Time
expectError bool
}{
{
name: "valid deadline - 30 days",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: false,
},
{
name: "valid deadline - 14 days",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 14),
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: true,
},
{
name: "invalid version ID",
userID: uuid.New(),
versionID: uuid.Nil,
deadlineAt: time.Now().AddDate(0, 0, 30),
expectError: true,
},
{
name: "deadline in past",
userID: uuid.New(),
versionID: uuid.New(),
deadlineAt: time.Now().AddDate(0, 0, -1),
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
} else if tt.deadlineAt.Before(time.Now()) {
err = &ValidationError{Field: "deadline", Message: "must be in the future"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_CheckDeadlineStatus tests deadline status checking
func TestDeadlineService_CheckDeadlineStatus(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlineAt time.Time
isOverdue bool
daysLeft int
urgency string
}{
{
name: "30 days left",
deadlineAt: now.AddDate(0, 0, 30),
isOverdue: false,
daysLeft: 30,
urgency: "normal",
},
{
name: "7 days left - warning",
deadlineAt: now.AddDate(0, 0, 7),
isOverdue: false,
daysLeft: 7,
urgency: "warning",
},
{
name: "3 days left - urgent",
deadlineAt: now.AddDate(0, 0, 3),
isOverdue: false,
daysLeft: 3,
urgency: "urgent",
},
{
name: "1 day left - critical",
deadlineAt: now.AddDate(0, 0, 1),
isOverdue: false,
daysLeft: 1,
urgency: "critical",
},
{
name: "overdue by 1 day",
deadlineAt: now.AddDate(0, 0, -1),
isOverdue: true,
daysLeft: -1,
urgency: "overdue",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isOverdue := tt.deadlineAt.Before(now)
daysLeft := int(tt.deadlineAt.Sub(now).Hours() / 24)
var urgency string
if isOverdue {
urgency = "overdue"
} else if daysLeft <= 1 {
urgency = "critical"
} else if daysLeft <= 3 {
urgency = "urgent"
} else if daysLeft <= 7 {
urgency = "warning"
} else {
urgency = "normal"
}
if isOverdue != tt.isOverdue {
t.Errorf("Expected isOverdue=%v, got %v", tt.isOverdue, isOverdue)
}
if abs(daysLeft-tt.daysLeft) > 1 { // Allow 1 day difference
t.Errorf("Expected daysLeft=%d, got %d", tt.daysLeft, daysLeft)
}
if urgency != tt.urgency {
t.Errorf("Expected urgency=%s, got %s", tt.urgency, urgency)
}
})
}
}
// TestDeadlineService_SendReminders tests reminder scheduling
func TestDeadlineService_SendReminders(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlineAt time.Time
lastReminderAt *time.Time
reminderCount int
shouldSend bool
nextReminder int // days before deadline
}{
{
name: "first reminder - 14 days before",
deadlineAt: now.AddDate(0, 0, 14),
lastReminderAt: nil,
reminderCount: 0,
shouldSend: true,
nextReminder: 14,
},
{
name: "second reminder - 7 days before",
deadlineAt: now.AddDate(0, 0, 7),
lastReminderAt: ptrTime(now.AddDate(0, 0, -7)),
reminderCount: 1,
shouldSend: true,
nextReminder: 7,
},
{
name: "third reminder - 3 days before",
deadlineAt: now.AddDate(0, 0, 3),
lastReminderAt: ptrTime(now.AddDate(0, 0, -4)),
reminderCount: 2,
shouldSend: true,
nextReminder: 3,
},
{
name: "final reminder - 1 day before",
deadlineAt: now.AddDate(0, 0, 1),
lastReminderAt: ptrTime(now.AddDate(0, 0, -2)),
reminderCount: 3,
shouldSend: true,
nextReminder: 1,
},
{
name: "too soon for next reminder",
deadlineAt: now.AddDate(0, 0, 10),
lastReminderAt: ptrTime(now.AddDate(0, 0, -1)),
reminderCount: 1,
shouldSend: false,
nextReminder: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
daysUntilDeadline := int(tt.deadlineAt.Sub(now).Hours() / 24)
// Reminder schedule: 14, 7, 3, 1 days before deadline
reminderDays := []int{14, 7, 3, 1}
shouldSend := false
for _, day := range reminderDays {
if daysUntilDeadline == day {
// Check if enough time passed since last reminder
if tt.lastReminderAt == nil || now.Sub(*tt.lastReminderAt) > 12*time.Hour {
shouldSend = true
break
}
}
}
if shouldSend != tt.shouldSend {
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
}
})
}
}
// TestDeadlineService_SuspendAccount tests account suspension logic
func TestDeadlineService_SuspendAccount(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
reason string
shouldSuspend bool
expectError bool
}{
{
name: "suspend for missed deadline",
userID: uuid.New(),
reason: "consent_deadline_exceeded",
shouldSuspend: true,
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
reason: "consent_deadline_exceeded",
shouldSuspend: false,
expectError: true,
},
{
name: "invalid reason",
userID: uuid.New(),
reason: "",
shouldSuspend: false,
expectError: true,
},
}
validReasons := map[string]bool{
"consent_deadline_exceeded": true,
"mandatory_consent_missing": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if !validReasons[tt.reason] && tt.reason != "" {
err = &ValidationError{Field: "reason", Message: "invalid suspension reason"}
} else if tt.reason == "" {
err = &ValidationError{Field: "reason", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_LiftSuspension tests lifting account suspension
func TestDeadlineService_LiftSuspension(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
adminID uuid.UUID
reason string
expectError bool
}{
{
name: "lift valid suspension",
userID: uuid.New(),
adminID: uuid.New(),
reason: "consent provided",
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
adminID: uuid.New(),
reason: "consent provided",
expectError: true,
},
{
name: "invalid admin ID",
userID: uuid.New(),
adminID: uuid.Nil,
reason: "consent provided",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.adminID == uuid.Nil {
err = &ValidationError{Field: "admin ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDeadlineService_GetOverdueDeadlines tests finding overdue deadlines
func TestDeadlineService_GetOverdueDeadlines(t *testing.T) {
now := time.Now()
tests := []struct {
name string
deadlines []time.Time
expected int // number of overdue
}{
{
name: "no overdue deadlines",
deadlines: []time.Time{
now.AddDate(0, 0, 1),
now.AddDate(0, 0, 7),
now.AddDate(0, 0, 30),
},
expected: 0,
},
{
name: "some overdue",
deadlines: []time.Time{
now.AddDate(0, 0, -1),
now.AddDate(0, 0, -5),
now.AddDate(0, 0, 7),
},
expected: 2,
},
{
name: "all overdue",
deadlines: []time.Time{
now.AddDate(0, 0, -1),
now.AddDate(0, 0, -7),
now.AddDate(0, 0, -30),
},
expected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
overdueCount := 0
for _, deadline := range tt.deadlines {
if deadline.Before(now) {
overdueCount++
}
}
if overdueCount != tt.expected {
t.Errorf("Expected %d overdue, got %d", tt.expected, overdueCount)
}
})
}
}
// TestDeadlineService_ProcessScheduledTasks tests scheduled task processing
func TestDeadlineService_ProcessScheduledTasks(t *testing.T) {
now := time.Now()
tests := []struct {
name string
task string
scheduledAt time.Time
shouldProcess bool
}{
{
name: "process due task",
task: "send_reminder",
scheduledAt: now.Add(-1 * time.Hour),
shouldProcess: true,
},
{
name: "skip future task",
task: "send_reminder",
scheduledAt: now.Add(1 * time.Hour),
shouldProcess: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldProcess := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
if shouldProcess != tt.shouldProcess {
t.Errorf("Expected shouldProcess=%v, got %v", tt.shouldProcess, shouldProcess)
}
})
}
}
// Helper functions
func ptrTime(t time.Time) *time.Time {
return &t
}

View File

@@ -0,0 +1,728 @@
package services
import (
"regexp"
"testing"
"time"
"github.com/google/uuid"
)
// TestDocumentService_CreateDocument tests creating a new legal document
func TestDocumentService_CreateDocument(t *testing.T) {
tests := []struct {
name string
docType string
docName string
description string
isMandatory bool
expectError bool
errorContains string
}{
{
name: "valid mandatory document",
docType: "terms",
docName: "Terms of Service",
description: "Our terms and conditions",
isMandatory: true,
expectError: false,
},
{
name: "valid optional document",
docType: "cookies",
docName: "Cookie Policy",
description: "How we use cookies",
isMandatory: false,
expectError: false,
},
{
name: "empty document type",
docType: "",
docName: "Test Document",
description: "Test",
isMandatory: true,
expectError: true,
errorContains: "type",
},
{
name: "empty document name",
docType: "privacy",
docName: "",
description: "Test",
isMandatory: true,
expectError: true,
errorContains: "name",
},
{
name: "invalid document type",
docType: "invalid_type",
docName: "Test",
description: "Test",
isMandatory: false,
expectError: true,
errorContains: "type",
},
}
validTypes := map[string]bool{
"terms": true,
"privacy": true,
"cookies": true,
"community_guidelines": true,
"imprint": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate inputs
var err error
if tt.docType == "" {
err = &ValidationError{Field: "type", Message: "required"}
} else if !validTypes[tt.docType] {
err = &ValidationError{Field: "type", Message: "invalid document type"}
} else if tt.docName == "" {
err = &ValidationError{Field: "name", Message: "required"}
}
// Assert
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_UpdateDocument tests updating a document
func TestDocumentService_UpdateDocument(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
newName string
newActive bool
expectError bool
}{
{
name: "valid update",
documentID: uuid.New(),
newName: "Updated Name",
newActive: true,
expectError: false,
},
{
name: "deactivate document",
documentID: uuid.New(),
newName: "Test",
newActive: false,
expectError: false,
},
{
name: "invalid document ID",
documentID: uuid.Nil,
newName: "Test",
newActive: true,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentID == uuid.Nil {
err = &ValidationError{Field: "document ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_CreateVersion tests creating a document version
func TestDocumentService_CreateVersion(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
version string
language string
title string
content string
expectError bool
errorContains string
}{
{
name: "valid version - German",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "Nutzungsbedingungen",
content: "<h1>Terms</h1><p>Content...</p>",
expectError: false,
},
{
name: "valid version - English",
documentID: uuid.New(),
version: "1.0.0",
language: "en",
title: "Terms of Service",
content: "<h1>Terms</h1><p>Content...</p>",
expectError: false,
},
{
name: "invalid version format",
documentID: uuid.New(),
version: "1.0",
language: "de",
title: "Test",
content: "Content",
expectError: true,
errorContains: "version",
},
{
name: "invalid language",
documentID: uuid.New(),
version: "1.0.0",
language: "fr",
title: "Test",
content: "Content",
expectError: true,
errorContains: "language",
},
{
name: "empty title",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "",
content: "Content",
expectError: true,
errorContains: "title",
},
{
name: "empty content",
documentID: uuid.New(),
version: "1.0.0",
language: "de",
title: "Test",
content: "",
expectError: true,
errorContains: "content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate semver format (X.Y.Z pattern)
validVersion := regexp.MustCompile(`^\d+\.\d+\.\d+$`).MatchString(tt.version)
validLanguage := tt.language == "de" || tt.language == "en"
var err error
if !validVersion {
err = &ValidationError{Field: "version", Message: "invalid format"}
} else if !validLanguage {
err = &ValidationError{Field: "language", Message: "must be 'de' or 'en'"}
} else if tt.title == "" {
err = &ValidationError{Field: "title", Message: "required"}
} else if tt.content == "" {
err = &ValidationError{Field: "content", Message: "required"}
}
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_VersionStatusTransitions tests version status workflow
func TestDocumentService_VersionStatusTransitions(t *testing.T) {
tests := []struct {
name string
fromStatus string
toStatus string
isAllowed bool
}{
// Valid transitions
{"draft to review", "draft", "review", true},
{"review to approved", "review", "approved", true},
{"review to rejected", "review", "rejected", true},
{"approved to published", "approved", "published", true},
{"approved to scheduled", "approved", "scheduled", true},
{"scheduled to published", "scheduled", "published", true},
{"published to archived", "published", "archived", true},
{"rejected to draft", "rejected", "draft", true},
// Invalid transitions
{"draft to published", "draft", "published", false},
{"draft to approved", "draft", "approved", false},
{"review to published", "review", "published", false},
{"published to draft", "published", "draft", false},
{"published to review", "published", "review", false},
{"archived to draft", "archived", "draft", false},
{"archived to published", "archived", "published", false},
}
// Define valid transitions
validTransitions := map[string][]string{
"draft": {"review"},
"review": {"approved", "rejected"},
"approved": {"published", "scheduled"},
"scheduled": {"published"},
"published": {"archived"},
"rejected": {"draft"},
"archived": {}, // terminal state
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Check if transition is allowed
allowed := false
if transitions, ok := validTransitions[tt.fromStatus]; ok {
for _, validTo := range transitions {
if validTo == tt.toStatus {
allowed = true
break
}
}
}
if allowed != tt.isAllowed {
t.Errorf("Transition %s->%s: expected allowed=%v, got %v",
tt.fromStatus, tt.toStatus, tt.isAllowed, allowed)
}
})
}
}
// TestDocumentService_PublishVersion tests publishing a version
func TestDocumentService_PublishVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
currentStatus string
expectError bool
errorContains string
}{
{
name: "publish approved version",
versionID: uuid.New(),
currentStatus: "approved",
expectError: false,
},
{
name: "publish scheduled version",
versionID: uuid.New(),
currentStatus: "scheduled",
expectError: false,
},
{
name: "cannot publish draft",
versionID: uuid.New(),
currentStatus: "draft",
expectError: true,
errorContains: "draft",
},
{
name: "cannot publish review",
versionID: uuid.New(),
currentStatus: "review",
expectError: true,
errorContains: "review",
},
{
name: "invalid version ID",
versionID: uuid.Nil,
currentStatus: "approved",
expectError: true,
errorContains: "ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
} else if tt.currentStatus != "approved" && tt.currentStatus != "scheduled" {
err = &ValidationError{Field: "status", Message: "only approved or scheduled versions can be published"}
}
if tt.expectError {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tt.errorContains)
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
})
}
}
// TestDocumentService_ArchiveVersion tests archiving a version
func TestDocumentService_ArchiveVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
expectError bool
}{
{
name: "archive valid version",
versionID: uuid.New(),
expectError: false,
},
{
name: "invalid version ID",
versionID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.versionID == uuid.Nil {
err = &ValidationError{Field: "version ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_DeleteVersion tests deleting a version
func TestDocumentService_DeleteVersion(t *testing.T) {
tests := []struct {
name string
versionID uuid.UUID
status string
canDelete bool
expectError bool
}{
{
name: "delete draft version",
versionID: uuid.New(),
status: "draft",
canDelete: true,
expectError: false,
},
{
name: "delete rejected version",
versionID: uuid.New(),
status: "rejected",
canDelete: true,
expectError: false,
},
{
name: "cannot delete published version",
versionID: uuid.New(),
status: "published",
canDelete: false,
expectError: true,
},
{
name: "cannot delete approved version",
versionID: uuid.New(),
status: "approved",
canDelete: false,
expectError: true,
},
{
name: "cannot delete archived version",
versionID: uuid.New(),
status: "archived",
canDelete: false,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Only draft and rejected can be deleted
canDelete := tt.status == "draft" || tt.status == "rejected"
var err error
if !canDelete {
err = &ValidationError{Field: "status", Message: "only draft or rejected versions can be deleted"}
}
if tt.expectError {
if err == nil {
t.Error("Expected error, got nil")
}
} else {
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
if canDelete != tt.canDelete {
t.Errorf("Expected canDelete=%v, got %v", tt.canDelete, canDelete)
}
})
}
}
// TestDocumentService_GetLatestVersion tests retrieving the latest version
func TestDocumentService_GetLatestVersion(t *testing.T) {
tests := []struct {
name string
documentID uuid.UUID
language string
status string
expectError bool
}{
{
name: "get latest German version",
documentID: uuid.New(),
language: "de",
status: "published",
expectError: false,
},
{
name: "get latest English version",
documentID: uuid.New(),
language: "en",
status: "published",
expectError: false,
},
{
name: "invalid document ID",
documentID: uuid.Nil,
language: "de",
status: "published",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentID == uuid.Nil {
err = &ValidationError{Field: "document ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestDocumentService_CompareVersions tests version comparison
func TestDocumentService_CompareVersions(t *testing.T) {
tests := []struct {
name string
version1 string
version2 string
isDifferent bool
}{
{
name: "same version",
version1: "1.0.0",
version2: "1.0.0",
isDifferent: false,
},
{
name: "different major version",
version1: "2.0.0",
version2: "1.0.0",
isDifferent: true,
},
{
name: "different minor version",
version1: "1.1.0",
version2: "1.0.0",
isDifferent: true,
},
{
name: "different patch version",
version1: "1.0.1",
version2: "1.0.0",
isDifferent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isDifferent := tt.version1 != tt.version2
if isDifferent != tt.isDifferent {
t.Errorf("Expected isDifferent=%v, got %v", tt.isDifferent, isDifferent)
}
})
}
}
// TestDocumentService_ScheduledPublishing tests scheduled publishing
func TestDocumentService_ScheduledPublishing(t *testing.T) {
now := time.Now()
tests := []struct {
name string
scheduledAt time.Time
shouldPublish bool
}{
{
name: "scheduled for past - should publish",
scheduledAt: now.Add(-1 * time.Hour),
shouldPublish: true,
},
{
name: "scheduled for now - should publish",
scheduledAt: now,
shouldPublish: true,
},
{
name: "scheduled for future - should not publish",
scheduledAt: now.Add(1 * time.Hour),
shouldPublish: false,
},
{
name: "scheduled for tomorrow - should not publish",
scheduledAt: now.AddDate(0, 0, 1),
shouldPublish: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldPublish := tt.scheduledAt.Before(now) || tt.scheduledAt.Equal(now)
if shouldPublish != tt.shouldPublish {
t.Errorf("Expected shouldPublish=%v, got %v", tt.shouldPublish, shouldPublish)
}
})
}
}
// TestDocumentService_ApprovalWorkflow tests the approval workflow
func TestDocumentService_ApprovalWorkflow(t *testing.T) {
tests := []struct {
name string
action string
userRole string
isAllowed bool
}{
// Admin permissions
{"admin submit for review", "submit_review", "admin", true},
{"admin cannot approve", "approve", "admin", false},
{"admin can publish", "publish", "admin", true},
// DSB permissions
{"dsb can approve", "approve", "data_protection_officer", true},
{"dsb can reject", "reject", "data_protection_officer", true},
{"dsb can publish", "publish", "data_protection_officer", true},
// User permissions
{"user cannot submit", "submit_review", "user", false},
{"user cannot approve", "approve", "user", false},
{"user cannot publish", "publish", "user", false},
}
permissions := map[string]map[string]bool{
"admin": {
"submit_review": true,
"approve": false,
"reject": false,
"publish": true,
},
"data_protection_officer": {
"submit_review": true,
"approve": true,
"reject": true,
"publish": true,
},
"user": {
"submit_review": false,
"approve": false,
"reject": false,
"publish": false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rolePerms, ok := permissions[tt.userRole]
if !ok {
t.Fatalf("Unknown role: %s", tt.userRole)
}
isAllowed := rolePerms[tt.action]
if isAllowed != tt.isAllowed {
t.Errorf("Role %s action %s: expected allowed=%v, got %v",
tt.userRole, tt.action, tt.isAllowed, isAllowed)
}
})
}
}
// TestDocumentService_FourEyesPrinciple tests the four-eyes principle
func TestDocumentService_FourEyesPrinciple(t *testing.T) {
tests := []struct {
name string
createdBy uuid.UUID
approver uuid.UUID
approverRole string
canApprove bool
}{
{
name: "different users - DSB can approve",
createdBy: uuid.New(),
approver: uuid.New(),
approverRole: "data_protection_officer",
canApprove: true,
},
{
name: "same user - DSB cannot approve own",
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approverRole: "data_protection_officer",
canApprove: false,
},
{
name: "same user - admin CAN approve own (exception)",
createdBy: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approver: uuid.MustParse("123e4567-e89b-12d3-a456-426614174000"),
approverRole: "admin",
canApprove: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Four-eyes principle: DSB cannot approve their own work
// Exception: Admins can (for development/testing)
canApprove := tt.createdBy != tt.approver || tt.approverRole == "admin"
if canApprove != tt.canApprove {
t.Errorf("Expected canApprove=%v, got %v", tt.canApprove, canApprove)
}
})
}
}

View File

@@ -0,0 +1,947 @@
package services
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// DSRService handles Data Subject Request business logic
type DSRService struct {
pool *pgxpool.Pool
notificationService *NotificationService
emailService *EmailService
}
// NewDSRService creates a new DSRService
func NewDSRService(pool *pgxpool.Pool, notificationService *NotificationService, emailService *EmailService) *DSRService {
return &DSRService{
pool: pool,
notificationService: notificationService,
emailService: emailService,
}
}
// GetPool returns the database pool for direct queries
func (s *DSRService) GetPool() *pgxpool.Pool {
return s.pool
}
// generateRequestNumber generates a unique request number like DSR-2025-000001
func (s *DSRService) generateRequestNumber(ctx context.Context) (string, error) {
var seqNum int64
err := s.pool.QueryRow(ctx, "SELECT nextval('dsr_request_number_seq')").Scan(&seqNum)
if err != nil {
return "", fmt.Errorf("failed to get next sequence number: %w", err)
}
year := time.Now().Year()
return fmt.Sprintf("DSR-%d-%06d", year, seqNum), nil
}
// CreateRequest creates a new data subject request
func (s *DSRService) CreateRequest(ctx context.Context, req models.CreateDSRRequest, createdBy *uuid.UUID) (*models.DataSubjectRequest, error) {
// Validate request type
requestType := models.DSRRequestType(req.RequestType)
if !isValidRequestType(requestType) {
return nil, fmt.Errorf("invalid request type: %s", req.RequestType)
}
// Generate request number
requestNumber, err := s.generateRequestNumber(ctx)
if err != nil {
return nil, err
}
// Calculate deadline
deadlineDays := requestType.DeadlineDays()
deadline := time.Now().AddDate(0, 0, deadlineDays)
// Determine priority
priority := models.DSRPriorityNormal
if req.Priority != "" {
priority = models.DSRPriority(req.Priority)
} else if requestType.IsExpedited() {
priority = models.DSRPriorityExpedited
}
// Determine source
source := models.DSRSourceAPI
if req.Source != "" {
source = models.DSRSource(req.Source)
}
// Serialize request details
detailsJSON, err := json.Marshal(req.RequestDetails)
if err != nil {
detailsJSON = []byte("{}")
}
// Try to find existing user by email
var userID *uuid.UUID
var foundUserID uuid.UUID
err = s.pool.QueryRow(ctx, "SELECT id FROM users WHERE email = $1", req.RequesterEmail).Scan(&foundUserID)
if err == nil {
userID = &foundUserID
}
// Insert request
var dsr models.DataSubjectRequest
err = s.pool.QueryRow(ctx, `
INSERT INTO data_subject_requests (
user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone,
request_details, deadline_at, legal_deadline_days, created_by
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone, identity_verified,
request_details, deadline_at, legal_deadline_days, created_at, updated_at, created_by
`, userID, requestNumber, requestType, models.DSRStatusIntake, priority, source,
req.RequesterEmail, req.RequesterName, req.RequesterPhone,
detailsJSON, deadline, deadlineDays, createdBy,
).Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &detailsJSON,
&dsr.DeadlineAt, &dsr.LegalDeadlineDays, &dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
)
if err != nil {
return nil, fmt.Errorf("failed to create DSR: %w", err)
}
// Parse details back
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
// Record initial status
s.recordStatusChange(ctx, dsr.ID, nil, models.DSRStatusIntake, createdBy, "Anfrage eingegangen")
// Notify DPOs about new request
go s.notifyNewRequest(context.Background(), &dsr)
return &dsr, nil
}
// GetByID retrieves a DSR by ID
func (s *DSRService) GetByID(ctx context.Context, id uuid.UUID) (*models.DataSubjectRequest, error) {
var dsr models.DataSubjectRequest
var detailsJSON, resultDataJSON []byte
err := s.pool.QueryRow(ctx, `
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone,
identity_verified, identity_verified_at, identity_verified_by, identity_verification_method,
request_details, deadline_at, legal_deadline_days, extended_deadline_at, extension_reason,
assigned_to, processing_notes, completed_at, completed_by, result_summary, result_data,
rejected_at, rejected_by, rejection_reason, rejection_legal_basis,
created_at, updated_at, created_by
FROM data_subject_requests WHERE id = $1
`, id).Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.IdentityVerifiedAt,
&dsr.IdentityVerifiedBy, &dsr.IdentityVerificationMethod,
&detailsJSON, &dsr.DeadlineAt, &dsr.LegalDeadlineDays,
&dsr.ExtendedDeadlineAt, &dsr.ExtensionReason, &dsr.AssignedTo,
&dsr.ProcessingNotes, &dsr.CompletedAt, &dsr.CompletedBy,
&dsr.ResultSummary, &resultDataJSON, &dsr.RejectedAt, &dsr.RejectedBy,
&dsr.RejectionReason, &dsr.RejectionLegalBasis,
&dsr.CreatedAt, &dsr.UpdatedAt, &dsr.CreatedBy,
)
if err != nil {
return nil, fmt.Errorf("DSR not found: %w", err)
}
json.Unmarshal(detailsJSON, &dsr.RequestDetails)
json.Unmarshal(resultDataJSON, &dsr.ResultData)
return &dsr, nil
}
// GetByNumber retrieves a DSR by request number
func (s *DSRService) GetByNumber(ctx context.Context, requestNumber string) (*models.DataSubjectRequest, error) {
var id uuid.UUID
err := s.pool.QueryRow(ctx, "SELECT id FROM data_subject_requests WHERE request_number = $1", requestNumber).Scan(&id)
if err != nil {
return nil, fmt.Errorf("DSR not found: %w", err)
}
return s.GetByID(ctx, id)
}
// List retrieves DSRs with filters and pagination
func (s *DSRService) List(ctx context.Context, filters models.DSRListFilters, limit, offset int) ([]models.DataSubjectRequest, int, error) {
// Build query
baseQuery := "FROM data_subject_requests WHERE 1=1"
args := []interface{}{}
argIndex := 1
if filters.Status != nil && *filters.Status != "" {
baseQuery += fmt.Sprintf(" AND status = $%d", argIndex)
args = append(args, *filters.Status)
argIndex++
}
if filters.RequestType != nil && *filters.RequestType != "" {
baseQuery += fmt.Sprintf(" AND request_type = $%d", argIndex)
args = append(args, *filters.RequestType)
argIndex++
}
if filters.AssignedTo != nil && *filters.AssignedTo != "" {
baseQuery += fmt.Sprintf(" AND assigned_to = $%d", argIndex)
args = append(args, *filters.AssignedTo)
argIndex++
}
if filters.Priority != nil && *filters.Priority != "" {
baseQuery += fmt.Sprintf(" AND priority = $%d", argIndex)
args = append(args, *filters.Priority)
argIndex++
}
if filters.OverdueOnly {
baseQuery += " AND deadline_at < NOW() AND status NOT IN ('completed', 'rejected', 'cancelled')"
}
if filters.FromDate != nil {
baseQuery += fmt.Sprintf(" AND created_at >= $%d", argIndex)
args = append(args, *filters.FromDate)
argIndex++
}
if filters.ToDate != nil {
baseQuery += fmt.Sprintf(" AND created_at <= $%d", argIndex)
args = append(args, *filters.ToDate)
argIndex++
}
if filters.Search != nil && *filters.Search != "" {
searchPattern := "%" + *filters.Search + "%"
baseQuery += fmt.Sprintf(" AND (request_number ILIKE $%d OR requester_email ILIKE $%d OR requester_name ILIKE $%d)", argIndex, argIndex, argIndex)
args = append(args, searchPattern)
argIndex++
}
// Get total count
var total int
err := s.pool.QueryRow(ctx, "SELECT COUNT(*) "+baseQuery, args...).Scan(&total)
if err != nil {
return nil, 0, fmt.Errorf("failed to count DSRs: %w", err)
}
// Get paginated results
query := fmt.Sprintf(`
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, requester_phone, identity_verified,
deadline_at, legal_deadline_days, assigned_to, created_at, updated_at
%s ORDER BY created_at DESC LIMIT $%d OFFSET $%d
`, baseQuery, argIndex, argIndex+1)
args = append(args, limit, offset)
rows, err := s.pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, fmt.Errorf("failed to query DSRs: %w", err)
}
defer rows.Close()
var dsrs []models.DataSubjectRequest
for rows.Next() {
var dsr models.DataSubjectRequest
err := rows.Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.RequesterPhone, &dsr.IdentityVerified, &dsr.DeadlineAt,
&dsr.LegalDeadlineDays, &dsr.AssignedTo, &dsr.CreatedAt, &dsr.UpdatedAt,
)
if err != nil {
return nil, 0, fmt.Errorf("failed to scan DSR: %w", err)
}
dsrs = append(dsrs, dsr)
}
return dsrs, total, nil
}
// ListByUser retrieves DSRs for a specific user
func (s *DSRService) ListByUser(ctx context.Context, userID uuid.UUID) ([]models.DataSubjectRequest, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, user_id, request_number, request_type, status, priority, source,
requester_email, requester_name, deadline_at, created_at, updated_at
FROM data_subject_requests
WHERE user_id = $1 OR requester_email = (SELECT email FROM users WHERE id = $1)
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to query user DSRs: %w", err)
}
defer rows.Close()
var dsrs []models.DataSubjectRequest
for rows.Next() {
var dsr models.DataSubjectRequest
err := rows.Scan(
&dsr.ID, &dsr.UserID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status,
&dsr.Priority, &dsr.Source, &dsr.RequesterEmail, &dsr.RequesterName,
&dsr.DeadlineAt, &dsr.CreatedAt, &dsr.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan DSR: %w", err)
}
dsrs = append(dsrs, dsr)
}
return dsrs, nil
}
// UpdateStatus changes the status of a DSR
func (s *DSRService) UpdateStatus(ctx context.Context, id uuid.UUID, newStatus models.DSRStatus, comment string, changedBy *uuid.UUID) error {
// Get current status
var currentStatus models.DSRStatus
err := s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
if err != nil {
return fmt.Errorf("DSR not found: %w", err)
}
// Validate transition
if !isValidStatusTransition(currentStatus, newStatus) {
return fmt.Errorf("invalid status transition from %s to %s", currentStatus, newStatus)
}
// Update status
_, err = s.pool.Exec(ctx, `
UPDATE data_subject_requests SET status = $1, updated_at = NOW() WHERE id = $2
`, newStatus, id)
if err != nil {
return fmt.Errorf("failed to update status: %w", err)
}
// Record status change
s.recordStatusChange(ctx, id, &currentStatus, newStatus, changedBy, comment)
return nil
}
// VerifyIdentity marks identity as verified
func (s *DSRService) VerifyIdentity(ctx context.Context, id uuid.UUID, method string, verifiedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET identity_verified = TRUE,
identity_verified_at = NOW(),
identity_verified_by = $1,
identity_verification_method = $2,
status = CASE WHEN status = 'intake' THEN 'identity_verification' ELSE status END,
updated_at = NOW()
WHERE id = $3
`, verifiedBy, method, id)
if err != nil {
return fmt.Errorf("failed to verify identity: %w", err)
}
s.recordStatusChange(ctx, id, nil, models.DSRStatusIdentityVerification, &verifiedBy, "Identität verifiziert via "+method)
return nil
}
// AssignRequest assigns a DSR to a handler
func (s *DSRService) AssignRequest(ctx context.Context, id uuid.UUID, assigneeID uuid.UUID, assignedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests SET assigned_to = $1, updated_at = NOW() WHERE id = $2
`, assigneeID, id)
if err != nil {
return fmt.Errorf("failed to assign DSR: %w", err)
}
// Get assignee name for comment
var assigneeName string
s.pool.QueryRow(ctx, "SELECT COALESCE(name, email) FROM users WHERE id = $1", assigneeID).Scan(&assigneeName)
s.recordStatusChange(ctx, id, nil, "", &assignedBy, "Zugewiesen an "+assigneeName)
// Notify assignee
go s.notifyAssignment(context.Background(), id, assigneeID)
return nil
}
// ExtendDeadline extends the deadline for a DSR
func (s *DSRService) ExtendDeadline(ctx context.Context, id uuid.UUID, reason string, days int, extendedBy uuid.UUID) error {
// Default extension is 2 months (60 days) per Art. 12(3)
if days <= 0 {
days = 60
}
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET extended_deadline_at = deadline_at + ($1 || ' days')::INTERVAL,
extension_reason = $2,
updated_at = NOW()
WHERE id = $3
`, days, reason, id)
if err != nil {
return fmt.Errorf("failed to extend deadline: %w", err)
}
s.recordStatusChange(ctx, id, nil, "", &extendedBy, fmt.Sprintf("Frist um %d Tage verlängert: %s", days, reason))
return nil
}
// CompleteRequest marks a DSR as completed
func (s *DSRService) CompleteRequest(ctx context.Context, id uuid.UUID, summary string, resultData map[string]interface{}, completedBy uuid.UUID) error {
resultJSON, _ := json.Marshal(resultData)
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET status = 'completed',
completed_at = NOW(),
completed_by = $1,
result_summary = $2,
result_data = $3,
updated_at = NOW()
WHERE id = $4
`, completedBy, summary, resultJSON, id)
if err != nil {
return fmt.Errorf("failed to complete DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusCompleted, &completedBy, summary)
return nil
}
// RejectRequest rejects a DSR with legal basis
func (s *DSRService) RejectRequest(ctx context.Context, id uuid.UUID, reason, legalBasis string, rejectedBy uuid.UUID) error {
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err := s.pool.Exec(ctx, `
UPDATE data_subject_requests
SET status = 'rejected',
rejected_at = NOW(),
rejected_by = $1,
rejection_reason = $2,
rejection_legal_basis = $3,
updated_at = NOW()
WHERE id = $4
`, rejectedBy, reason, legalBasis, id)
if err != nil {
return fmt.Errorf("failed to reject DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusRejected, &rejectedBy, fmt.Sprintf("Abgelehnt (%s): %s", legalBasis, reason))
return nil
}
// CancelRequest cancels a DSR (by user)
func (s *DSRService) CancelRequest(ctx context.Context, id uuid.UUID, cancelledBy uuid.UUID) error {
// Verify ownership
var userID *uuid.UUID
err := s.pool.QueryRow(ctx, "SELECT user_id FROM data_subject_requests WHERE id = $1", id).Scan(&userID)
if err != nil {
return fmt.Errorf("DSR not found: %w", err)
}
if userID == nil || *userID != cancelledBy {
return fmt.Errorf("unauthorized: can only cancel own requests")
}
// Get current status
var currentStatus models.DSRStatus
s.pool.QueryRow(ctx, "SELECT status FROM data_subject_requests WHERE id = $1", id).Scan(&currentStatus)
_, err = s.pool.Exec(ctx, `
UPDATE data_subject_requests SET status = 'cancelled', updated_at = NOW() WHERE id = $1
`, id)
if err != nil {
return fmt.Errorf("failed to cancel DSR: %w", err)
}
s.recordStatusChange(ctx, id, &currentStatus, models.DSRStatusCancelled, &cancelledBy, "Vom Antragsteller storniert")
return nil
}
// GetDashboardStats returns statistics for the admin dashboard
func (s *DSRService) GetDashboardStats(ctx context.Context) (*models.DSRDashboardStats, error) {
stats := &models.DSRDashboardStats{
ByType: make(map[string]int),
ByStatus: make(map[string]int),
}
// Total requests
s.pool.QueryRow(ctx, "SELECT COUNT(*) FROM data_subject_requests").Scan(&stats.TotalRequests)
// Pending requests (not completed, rejected, or cancelled)
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE status NOT IN ('completed', 'rejected', 'cancelled')
`).Scan(&stats.PendingRequests)
// Overdue requests
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) < NOW()
AND status NOT IN ('completed', 'rejected', 'cancelled')
`).Scan(&stats.OverdueRequests)
// Completed this month
s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM data_subject_requests
WHERE status = 'completed'
AND completed_at >= DATE_TRUNC('month', NOW())
`).Scan(&stats.CompletedThisMonth)
// Average processing days
s.pool.QueryRow(ctx, `
SELECT COALESCE(AVG(EXTRACT(EPOCH FROM (completed_at - created_at)) / 86400), 0)
FROM data_subject_requests WHERE status = 'completed'
`).Scan(&stats.AverageProcessingDays)
// Count by type
rows, _ := s.pool.Query(ctx, `
SELECT request_type, COUNT(*) FROM data_subject_requests GROUP BY request_type
`)
for rows.Next() {
var t string
var count int
rows.Scan(&t, &count)
stats.ByType[t] = count
}
rows.Close()
// Count by status
rows, _ = s.pool.Query(ctx, `
SELECT status, COUNT(*) FROM data_subject_requests GROUP BY status
`)
for rows.Next() {
var s string
var count int
rows.Scan(&s, &count)
stats.ByStatus[s] = count
}
rows.Close()
// Upcoming deadlines (next 7 days)
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, status, requester_email, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN NOW() AND NOW() + INTERVAL '7 days'
AND status NOT IN ('completed', 'rejected', 'cancelled')
ORDER BY deadline_at ASC LIMIT 10
`)
for rows.Next() {
var dsr models.DataSubjectRequest
rows.Scan(&dsr.ID, &dsr.RequestNumber, &dsr.RequestType, &dsr.Status, &dsr.RequesterEmail, &dsr.DeadlineAt)
stats.UpcomingDeadlines = append(stats.UpcomingDeadlines, dsr)
}
rows.Close()
return stats, nil
}
// GetStatusHistory retrieves the status history for a DSR
func (s *DSRService) GetStatusHistory(ctx context.Context, requestID uuid.UUID) ([]models.DSRStatusHistory, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, from_status, to_status, changed_by, comment, metadata, created_at
FROM dsr_status_history WHERE request_id = $1 ORDER BY created_at DESC
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query status history: %w", err)
}
defer rows.Close()
var history []models.DSRStatusHistory
for rows.Next() {
var h models.DSRStatusHistory
var metadataJSON []byte
err := rows.Scan(&h.ID, &h.RequestID, &h.FromStatus, &h.ToStatus, &h.ChangedBy, &h.Comment, &metadataJSON, &h.CreatedAt)
if err != nil {
continue
}
json.Unmarshal(metadataJSON, &h.Metadata)
history = append(history, h)
}
return history, nil
}
// GetCommunications retrieves communications for a DSR
func (s *DSRService) GetCommunications(ctx context.Context, requestID uuid.UUID) ([]models.DSRCommunication, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, direction, channel, communication_type, template_version_id,
subject, body_html, body_text, recipient_email, sent_at, error_message,
attachments, created_at, created_by
FROM dsr_communications WHERE request_id = $1 ORDER BY created_at DESC
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query communications: %w", err)
}
defer rows.Close()
var comms []models.DSRCommunication
for rows.Next() {
var c models.DSRCommunication
var attachmentsJSON []byte
err := rows.Scan(&c.ID, &c.RequestID, &c.Direction, &c.Channel, &c.CommunicationType,
&c.TemplateVersionID, &c.Subject, &c.BodyHTML, &c.BodyText, &c.RecipientEmail,
&c.SentAt, &c.ErrorMessage, &attachmentsJSON, &c.CreatedAt, &c.CreatedBy)
if err != nil {
continue
}
json.Unmarshal(attachmentsJSON, &c.Attachments)
comms = append(comms, c)
}
return comms, nil
}
// SendCommunication sends a communication for a DSR
func (s *DSRService) SendCommunication(ctx context.Context, requestID uuid.UUID, req models.SendDSRCommunicationRequest, sentBy uuid.UUID) error {
// Get DSR details
dsr, err := s.GetByID(ctx, requestID)
if err != nil {
return err
}
// Get template if specified
var subject, bodyHTML, bodyText string
if req.TemplateVersionID != nil {
templateVersionID, _ := uuid.Parse(*req.TemplateVersionID)
err := s.pool.QueryRow(ctx, `
SELECT subject, body_html, body_text FROM dsr_template_versions WHERE id = $1 AND status = 'published'
`, templateVersionID).Scan(&subject, &bodyHTML, &bodyText)
if err != nil {
return fmt.Errorf("template version not found or not published: %w", err)
}
}
// Use custom content if provided
if req.CustomSubject != nil {
subject = *req.CustomSubject
}
if req.CustomBody != nil {
bodyHTML = *req.CustomBody
bodyText = stripHTML(*req.CustomBody)
}
// Replace variables
variables := map[string]string{
"requester_name": stringOrDefault(dsr.RequesterName, "Antragsteller/in"),
"request_number": dsr.RequestNumber,
"request_type_de": dsr.RequestType.Label(),
"request_date": dsr.CreatedAt.Format("02.01.2006"),
"deadline_date": dsr.DeadlineAt.Format("02.01.2006"),
}
for k, v := range req.Variables {
variables[k] = v
}
subject = replaceVariables(subject, variables)
bodyHTML = replaceVariables(bodyHTML, variables)
bodyText = replaceVariables(bodyText, variables)
// Send email
if s.emailService != nil {
err = s.emailService.SendEmail(dsr.RequesterEmail, subject, bodyHTML, bodyText)
if err != nil {
// Log error but continue
_, _ = s.pool.Exec(ctx, `
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
template_version_id, subject, body_html, body_text, recipient_email, error_message, created_by)
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
dsr.RequesterEmail, err.Error(), sentBy)
return fmt.Errorf("failed to send email: %w", err)
}
}
// Log communication
now := time.Now()
_, err = s.pool.Exec(ctx, `
INSERT INTO dsr_communications (request_id, direction, channel, communication_type,
template_version_id, subject, body_html, body_text, recipient_email, sent_at, created_by)
VALUES ($1, 'outbound', 'email', $2, $3, $4, $5, $6, $7, $8, $9)
`, requestID, req.CommunicationType, req.TemplateVersionID, subject, bodyHTML, bodyText,
dsr.RequesterEmail, now, sentBy)
return err
}
// InitErasureExceptionChecks initializes exception checks for an erasure request
func (s *DSRService) InitErasureExceptionChecks(ctx context.Context, requestID uuid.UUID) error {
exceptions := []struct {
Type string
Description string
}{
{models.DSRExceptionFreedomExpression, "Ausübung des Rechts auf freie Meinungsäußerung und Information (Art. 17 Abs. 3 lit. a)"},
{models.DSRExceptionLegalObligation, "Erfüllung einer rechtlichen Verpflichtung oder öffentlichen Aufgabe (Art. 17 Abs. 3 lit. b)"},
{models.DSRExceptionPublicHealth, "Gründe des öffentlichen Interesses im Bereich der öffentlichen Gesundheit (Art. 17 Abs. 3 lit. c)"},
{models.DSRExceptionArchiving, "Im öffentlichen Interesse liegende Archivzwecke, Forschung oder Statistik (Art. 17 Abs. 3 lit. d)"},
{models.DSRExceptionLegalClaims, "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen (Art. 17 Abs. 3 lit. e)"},
}
for _, exc := range exceptions {
_, err := s.pool.Exec(ctx, `
INSERT INTO dsr_exception_checks (request_id, exception_type, description)
VALUES ($1, $2, $3) ON CONFLICT DO NOTHING
`, requestID, exc.Type, exc.Description)
if err != nil {
return fmt.Errorf("failed to create exception check: %w", err)
}
}
return nil
}
// GetExceptionChecks retrieves exception checks for a DSR
func (s *DSRService) GetExceptionChecks(ctx context.Context, requestID uuid.UUID) ([]models.DSRExceptionCheck, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, request_id, exception_type, description, applies, checked_by, checked_at, notes, created_at
FROM dsr_exception_checks WHERE request_id = $1 ORDER BY created_at
`, requestID)
if err != nil {
return nil, fmt.Errorf("failed to query exception checks: %w", err)
}
defer rows.Close()
var checks []models.DSRExceptionCheck
for rows.Next() {
var c models.DSRExceptionCheck
err := rows.Scan(&c.ID, &c.RequestID, &c.ExceptionType, &c.Description, &c.Applies,
&c.CheckedBy, &c.CheckedAt, &c.Notes, &c.CreatedAt)
if err != nil {
continue
}
checks = append(checks, c)
}
return checks, nil
}
// UpdateExceptionCheck updates an exception check
func (s *DSRService) UpdateExceptionCheck(ctx context.Context, checkID uuid.UUID, applies bool, notes *string, checkedBy uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE dsr_exception_checks
SET applies = $1, notes = $2, checked_by = $3, checked_at = NOW()
WHERE id = $4
`, applies, notes, checkedBy, checkID)
return err
}
// ProcessDeadlines checks for approaching and overdue deadlines
func (s *DSRService) ProcessDeadlines(ctx context.Context) error {
now := time.Now()
// Find requests with deadlines in 3 days
threeDaysAhead := now.AddDate(0, 0, 3)
rows, _ := s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now, threeDaysAhead)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
// Notify assigned user or all DPOs
if assignedTo != nil {
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 3)
} else {
s.notifyAllDPOs(ctx, id, requestNumber, "Frist in 3 Tagen", deadline)
}
}
rows.Close()
// Find requests with deadlines in 1 day
oneDayAhead := now.AddDate(0, 0, 1)
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) BETWEEN $1 AND $2
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now, oneDayAhead)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
if assignedTo != nil {
s.notifyDeadlineWarning(ctx, id, *assignedTo, requestNumber, deadline, 1)
} else {
s.notifyAllDPOs(ctx, id, requestNumber, "Frist morgen!", deadline)
}
}
rows.Close()
// Find overdue requests
rows, _ = s.pool.Query(ctx, `
SELECT id, request_number, request_type, assigned_to, deadline_at
FROM data_subject_requests
WHERE COALESCE(extended_deadline_at, deadline_at) < $1
AND status NOT IN ('completed', 'rejected', 'cancelled')
`, now)
for rows.Next() {
var id uuid.UUID
var requestNumber, requestType string
var assignedTo *uuid.UUID
var deadline time.Time
rows.Scan(&id, &requestNumber, &requestType, &assignedTo, &deadline)
// Notify all DPOs for overdue
s.notifyAllDPOs(ctx, id, requestNumber, "ÜBERFÄLLIG!", deadline)
// Log to audit
s.pool.Exec(ctx, `
INSERT INTO consent_audit_log (action, entity_type, entity_id, details)
VALUES ('dsr_overdue', 'dsr', $1, $2)
`, id, fmt.Sprintf(`{"request_number": "%s", "deadline": "%s"}`, requestNumber, deadline.Format(time.RFC3339)))
}
rows.Close()
return nil
}
// Helper functions
func (s *DSRService) recordStatusChange(ctx context.Context, requestID uuid.UUID, fromStatus *models.DSRStatus, toStatus models.DSRStatus, changedBy *uuid.UUID, comment string) {
s.pool.Exec(ctx, `
INSERT INTO dsr_status_history (request_id, from_status, to_status, changed_by, comment)
VALUES ($1, $2, $3, $4, $5)
`, requestID, fromStatus, toStatus, changedBy, comment)
}
func (s *DSRService) notifyNewRequest(ctx context.Context, dsr *models.DataSubjectRequest) {
if s.notificationService == nil {
return
}
// Notify all DPOs
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
rows.Scan(&userID)
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRReceived,
"Neue Betroffenenanfrage",
fmt.Sprintf("Neue %s eingegangen: %s", dsr.RequestType.Label(), dsr.RequestNumber),
map[string]interface{}{"dsr_id": dsr.ID, "request_number": dsr.RequestNumber})
}
}
func (s *DSRService) notifyAssignment(ctx context.Context, dsrID, assigneeID uuid.UUID) {
if s.notificationService == nil {
return
}
dsr, _ := s.GetByID(ctx, dsrID)
if dsr != nil {
s.notificationService.CreateNotification(ctx, assigneeID, NotificationTypeDSRAssigned,
"Betroffenenanfrage zugewiesen",
fmt.Sprintf("Ihnen wurde die Anfrage %s zugewiesen", dsr.RequestNumber),
map[string]interface{}{"dsr_id": dsrID, "request_number": dsr.RequestNumber})
}
}
func (s *DSRService) notifyDeadlineWarning(ctx context.Context, dsrID, userID uuid.UUID, requestNumber string, deadline time.Time, daysLeft int) {
if s.notificationService == nil {
return
}
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
fmt.Sprintf("Fristwarnung: %s", requestNumber),
fmt.Sprintf("Die Frist für %s läuft in %d Tag(en) ab (%s)", requestNumber, daysLeft, deadline.Format("02.01.2006")),
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline, "days_left": daysLeft})
}
func (s *DSRService) notifyAllDPOs(ctx context.Context, dsrID uuid.UUID, requestNumber, message string, deadline time.Time) {
if s.notificationService == nil {
return
}
rows, _ := s.pool.Query(ctx, "SELECT id FROM users WHERE role = 'data_protection_officer'")
defer rows.Close()
for rows.Next() {
var userID uuid.UUID
rows.Scan(&userID)
s.notificationService.CreateNotification(ctx, userID, NotificationTypeDSRDeadline,
fmt.Sprintf("%s: %s", message, requestNumber),
fmt.Sprintf("Anfrage %s: %s (Frist: %s)", requestNumber, message, deadline.Format("02.01.2006")),
map[string]interface{}{"dsr_id": dsrID, "deadline": deadline})
}
}
func isValidRequestType(rt models.DSRRequestType) bool {
switch rt {
case models.DSRTypeAccess, models.DSRTypeRectification, models.DSRTypeErasure,
models.DSRTypeRestriction, models.DSRTypePortability:
return true
}
return false
}
func isValidStatusTransition(from, to models.DSRStatus) bool {
validTransitions := map[models.DSRStatus][]models.DSRStatus{
models.DSRStatusIntake: {models.DSRStatusIdentityVerification, models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusIdentityVerification: {models.DSRStatusProcessing, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusProcessing: {models.DSRStatusCompleted, models.DSRStatusRejected, models.DSRStatusCancelled},
models.DSRStatusCompleted: {},
models.DSRStatusRejected: {},
models.DSRStatusCancelled: {},
}
allowed, exists := validTransitions[from]
if !exists {
return false
}
for _, s := range allowed {
if s == to {
return true
}
}
return false
}
func stringOrDefault(s *string, def string) string {
if s != nil {
return *s
}
return def
}
func replaceVariables(text string, variables map[string]string) string {
for k, v := range variables {
text = strings.ReplaceAll(text, "{{"+k+"}}", v)
}
return text
}
func stripHTML(html string) string {
// Simple HTML stripping - in production use a proper library
text := strings.ReplaceAll(html, "<br>", "\n")
text = strings.ReplaceAll(text, "<br/>", "\n")
text = strings.ReplaceAll(text, "<br />", "\n")
text = strings.ReplaceAll(text, "</p>", "\n\n")
// Remove all remaining tags
for {
start := strings.Index(text, "<")
if start == -1 {
break
}
end := strings.Index(text[start:], ">")
if end == -1 {
break
}
text = text[:start] + text[start+end+1:]
}
return strings.TrimSpace(text)
}

View File

@@ -0,0 +1,420 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
)
// TestDSRRequestTypeLabel tests label generation for request types
func TestDSRRequestTypeLabel(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expected string
}{
{"access type", models.DSRTypeAccess, "Auskunftsanfrage (Art. 15)"},
{"rectification type", models.DSRTypeRectification, "Berichtigungsanfrage (Art. 16)"},
{"erasure type", models.DSRTypeErasure, "Löschanfrage (Art. 17)"},
{"restriction type", models.DSRTypeRestriction, "Einschränkungsanfrage (Art. 18)"},
{"portability type", models.DSRTypePortability, "Datenübertragung (Art. 20)"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.Label()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
// TestDSRRequestTypeDeadlineDays tests deadline calculation for different request types
func TestDSRRequestTypeDeadlineDays(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expectedDays int
}{
{"access has 30 days", models.DSRTypeAccess, 30},
{"portability has 30 days", models.DSRTypePortability, 30},
{"rectification has 14 days", models.DSRTypeRectification, 14},
{"erasure has 14 days", models.DSRTypeErasure, 14},
{"restriction has 14 days", models.DSRTypeRestriction, 14},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.DeadlineDays()
if result != tt.expectedDays {
t.Errorf("Expected %d days, got %d", tt.expectedDays, result)
}
})
}
}
// TestDSRRequestTypeIsExpedited tests expedited flag for request types
func TestDSRRequestTypeIsExpedited(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
isExpedited bool
}{
{"access not expedited", models.DSRTypeAccess, false},
{"portability not expedited", models.DSRTypePortability, false},
{"rectification is expedited", models.DSRTypeRectification, true},
{"erasure is expedited", models.DSRTypeErasure, true},
{"restriction is expedited", models.DSRTypeRestriction, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.reqType.IsExpedited()
if result != tt.isExpedited {
t.Errorf("Expected IsExpedited=%v, got %v", tt.isExpedited, result)
}
})
}
}
// TestDSRStatusLabel tests label generation for statuses
func TestDSRStatusLabel(t *testing.T) {
tests := []struct {
name string
status models.DSRStatus
expected string
}{
{"intake status", models.DSRStatusIntake, "Eingang"},
{"identity verification", models.DSRStatusIdentityVerification, "Identitätsprüfung"},
{"processing status", models.DSRStatusProcessing, "In Bearbeitung"},
{"completed status", models.DSRStatusCompleted, "Abgeschlossen"},
{"rejected status", models.DSRStatusRejected, "Abgelehnt"},
{"cancelled status", models.DSRStatusCancelled, "Storniert"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.status.Label()
if result != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, result)
}
})
}
}
// TestValidDSRRequestType tests request type validation
func TestValidDSRRequestType(t *testing.T) {
tests := []struct {
name string
reqType string
valid bool
}{
{"valid access", "access", true},
{"valid rectification", "rectification", true},
{"valid erasure", "erasure", true},
{"valid restriction", "restriction", true},
{"valid portability", "portability", true},
{"invalid type", "invalid", false},
{"empty type", "", false},
{"random string", "delete_everything", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := models.IsValidDSRRequestType(tt.reqType)
if result != tt.valid {
t.Errorf("Expected IsValidDSRRequestType=%v for %s, got %v", tt.valid, tt.reqType, result)
}
})
}
}
// TestValidDSRStatus tests status validation
func TestValidDSRStatus(t *testing.T) {
tests := []struct {
name string
status string
valid bool
}{
{"valid intake", "intake", true},
{"valid identity_verification", "identity_verification", true},
{"valid processing", "processing", true},
{"valid completed", "completed", true},
{"valid rejected", "rejected", true},
{"valid cancelled", "cancelled", true},
{"invalid status", "invalid", false},
{"empty status", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := models.IsValidDSRStatus(tt.status)
if result != tt.valid {
t.Errorf("Expected IsValidDSRStatus=%v for %s, got %v", tt.valid, tt.status, result)
}
})
}
}
// TestDSRStatusTransitionValidation tests allowed status transitions
func TestDSRStatusTransitionValidation(t *testing.T) {
tests := []struct {
name string
fromStatus models.DSRStatus
toStatus models.DSRStatus
allowed bool
}{
// From intake
{"intake to identity_verification", models.DSRStatusIntake, models.DSRStatusIdentityVerification, true},
{"intake to processing", models.DSRStatusIntake, models.DSRStatusProcessing, true},
{"intake to rejected", models.DSRStatusIntake, models.DSRStatusRejected, true},
{"intake to cancelled", models.DSRStatusIntake, models.DSRStatusCancelled, true},
{"intake to completed invalid", models.DSRStatusIntake, models.DSRStatusCompleted, false},
// From identity_verification
{"identity to processing", models.DSRStatusIdentityVerification, models.DSRStatusProcessing, true},
{"identity to rejected", models.DSRStatusIdentityVerification, models.DSRStatusRejected, true},
{"identity to cancelled", models.DSRStatusIdentityVerification, models.DSRStatusCancelled, true},
// From processing
{"processing to completed", models.DSRStatusProcessing, models.DSRStatusCompleted, true},
{"processing to rejected", models.DSRStatusProcessing, models.DSRStatusRejected, true},
{"processing to intake invalid", models.DSRStatusProcessing, models.DSRStatusIntake, false},
// From completed
{"completed to anything invalid", models.DSRStatusCompleted, models.DSRStatusProcessing, false},
// From rejected
{"rejected to anything invalid", models.DSRStatusRejected, models.DSRStatusProcessing, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := testIsValidStatusTransition(tt.fromStatus, tt.toStatus)
if result != tt.allowed {
t.Errorf("Expected transition %s->%s allowed=%v, got %v",
tt.fromStatus, tt.toStatus, tt.allowed, result)
}
})
}
}
// testIsValidStatusTransition is a test helper for validating status transitions
// This mirrors the logic in dsr_service.go for testing purposes
func testIsValidStatusTransition(from, to models.DSRStatus) bool {
validTransitions := map[models.DSRStatus][]models.DSRStatus{
models.DSRStatusIntake: {
models.DSRStatusIdentityVerification,
models.DSRStatusProcessing,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusIdentityVerification: {
models.DSRStatusProcessing,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusProcessing: {
models.DSRStatusCompleted,
models.DSRStatusRejected,
models.DSRStatusCancelled,
},
models.DSRStatusCompleted: {},
models.DSRStatusRejected: {},
models.DSRStatusCancelled: {},
}
allowed, exists := validTransitions[from]
if !exists {
return false
}
for _, s := range allowed {
if s == to {
return true
}
}
return false
}
// TestCalculateDeadline tests deadline calculation
func TestCalculateDeadline(t *testing.T) {
tests := []struct {
name string
reqType models.DSRRequestType
expectedDays int
}{
{"access 30 days", models.DSRTypeAccess, 30},
{"erasure 14 days", models.DSRTypeErasure, 14},
{"rectification 14 days", models.DSRTypeRectification, 14},
{"restriction 14 days", models.DSRTypeRestriction, 14},
{"portability 30 days", models.DSRTypePortability, 30},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
now := time.Now()
deadline := now.AddDate(0, 0, tt.expectedDays)
days := tt.reqType.DeadlineDays()
if days != tt.expectedDays {
t.Errorf("Expected %d days, got %d", tt.expectedDays, days)
}
// Verify deadline is approximately correct (within 1 day due to test timing)
calculatedDeadline := now.AddDate(0, 0, days)
diff := calculatedDeadline.Sub(deadline)
if diff > time.Hour*24 || diff < -time.Hour*24 {
t.Errorf("Deadline calculation off by more than a day")
}
})
}
}
// TestCreateDSRRequest_Validation tests validation of create request
func TestCreateDSRRequest_Validation(t *testing.T) {
tests := []struct {
name string
request models.CreateDSRRequest
expectError bool
}{
{
name: "valid access request",
request: models.CreateDSRRequest{
RequestType: "access",
RequesterEmail: "test@example.com",
},
expectError: false,
},
{
name: "valid erasure request with name",
request: models.CreateDSRRequest{
RequestType: "erasure",
RequesterEmail: "test@example.com",
RequesterName: stringPtr("Max Mustermann"),
},
expectError: false,
},
{
name: "missing email",
request: models.CreateDSRRequest{
RequestType: "access",
},
expectError: true,
},
{
name: "invalid request type",
request: models.CreateDSRRequest{
RequestType: "invalid_type",
RequesterEmail: "test@example.com",
},
expectError: true,
},
{
name: "empty request type",
request: models.CreateDSRRequest{
RequestType: "",
RequesterEmail: "test@example.com",
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testValidateCreateDSRRequest(tt.request)
hasError := err != nil
if hasError != tt.expectError {
t.Errorf("Expected error=%v, got error=%v (err: %v)", tt.expectError, hasError, err)
}
})
}
}
// testValidateCreateDSRRequest is a test helper for validating create DSR requests
func testValidateCreateDSRRequest(req models.CreateDSRRequest) error {
if req.RequesterEmail == "" {
return &dsrValidationError{"requester_email is required"}
}
if !models.IsValidDSRRequestType(req.RequestType) {
return &dsrValidationError{"invalid request_type"}
}
return nil
}
type dsrValidationError struct {
Message string
}
func (e *dsrValidationError) Error() string {
return e.Message
}
// TestDSRTemplateTypes tests the template types
func TestDSRTemplateTypes(t *testing.T) {
expectedTemplates := []string{
"dsr_receipt_access",
"dsr_receipt_rectification",
"dsr_receipt_erasure",
"dsr_receipt_restriction",
"dsr_receipt_portability",
"dsr_identity_request",
"dsr_processing_started",
"dsr_processing_update",
"dsr_clarification_request",
"dsr_completed_access",
"dsr_completed_rectification",
"dsr_completed_erasure",
"dsr_completed_restriction",
"dsr_completed_portability",
"dsr_restriction_lifted",
"dsr_rejected_identity",
"dsr_rejected_exception",
"dsr_rejected_unfounded",
"dsr_deadline_warning",
}
// This test documents the expected template types
// The actual templates are created in database migration
for _, template := range expectedTemplates {
if template == "" {
t.Error("Template type should not be empty")
}
}
if len(expectedTemplates) != 19 {
t.Errorf("Expected 19 template types, got %d", len(expectedTemplates))
}
}
// TestErasureExceptionTypes tests Art. 17(3) exception types
func TestErasureExceptionTypes(t *testing.T) {
exceptions := []struct {
code string
description string
}{
{"art_17_3_a", "Meinungs- und Informationsfreiheit"},
{"art_17_3_b", "Rechtliche Verpflichtung"},
{"art_17_3_c", "Öffentliches Interesse im Gesundheitsbereich"},
{"art_17_3_d", "Archivzwecke, wissenschaftliche/historische Forschung"},
{"art_17_3_e", "Geltendmachung, Ausübung oder Verteidigung von Rechtsansprüchen"},
}
if len(exceptions) != 5 {
t.Errorf("Expected 5 Art. 17(3) exceptions, got %d", len(exceptions))
}
for _, ex := range exceptions {
if ex.code == "" || ex.description == "" {
t.Error("Exception code and description should not be empty")
}
}
}
// stringPtr returns a pointer to the given string
func stringPtr(s string) *string {
return &s
}

View File

@@ -0,0 +1,554 @@
package services
import (
"bytes"
"fmt"
"html/template"
"net/smtp"
"strings"
)
// EmailConfig holds SMTP configuration
type EmailConfig struct {
Host string
Port int
Username string
Password string
FromName string
FromAddr string
BaseURL string // Frontend URL for links
}
// EmailService handles sending emails
type EmailService struct {
config EmailConfig
}
// NewEmailService creates a new EmailService
func NewEmailService(config EmailConfig) *EmailService {
return &EmailService{config: config}
}
// SendEmail sends an email
func (s *EmailService) SendEmail(to, subject, htmlBody, textBody string) error {
// Build MIME message
var msg bytes.Buffer
msg.WriteString(fmt.Sprintf("From: %s <%s>\r\n", s.config.FromName, s.config.FromAddr))
msg.WriteString(fmt.Sprintf("To: %s\r\n", to))
msg.WriteString(fmt.Sprintf("Subject: %s\r\n", subject))
msg.WriteString("MIME-Version: 1.0\r\n")
msg.WriteString("Content-Type: multipart/alternative; boundary=\"boundary42\"\r\n")
msg.WriteString("\r\n")
// Text part
msg.WriteString("--boundary42\r\n")
msg.WriteString("Content-Type: text/plain; charset=\"UTF-8\"\r\n")
msg.WriteString("\r\n")
msg.WriteString(textBody)
msg.WriteString("\r\n")
// HTML part
msg.WriteString("--boundary42\r\n")
msg.WriteString("Content-Type: text/html; charset=\"UTF-8\"\r\n")
msg.WriteString("\r\n")
msg.WriteString(htmlBody)
msg.WriteString("\r\n")
msg.WriteString("--boundary42--\r\n")
// Send email
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
err := smtp.SendMail(addr, auth, s.config.FromAddr, []string{to}, msg.Bytes())
if err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
return nil
}
// SendVerificationEmail sends an email verification email
func (s *EmailService) SendVerificationEmail(to, name, token string) error {
verifyLink := fmt.Sprintf("%s/verify-email?token=%s", s.config.BaseURL, token)
subject := "Bitte bestätigen Sie Ihre E-Mail-Adresse - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Willkommen bei BreakPilot!
Bitte bestätigen Sie Ihre E-Mail-Adresse, indem Sie den folgenden Link öffnen:
%s
Dieser Link ist 24 Stunden gültig.
Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), verifyLink)
htmlBody := s.renderTemplate("verification", map[string]interface{}{
"Name": getDisplayName(name),
"VerifyLink": verifyLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendPasswordResetEmail sends a password reset email
func (s *EmailService) SendPasswordResetEmail(to, name, token string) error {
resetLink := fmt.Sprintf("%s/reset-password?token=%s", s.config.BaseURL, token)
subject := "Passwort zurücksetzen - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.
Klicken Sie auf den folgenden Link, um Ihr Passwort zurückzusetzen:
%s
Dieser Link ist 1 Stunde gültig.
Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), resetLink)
htmlBody := s.renderTemplate("password_reset", map[string]interface{}{
"Name": getDisplayName(name),
"ResetLink": resetLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendNewVersionNotification sends a notification about new document version
func (s *EmailService) SendNewVersionNotification(to, name, documentName, documentType string, deadlineDays int) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
subject := fmt.Sprintf("Neue Version: %s - Bitte bestätigen Sie innerhalb von %d Tagen", documentName, deadlineDays)
textBody := fmt.Sprintf(`Hallo %s,
Wir haben unsere %s aktualisiert.
Bitte lesen und bestätigen Sie die neuen Bedingungen innerhalb der nächsten %d Tage:
%s
Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), documentName, deadlineDays, consentLink)
htmlBody := s.renderTemplate("new_version", map[string]interface{}{
"Name": getDisplayName(name),
"DocumentName": documentName,
"DeadlineDays": deadlineDays,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendConsentReminder sends a consent reminder email
func (s *EmailService) SendConsentReminder(to, name string, documents []string, daysLeft int) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
urgency := "Erinnerung"
if daysLeft <= 7 {
urgency = "Dringend"
}
if daysLeft <= 2 {
urgency = "Letzte Warnung"
}
subject := fmt.Sprintf("%s: Noch %d Tage um ausstehende Dokumente zu bestätigen", urgency, daysLeft)
docList := strings.Join(documents, "\n- ")
textBody := fmt.Sprintf(`Hallo %s,
Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.
Ausstehende Dokumente:
- %s
Sie haben noch %d Tage Zeit. Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
Bitte bestätigen Sie hier:
%s
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), docList, daysLeft, consentLink)
htmlBody := s.renderTemplate("reminder", map[string]interface{}{
"Name": getDisplayName(name),
"Documents": documents,
"DaysLeft": daysLeft,
"Urgency": urgency,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendAccountSuspendedNotification sends notification when account is suspended
func (s *EmailService) SendAccountSuspendedNotification(to, name string, documents []string) error {
consentLink := fmt.Sprintf("%s/app?consent=pending", s.config.BaseURL)
subject := "Ihr Account wurde vorübergehend gesperrt - BreakPilot"
docList := strings.Join(documents, "\n- ")
textBody := fmt.Sprintf(`Hallo %s,
Ihr Account wurde vorübergehend gesperrt, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben:
- %s
Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:
%s
Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), docList, consentLink)
htmlBody := s.renderTemplate("suspended", map[string]interface{}{
"Name": getDisplayName(name),
"Documents": documents,
"ConsentLink": consentLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// SendAccountReactivatedNotification sends notification when account is reactivated
func (s *EmailService) SendAccountReactivatedNotification(to, name string) error {
appLink := fmt.Sprintf("%s/app", s.config.BaseURL)
subject := "Ihr Account wurde wieder aktiviert - BreakPilot"
textBody := fmt.Sprintf(`Hallo %s,
Vielen Dank für die Bestätigung der rechtlichen Dokumente!
Ihr Account wurde wieder aktiviert und Sie können BreakPilot wie gewohnt nutzen:
%s
Mit freundlichen Grüßen,
Ihr BreakPilot Team`, getDisplayName(name), appLink)
htmlBody := s.renderTemplate("reactivated", map[string]interface{}{
"Name": getDisplayName(name),
"AppLink": appLink,
})
return s.SendEmail(to, subject, htmlBody, textBody)
}
// renderTemplate renders an email HTML template
func (s *EmailService) renderTemplate(templateName string, data map[string]interface{}) string {
templates := map[string]string{
"verification": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Willkommen bei BreakPilot!</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Vielen Dank für Ihre Registrierung! Bitte bestätigen Sie Ihre E-Mail-Adresse, um Ihr Konto zu aktivieren.</p>
<p style="text-align: center;">
<a href="{{.VerifyLink}}" class="button">E-Mail bestätigen</a>
</p>
<p>Dieser Link ist 24 Stunden gültig.</p>
<p>Falls Sie sich nicht bei BreakPilot registriert haben, können Sie diese E-Mail ignorieren.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"password_reset": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Passwort zurücksetzen</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Sie haben eine Anfrage zum Zurücksetzen Ihres Passworts gestellt.</p>
<p style="text-align: center;">
<a href="{{.ResetLink}}" class="button">Passwort zurücksetzen</a>
</p>
<div class="warning">
<strong>Hinweis:</strong> Dieser Link ist nur 1 Stunde gültig.
</div>
<p>Falls Sie keine Passwort-Zurücksetzung angefordert haben, können Sie diese E-Mail ignorieren.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"new_version": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.info-box { background: #e0e7ff; border-left: 4px solid #6366f1; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Neue Version: {{.DocumentName}}</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Wir haben unsere <strong>{{.DocumentName}}</strong> aktualisiert.</p>
<div class="info-box">
<strong>Wichtig:</strong> Bitte bestätigen Sie die neuen Bedingungen innerhalb der nächsten <strong>{{.DeadlineDays}} Tage</strong>.
</div>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Dokument ansehen & bestätigen</a>
</p>
<p>Falls Sie nicht innerhalb dieser Frist bestätigen, wird Ihr Account vorübergehend gesperrt.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"reminder": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #f59e0b, #d97706); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #f59e0b; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.warning { background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px; margin: 20px 0; }
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>{{.Urgency}}: Ausstehende Bestätigungen</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<p>Dies ist eine freundliche Erinnerung, dass Sie noch ausstehende rechtliche Dokumente bestätigen müssen.</p>
<div class="doc-list">
<strong>Ausstehende Dokumente:</strong>
<ul>
{{range .Documents}}<li>{{.}}</li>{{end}}
</ul>
</div>
<div class="warning">
<strong>Sie haben noch {{.DaysLeft}} Tage Zeit.</strong> Nach Ablauf dieser Frist wird Ihr Account vorübergehend gesperrt.
</div>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Jetzt bestätigen</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"suspended": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #ef4444, #dc2626); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.alert { background: #fee2e2; border-left: 4px solid #ef4444; padding: 12px; margin: 20px 0; }
.doc-list { background: white; padding: 15px; border-radius: 8px; margin: 15px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Account vorübergehend gesperrt</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<div class="alert">
<strong>Ihr Account wurde vorübergehend gesperrt</strong>, da Sie die folgenden rechtlichen Dokumente nicht innerhalb der Frist bestätigt haben.
</div>
<div class="doc-list">
<strong>Nicht bestätigte Dokumente:</strong>
<ul>
{{range .Documents}}<li>{{.}}</li>{{end}}
</ul>
</div>
<p>Um Ihren Account zu entsperren, bestätigen Sie bitte alle ausstehenden Dokumente:</p>
<p style="text-align: center;">
<a href="{{.ConsentLink}}" class="button">Dokumente bestätigen & Account entsperren</a>
</p>
<p>Sobald Sie alle Dokumente bestätigt haben, wird Ihr Account automatisch entsperrt.</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"reactivated": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #22c55e, #16a34a); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #22c55e; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.success { background: #dcfce7; border-left: 4px solid #22c55e; padding: 12px; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>Account wieder aktiviert!</h1>
</div>
<div class="content">
<p>Hallo {{.Name}},</p>
<div class="success">
<strong>Vielen Dank!</strong> Ihr Account wurde erfolgreich wieder aktiviert.
</div>
<p>Sie können BreakPilot ab sofort wieder wie gewohnt nutzen.</p>
<p style="text-align: center;">
<a href="{{.AppLink}}" class="button">Zu BreakPilot</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
"generic_notification": `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background: linear-gradient(135deg, #6366f1, #8b5cf6); color: white; padding: 30px; border-radius: 10px 10px 0 0; text-align: center; }
.content { background: #f8fafc; padding: 30px; border-radius: 0 0 10px 10px; }
.button { display: inline-block; background: #6366f1; color: white; padding: 14px 28px; text-decoration: none; border-radius: 8px; font-weight: 600; margin: 20px 0; }
.footer { text-align: center; color: #64748b; font-size: 12px; margin-top: 30px; }
</style>
</head>
<body>
<div class="header">
<h1>{{.Title}}</h1>
</div>
<div class="content">
<p>{{.Body}}</p>
<p style="text-align: center;">
<a href="{{.BaseURL}}/app" class="button">Zu BreakPilot</a>
</p>
</div>
<div class="footer">
<p>© 2024 BreakPilot. Alle Rechte vorbehalten.</p>
</div>
</body>
</html>`,
}
tmplStr, ok := templates[templateName]
if !ok {
return ""
}
tmpl, err := template.New(templateName).Parse(tmplStr)
if err != nil {
return ""
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return ""
}
return buf.String()
}
// SendConsentReminderEmail sends a simplified consent reminder email
func (s *EmailService) SendConsentReminderEmail(to, title, body string) error {
subject := title
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
"Title": title,
"Body": body,
"BaseURL": s.config.BaseURL,
})
return s.SendEmail(to, subject, htmlBody, body)
}
// SendGenericNotificationEmail sends a generic notification email
func (s *EmailService) SendGenericNotificationEmail(to, title, body string) error {
subject := title
htmlBody := s.renderTemplate("generic_notification", map[string]interface{}{
"Title": title,
"Body": body,
"BaseURL": s.config.BaseURL,
})
return s.SendEmail(to, subject, htmlBody, body)
}
// getDisplayName returns display name or fallback
func getDisplayName(name string) string {
if name != "" {
return name
}
return "Nutzer"
}

View File

@@ -0,0 +1,624 @@
package services
import (
"fmt"
"net/smtp"
"regexp"
"strings"
"testing"
)
// MockSMTPSender is a mock SMTP sender for testing
type MockSMTPSender struct {
SentEmails []SentEmail
ShouldFail bool
FailError error
}
// SentEmail represents a sent email for testing
type SentEmail struct {
To []string
Subject string
Body string
}
// SendMail is a mock implementation of smtp.SendMail
func (m *MockSMTPSender) SendMail(addr string, auth smtp.Auth, from string, to []string, msg []byte) error {
if m.ShouldFail {
return m.FailError
}
// Parse the email to extract subject and body
msgStr := string(msg)
subject := extractSubject(msgStr)
m.SentEmails = append(m.SentEmails, SentEmail{
To: to,
Subject: subject,
Body: msgStr,
})
return nil
}
// extractSubject extracts the subject from an email message
func extractSubject(msg string) string {
lines := strings.Split(msg, "\r\n")
for _, line := range lines {
if strings.HasPrefix(line, "Subject: ") {
return strings.TrimPrefix(line, "Subject: ")
}
}
return ""
}
// TestEmailService_SendEmail tests basic email sending
func TestEmailService_SendEmail(t *testing.T) {
tests := []struct {
name string
to string
subject string
htmlBody string
textBody string
shouldFail bool
expectError bool
}{
{
name: "valid email",
to: "user@example.com",
subject: "Test Email",
htmlBody: "<h1>Hello</h1><p>World</p>",
textBody: "Hello\nWorld",
shouldFail: false,
expectError: false,
},
{
name: "email with special characters",
to: "user+test@example.com",
subject: "Test: Öäü Special Characters",
htmlBody: "<p>Special: €£¥</p>",
textBody: "Special: €£¥",
shouldFail: false,
expectError: false,
},
{
name: "SMTP failure",
to: "user@example.com",
subject: "Test",
htmlBody: "<p>Test</p>",
textBody: "Test",
shouldFail: true,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate email format
isValidEmail := strings.Contains(tt.to, "@") && strings.Contains(tt.to, ".")
if !isValidEmail && !tt.expectError {
t.Error("Invalid email format should produce error")
}
// Validate subject is not empty
if tt.subject == "" && !tt.expectError {
t.Error("Empty subject should produce error")
}
// Validate body content exists
if (tt.htmlBody == "" && tt.textBody == "") && !tt.expectError {
t.Error("Both bodies empty should produce error")
}
// Simulate SMTP send
var err error
if tt.shouldFail {
err = fmt.Errorf("SMTP error: connection refused")
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_SendVerificationEmail tests verification email sending
func TestEmailService_SendVerificationEmail(t *testing.T) {
tests := []struct {
name string
to string
userName string
token string
expectError bool
}{
{
name: "valid verification email",
to: "newuser@example.com",
userName: "Max Mustermann",
token: "abc123def456",
expectError: false,
},
{
name: "user without name",
to: "user@example.com",
userName: "",
token: "token123",
expectError: false,
},
{
name: "empty token",
to: "user@example.com",
userName: "Test User",
token: "",
expectError: true,
},
{
name: "invalid email",
to: "invalid-email",
userName: "Test",
token: "token123",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Validate inputs
var err error
if tt.token == "" {
err = &ValidationError{Field: "token", Message: "required"}
} else if !strings.Contains(tt.to, "@") {
err = &ValidationError{Field: "email", Message: "invalid format"}
}
// Build verification link
if tt.token != "" {
verifyLink := fmt.Sprintf("https://example.com/verify-email?token=%s", tt.token)
if verifyLink == "" {
t.Error("Verification link should not be empty")
}
// Verify link contains token
if !strings.Contains(verifyLink, tt.token) {
t.Error("Verification link should contain token")
}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_SendPasswordResetEmail tests password reset email
func TestEmailService_SendPasswordResetEmail(t *testing.T) {
tests := []struct {
name string
to string
userName string
token string
expectError bool
}{
{
name: "valid password reset",
to: "user@example.com",
userName: "John Doe",
token: "reset-token-123",
expectError: false,
},
{
name: "empty token",
to: "user@example.com",
userName: "John Doe",
token: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.token == "" {
err = &ValidationError{Field: "token", Message: "required"}
}
// Build reset link
if tt.token != "" {
resetLink := fmt.Sprintf("https://example.com/reset-password?token=%s", tt.token)
// Verify link is secure (HTTPS)
if !strings.HasPrefix(resetLink, "https://") {
t.Error("Reset link should use HTTPS")
}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_Send2FAEmail tests 2FA notification emails
func TestEmailService_Send2FAEmail(t *testing.T) {
tests := []struct {
name string
to string
action string
expectError bool
}{
{
name: "2FA enabled notification",
to: "user@example.com",
action: "enabled",
expectError: false,
},
{
name: "2FA disabled notification",
to: "user@example.com",
action: "disabled",
expectError: false,
},
{
name: "invalid action",
to: "user@example.com",
action: "invalid",
expectError: true,
},
}
validActions := map[string]bool{
"enabled": true,
"disabled": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if !validActions[tt.action] {
err = &ValidationError{Field: "action", Message: "must be 'enabled' or 'disabled'"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_SendConsentReminderEmail tests consent reminder
func TestEmailService_SendConsentReminderEmail(t *testing.T) {
tests := []struct {
name string
to string
documentName string
daysLeft int
expectError bool
}{
{
name: "reminder with 7 days left",
to: "user@example.com",
documentName: "Terms of Service",
daysLeft: 7,
expectError: false,
},
{
name: "reminder with 1 day left",
to: "user@example.com",
documentName: "Privacy Policy",
daysLeft: 1,
expectError: false,
},
{
name: "urgent reminder - overdue",
to: "user@example.com",
documentName: "Terms",
daysLeft: 0,
expectError: false,
},
{
name: "empty document name",
to: "user@example.com",
documentName: "",
daysLeft: 7,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentName == "" {
err = &ValidationError{Field: "document name", Message: "required"}
}
// Check urgency level
var urgency string
if tt.daysLeft <= 0 {
urgency = "critical"
} else if tt.daysLeft <= 3 {
urgency = "urgent"
} else {
urgency = "normal"
}
if urgency == "" && !tt.expectError {
t.Error("Urgency should be set")
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_MIMEFormatting tests MIME message formatting
func TestEmailService_MIMEFormatting(t *testing.T) {
tests := []struct {
name string
htmlBody string
textBody string
checkFor []string
}{
{
name: "multipart alternative",
htmlBody: "<h1>Test</h1>",
textBody: "Test",
checkFor: []string{
"MIME-Version: 1.0",
"Content-Type: multipart/alternative",
"Content-Type: text/plain",
"Content-Type: text/html",
},
},
{
name: "UTF-8 encoding",
htmlBody: "<p>Öäü</p>",
textBody: "Öäü",
checkFor: []string{
"charset=\"UTF-8\"",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Build MIME message (simplified)
message := fmt.Sprintf("MIME-Version: 1.0\r\n"+
"Content-Type: multipart/alternative; boundary=\"boundary\"\r\n"+
"\r\n"+
"--boundary\r\n"+
"Content-Type: text/plain; charset=\"UTF-8\"\r\n"+
"\r\n%s\r\n"+
"--boundary\r\n"+
"Content-Type: text/html; charset=\"UTF-8\"\r\n"+
"\r\n%s\r\n"+
"--boundary--\r\n",
tt.textBody, tt.htmlBody)
// Verify required headers are present
for _, required := range tt.checkFor {
if !strings.Contains(message, required) {
t.Errorf("Message should contain '%s'", required)
}
}
// Verify both bodies are included
if !strings.Contains(message, tt.textBody) {
t.Error("Message should contain text body")
}
if !strings.Contains(message, tt.htmlBody) {
t.Error("Message should contain HTML body")
}
})
}
}
// TestEmailService_TemplateRendering tests email template rendering
func TestEmailService_TemplateRendering(t *testing.T) {
tests := []struct {
name string
template string
variables map[string]string
expectVars []string
}{
{
name: "verification template",
template: "verification",
variables: map[string]string{
"Name": "John Doe",
"VerifyLink": "https://example.com/verify?token=abc",
},
expectVars: []string{"John Doe", "https://example.com/verify?token=abc"},
},
{
name: "password reset template",
template: "password_reset",
variables: map[string]string{
"Name": "Jane Smith",
"ResetLink": "https://example.com/reset?token=xyz",
},
expectVars: []string{"Jane Smith", "https://example.com/reset?token=xyz"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate template rendering
rendered := fmt.Sprintf("Hello %s, please visit %s",
tt.variables["Name"],
getLink(tt.variables))
// Verify all variables are in rendered output
for _, expectedVar := range tt.expectVars {
if !strings.Contains(rendered, expectedVar) {
t.Errorf("Rendered template should contain '%s'", expectedVar)
}
}
})
}
}
// TestEmailService_EmailValidation tests email address validation
func TestEmailService_EmailValidation(t *testing.T) {
tests := []struct {
email string
isValid bool
}{
{"user@example.com", true},
{"user+tag@example.com", true},
{"user.name@example.co.uk", true},
{"user@subdomain.example.com", true},
{"invalid", false},
{"@example.com", false},
{"user@", false},
{"user@.com", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.email, func(t *testing.T) {
// RFC 5322 compliant email validation pattern
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
isValid := emailRegex.MatchString(tt.email)
if isValid != tt.isValid {
t.Errorf("Email %s: expected valid=%v, got %v", tt.email, tt.isValid, isValid)
}
})
}
}
// TestEmailService_SMTPConfig tests SMTP configuration
func TestEmailService_SMTPConfig(t *testing.T) {
tests := []struct {
name string
config EmailConfig
expectError bool
}{
{
name: "valid config",
config: EmailConfig{
Host: "smtp.example.com",
Port: 587,
Username: "user@example.com",
Password: "password",
FromName: "BreakPilot",
FromAddr: "noreply@example.com",
BaseURL: "https://example.com",
},
expectError: false,
},
{
name: "missing host",
config: EmailConfig{
Port: 587,
Username: "user@example.com",
Password: "password",
},
expectError: true,
},
{
name: "invalid port",
config: EmailConfig{
Host: "smtp.example.com",
Port: 0,
Username: "user@example.com",
Password: "password",
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.config.Host == "" {
err = &ValidationError{Field: "host", Message: "required"}
} else if tt.config.Port <= 0 || tt.config.Port > 65535 {
err = &ValidationError{Field: "port", Message: "must be between 1 and 65535"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestEmailService_RateLimiting tests email rate limiting logic
func TestEmailService_RateLimiting(t *testing.T) {
tests := []struct {
name string
emailsSent int
timeWindow int // minutes
limit int
expectThrottle bool
}{
{
name: "under limit",
emailsSent: 5,
timeWindow: 60,
limit: 10,
expectThrottle: false,
},
{
name: "at limit",
emailsSent: 10,
timeWindow: 60,
limit: 10,
expectThrottle: false,
},
{
name: "over limit",
emailsSent: 15,
timeWindow: 60,
limit: 10,
expectThrottle: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
shouldThrottle := tt.emailsSent > tt.limit
if shouldThrottle != tt.expectThrottle {
t.Errorf("Expected throttle=%v, got %v", tt.expectThrottle, shouldThrottle)
}
})
}
}
// Helper functions
func getLink(vars map[string]string) string {
if link, ok := vars["VerifyLink"]; ok {
return link
}
if link, ok := vars["ResetLink"]; ok {
return link
}
return ""
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,698 @@
package services
import (
"regexp"
"strings"
"testing"
"github.com/breakpilot/consent-service/internal/models"
)
// ========================================
// Test All 19 Email Categories
// ========================================
// TestEmailTemplateService_GetDefaultTemplateContent tests default content generation for each email type
func TestEmailTemplateService_GetDefaultTemplateContent(t *testing.T) {
service := &EmailTemplateService{}
// All 19 email categories
tests := []struct {
name string
emailType string
language string
wantSubject bool
wantBodyHTML bool
wantBodyText bool
}{
// Auth Lifecycle (10 types)
{"welcome_de", models.EmailTypeWelcome, "de", true, true, true},
{"email_verification_de", models.EmailTypeEmailVerification, "de", true, true, true},
{"password_reset_de", models.EmailTypePasswordReset, "de", true, true, true},
{"password_changed_de", models.EmailTypePasswordChanged, "de", true, true, true},
{"2fa_enabled_de", models.EmailType2FAEnabled, "de", true, true, true},
{"2fa_disabled_de", models.EmailType2FADisabled, "de", true, true, true},
{"new_device_login_de", models.EmailTypeNewDeviceLogin, "de", true, true, true},
{"suspicious_activity_de", models.EmailTypeSuspiciousActivity, "de", true, true, true},
{"account_locked_de", models.EmailTypeAccountLocked, "de", true, true, true},
{"account_unlocked_de", models.EmailTypeAccountUnlocked, "de", true, true, true},
// GDPR/Privacy (5 types)
{"deletion_requested_de", models.EmailTypeDeletionRequested, "de", true, true, true},
{"deletion_confirmed_de", models.EmailTypeDeletionConfirmed, "de", true, true, true},
{"data_export_ready_de", models.EmailTypeDataExportReady, "de", true, true, true},
{"email_changed_de", models.EmailTypeEmailChanged, "de", true, true, true},
{"email_change_verify_de", models.EmailTypeEmailChangeVerify, "de", true, true, true},
// Consent Management (4 types)
{"new_version_published_de", models.EmailTypeNewVersionPublished, "de", true, true, true},
{"consent_reminder_de", models.EmailTypeConsentReminder, "de", true, true, true},
{"consent_deadline_warning_de", models.EmailTypeConsentDeadlineWarning, "de", true, true, true},
{"account_suspended_de", models.EmailTypeAccountSuspended, "de", true, true, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(tt.emailType, tt.language)
if tt.wantSubject && subject == "" {
t.Errorf("GetDefaultTemplateContent(%s, %s): expected subject, got empty string", tt.emailType, tt.language)
}
if tt.wantBodyHTML && bodyHTML == "" {
t.Errorf("GetDefaultTemplateContent(%s, %s): expected bodyHTML, got empty string", tt.emailType, tt.language)
}
if tt.wantBodyText && bodyText == "" {
t.Errorf("GetDefaultTemplateContent(%s, %s): expected bodyText, got empty string", tt.emailType, tt.language)
}
})
}
}
// TestEmailTemplateService_GetDefaultTemplateContent_UnknownType tests default content for unknown type
func TestEmailTemplateService_GetDefaultTemplateContent_UnknownType(t *testing.T) {
service := &EmailTemplateService{}
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent("unknown_type", "de")
// The service returns a fallback for unknown types
if subject == "" {
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback subject, got empty")
}
if bodyHTML == "" {
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback bodyHTML, got empty")
}
if bodyText == "" {
t.Errorf("GetDefaultTemplateContent(unknown_type, de): expected fallback bodyText, got empty")
}
}
// TestEmailTemplateService_GetDefaultTemplateContent_UnsupportedLanguage tests fallback for unsupported language
func TestEmailTemplateService_GetDefaultTemplateContent_UnsupportedLanguage(t *testing.T) {
service := &EmailTemplateService{}
// Test with unsupported language - should return fallback
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(models.EmailTypeWelcome, "fr")
// Should return fallback (not empty, but generic)
if subject == "" || bodyHTML == "" || bodyText == "" {
t.Error("GetDefaultTemplateContent should return fallback for unsupported language")
}
}
// TestReplaceVariables tests variable replacement in templates
func TestReplaceVariables(t *testing.T) {
tests := []struct {
name string
template string
variables map[string]string
expected string
}{
{
name: "single variable",
template: "Hallo {{user_name}}!",
variables: map[string]string{"user_name": "Max"},
expected: "Hallo Max!",
},
{
name: "multiple variables",
template: "Hallo {{user_name}}, klicken Sie hier: {{reset_link}}",
variables: map[string]string{"user_name": "Max", "reset_link": "https://example.com"},
expected: "Hallo Max, klicken Sie hier: https://example.com",
},
{
name: "no variables",
template: "Hallo Welt!",
variables: map[string]string{},
expected: "Hallo Welt!",
},
{
name: "missing variable - not replaced",
template: "Hallo {{user_name}} und {{missing}}!",
variables: map[string]string{"user_name": "Max"},
expected: "Hallo Max und {{missing}}!",
},
{
name: "empty template",
template: "",
variables: map[string]string{"user_name": "Max"},
expected: "",
},
{
name: "variable with special characters",
template: "IP: {{ip_address}}",
variables: map[string]string{"ip_address": "192.168.1.1"},
expected: "IP: 192.168.1.1",
},
{
name: "variable with URL",
template: "Link: {{verification_url}}",
variables: map[string]string{"verification_url": "https://example.com/verify?token=abc123&user=test"},
expected: "Link: https://example.com/verify?token=abc123&user=test",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := testReplaceVariables(tt.template, tt.variables)
if result != tt.expected {
t.Errorf("replaceVariables() = %s, want %s", result, tt.expected)
}
})
}
}
// testReplaceVariables is a test helper function for variable replacement
func testReplaceVariables(template string, variables map[string]string) string {
result := template
for key, value := range variables {
placeholder := "{{" + key + "}}"
for i := 0; i < len(result); i++ {
idx := testFindSubstring(result, placeholder)
if idx == -1 {
break
}
result = result[:idx] + value + result[idx+len(placeholder):]
}
}
return result
}
func testFindSubstring(s, substr string) int {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return i
}
}
return -1
}
// TestEmailTypeConstantsExist verifies that all expected email types are defined
func TestEmailTypeConstantsExist(t *testing.T) {
// Test that all 19 email type constants are defined and produce non-empty templates
types := []string{
// Auth Lifecycle
models.EmailTypeWelcome,
models.EmailTypeEmailVerification,
models.EmailTypePasswordReset,
models.EmailTypePasswordChanged,
models.EmailType2FAEnabled,
models.EmailType2FADisabled,
models.EmailTypeNewDeviceLogin,
models.EmailTypeSuspiciousActivity,
models.EmailTypeAccountLocked,
models.EmailTypeAccountUnlocked,
// GDPR/Privacy
models.EmailTypeDeletionRequested,
models.EmailTypeDeletionConfirmed,
models.EmailTypeDataExportReady,
models.EmailTypeEmailChanged,
models.EmailTypeEmailChangeVerify,
// Consent Management
models.EmailTypeNewVersionPublished,
models.EmailTypeConsentReminder,
models.EmailTypeConsentDeadlineWarning,
models.EmailTypeAccountSuspended,
}
service := &EmailTemplateService{}
for _, emailType := range types {
t.Run(emailType, func(t *testing.T) {
subject, bodyHTML, _ := service.GetDefaultTemplateContent(emailType, "de")
if subject == "" {
t.Errorf("Email type %s has no default subject", emailType)
}
if bodyHTML == "" {
t.Errorf("Email type %s has no default body HTML", emailType)
}
})
}
// Verify we have exactly 19 types
if len(types) != 19 {
t.Errorf("Expected 19 email types, got %d", len(types))
}
}
// TestEmailTemplateService_ValidateTemplateContent tests template content validation
func TestEmailTemplateService_ValidateTemplateContent(t *testing.T) {
tests := []struct {
name string
subject string
bodyHTML string
wantError bool
}{
{
name: "valid content",
subject: "Test Subject",
bodyHTML: "<p>Test Body</p>",
wantError: false,
},
{
name: "empty subject",
subject: "",
bodyHTML: "<p>Test Body</p>",
wantError: true,
},
{
name: "empty body",
subject: "Test Subject",
bodyHTML: "",
wantError: true,
},
{
name: "both empty",
subject: "",
bodyHTML: "",
wantError: true,
},
{
name: "whitespace only subject",
subject: " ",
bodyHTML: "<p>Test Body</p>",
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := testValidateTemplateContent(tt.subject, tt.bodyHTML)
if (err != nil) != tt.wantError {
t.Errorf("validateTemplateContent() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
// testValidateTemplateContent is a test helper function to validate template content
func testValidateTemplateContent(subject, bodyHTML string) error {
if strings.TrimSpace(subject) == "" {
return &templateValidationError{Field: "subject", Message: "subject is required"}
}
if strings.TrimSpace(bodyHTML) == "" {
return &templateValidationError{Field: "body_html", Message: "body_html is required"}
}
return nil
}
// templateValidationError represents a validation error in email templates
type templateValidationError struct {
Field string
Message string
}
func (e *templateValidationError) Error() string {
return e.Field + ": " + e.Message
}
// TestGetTestVariablesForType tests that test variables are properly generated for each email type
func TestGetTestVariablesForType(t *testing.T) {
tests := []struct {
emailType string
expectedVars []string
}{
// Auth Lifecycle
{models.EmailTypeWelcome, []string{"user_name", "app_name"}},
{models.EmailTypeEmailVerification, []string{"user_name", "verification_url"}},
{models.EmailTypePasswordReset, []string{"reset_url"}},
{models.EmailTypePasswordChanged, []string{"user_name", "changed_at"}},
{models.EmailType2FAEnabled, []string{"user_name", "enabled_at"}},
{models.EmailType2FADisabled, []string{"user_name", "disabled_at"}},
{models.EmailTypeNewDeviceLogin, []string{"device", "location", "ip_address", "login_time"}},
{models.EmailTypeSuspiciousActivity, []string{"activity_type", "activity_time"}},
{models.EmailTypeAccountLocked, []string{"locked_at", "reason"}},
{models.EmailTypeAccountUnlocked, []string{"unlocked_at"}},
// GDPR/Privacy
{models.EmailTypeDeletionRequested, []string{"deletion_date", "cancel_url"}},
{models.EmailTypeDeletionConfirmed, []string{"deleted_at"}},
{models.EmailTypeDataExportReady, []string{"download_url", "expires_in"}},
{models.EmailTypeEmailChanged, []string{"old_email", "new_email"}},
// Consent Management
{models.EmailTypeNewVersionPublished, []string{"document_name", "version"}},
{models.EmailTypeConsentReminder, []string{"document_name", "days_left"}},
{models.EmailTypeConsentDeadlineWarning, []string{"document_name", "hours_left"}},
{models.EmailTypeAccountSuspended, []string{"suspended_at", "reason"}},
}
for _, tt := range tests {
t.Run(tt.emailType, func(t *testing.T) {
vars := getTestVariablesForType(tt.emailType)
for _, expected := range tt.expectedVars {
if _, ok := vars[expected]; !ok {
t.Errorf("getTestVariablesForType(%s) missing variable %s", tt.emailType, expected)
}
}
})
}
}
// getTestVariablesForType returns test variables for a given email type
func getTestVariablesForType(emailType string) map[string]string {
// Common variables
vars := map[string]string{
"user_name": "Max Mustermann",
"user_email": "max@example.com",
"app_name": "BreakPilot",
"app_url": "https://breakpilot.app",
"support_url": "https://breakpilot.app/support",
"support_email": "support@breakpilot.app",
"security_url": "https://breakpilot.app/security",
"login_url": "https://breakpilot.app/login",
}
switch emailType {
case models.EmailTypeEmailVerification:
vars["verification_url"] = "https://breakpilot.app/verify?token=xyz789"
vars["verification_code"] = "ABC123"
vars["expires_in"] = "24 Stunden"
case models.EmailTypePasswordReset:
vars["reset_url"] = "https://breakpilot.app/reset?token=abc123"
vars["reset_code"] = "RST456"
vars["expires_in"] = "1 Stunde"
vars["ip_address"] = "192.168.1.1"
case models.EmailTypePasswordChanged:
vars["changed_at"] = "14.12.2025 15:30 Uhr"
vars["ip_address"] = "192.168.1.1"
vars["device_info"] = "Chrome auf MacOS"
case models.EmailType2FAEnabled:
vars["enabled_at"] = "14.12.2025 15:30 Uhr"
vars["device_info"] = "Chrome auf MacOS"
case models.EmailType2FADisabled:
vars["disabled_at"] = "14.12.2025 15:30 Uhr"
vars["ip_address"] = "192.168.1.1"
case models.EmailTypeNewDeviceLogin:
vars["device"] = "Chrome auf MacOS"
vars["device_info"] = "Chrome auf MacOS"
vars["location"] = "Berlin, Deutschland"
vars["ip_address"] = "192.168.1.1"
vars["login_time"] = "14.12.2025 15:30 Uhr"
case models.EmailTypeSuspiciousActivity:
vars["activity_type"] = "Mehrere fehlgeschlagene Logins"
vars["activity_time"] = "14.12.2025 15:30 Uhr"
vars["ip_address"] = "192.168.1.1"
case models.EmailTypeAccountLocked:
vars["locked_at"] = "14.12.2025 15:30 Uhr"
vars["reason"] = "Zu viele fehlgeschlagene Login-Versuche"
vars["unlock_time"] = "14.12.2025 16:30 Uhr"
case models.EmailTypeAccountUnlocked:
vars["unlocked_at"] = "14.12.2025 16:30 Uhr"
case models.EmailTypeDeletionRequested:
vars["requested_at"] = "14.12.2025 15:30 Uhr"
vars["deletion_date"] = "14.01.2026"
vars["cancel_url"] = "https://breakpilot.app/cancel-deletion?token=del123"
vars["data_info"] = "Profildaten, Consent-Historie, Audit-Logs"
case models.EmailTypeDeletionConfirmed:
vars["deleted_at"] = "14.01.2026 00:00 Uhr"
vars["feedback_url"] = "https://breakpilot.app/feedback"
case models.EmailTypeDataExportReady:
vars["download_url"] = "https://breakpilot.app/download/export123"
vars["expires_in"] = "7 Tage"
vars["file_size"] = "2.5 MB"
case models.EmailTypeEmailChanged:
vars["old_email"] = "old@example.com"
vars["new_email"] = "new@example.com"
vars["changed_at"] = "14.12.2025 15:30 Uhr"
case models.EmailTypeEmailChangeVerify:
vars["new_email"] = "new@example.com"
vars["verification_url"] = "https://breakpilot.app/verify-email?token=ver123"
vars["expires_in"] = "24 Stunden"
case models.EmailTypeNewVersionPublished:
vars["document_name"] = "Datenschutzerklärung"
vars["document_type"] = "privacy"
vars["version"] = "2.0.0"
vars["consent_url"] = "https://breakpilot.app/consent"
vars["deadline"] = "31.12.2025"
case models.EmailTypeConsentReminder:
vars["document_name"] = "Nutzungsbedingungen"
vars["days_left"] = "7"
vars["consent_url"] = "https://breakpilot.app/consent"
vars["deadline"] = "21.12.2025"
case models.EmailTypeConsentDeadlineWarning:
vars["document_name"] = "Nutzungsbedingungen"
vars["hours_left"] = "24 Stunden"
vars["consent_url"] = "https://breakpilot.app/consent"
vars["consequences"] = "Ihr Konto wird temporär suspendiert."
case models.EmailTypeAccountSuspended:
vars["suspended_at"] = "14.12.2025 15:30 Uhr"
vars["reason"] = "Fehlende Zustimmung zu Pflichtdokumenten"
vars["documents"] = "- Nutzungsbedingungen v2.0\n- Datenschutzerklärung v3.0"
vars["consent_url"] = "https://breakpilot.app/consent"
}
return vars
}
// TestEmailTemplateService_HTMLEscape tests that HTML is properly escaped in text version
func TestEmailTemplateService_HTMLEscape(t *testing.T) {
tests := []struct {
name string
html string
expected string
}{
{
name: "simple paragraph",
html: "<p>Hello World</p>",
expected: "Hello World",
},
{
name: "link",
html: `<a href="https://example.com">Click here</a>`,
expected: "Click here",
},
{
name: "bold text",
html: "<strong>Important</strong>",
expected: "Important",
},
{
name: "nested tags",
html: "<div><p><strong>Nested</strong> text</p></div>",
expected: "Nested text",
},
{
name: "multiple tags",
html: "<h1>Title</h1><p>Paragraph</p>",
expected: "TitleParagraph",
},
{
name: "self-closing tag",
html: "Line1<br/>Line2",
expected: "Line1Line2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripHTMLTags(tt.html)
if result != tt.expected {
t.Errorf("stripHTMLTags() = %s, want %s", result, tt.expected)
}
})
}
}
// stripHTMLTags removes HTML tags from a string
func stripHTMLTags(html string) string {
result := ""
inTag := false
for _, r := range html {
if r == '<' {
inTag = true
continue
}
if r == '>' {
inTag = false
continue
}
if !inTag {
result += string(r)
}
}
return result
}
// TestEmailTemplateService_AllTemplatesHaveVariables tests that all templates define their required variables
func TestEmailTemplateService_AllTemplatesHaveVariables(t *testing.T) {
service := &EmailTemplateService{}
templateTypes := service.GetAllTemplateTypes()
for _, tt := range templateTypes {
t.Run(tt.TemplateType, func(t *testing.T) {
// Get default template content
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(tt.TemplateType, "de")
// Check that variables defined in template type are present in the content
for _, varName := range tt.Variables {
placeholder := "{{" + varName + "}}"
foundInSubject := strings.Contains(subject, placeholder)
foundInHTML := strings.Contains(bodyHTML, placeholder)
foundInText := strings.Contains(bodyText, placeholder)
// Variable should be present in at least one of subject, HTML or text
if !foundInSubject && !foundInHTML && !foundInText {
// Note: This is a warning, not an error, as some variables might be optional
t.Logf("Warning: Variable %s defined for %s but not found in template content", varName, tt.TemplateType)
}
}
// Check that all variables in content are defined
re := regexp.MustCompile(`\{\{(\w+)\}\}`)
allMatches := re.FindAllStringSubmatch(subject+bodyHTML+bodyText, -1)
definedVars := make(map[string]bool)
for _, v := range tt.Variables {
definedVars[v] = true
}
for _, match := range allMatches {
if len(match) > 1 {
varName := match[1]
if !definedVars[varName] {
t.Logf("Warning: Variable {{%s}} found in template but not defined in variables list for %s", varName, tt.TemplateType)
}
}
}
})
}
}
// TestEmailTemplateService_TemplateVariableDescriptions tests that all variables have descriptions
func TestEmailTemplateService_TemplateVariableDescriptions(t *testing.T) {
service := &EmailTemplateService{}
templateTypes := service.GetAllTemplateTypes()
for _, tt := range templateTypes {
t.Run(tt.TemplateType, func(t *testing.T) {
for _, varName := range tt.Variables {
if desc, ok := tt.Descriptions[varName]; !ok || desc == "" {
t.Errorf("Variable %s in %s has no description", varName, tt.TemplateType)
}
}
})
}
}
// TestEmailTemplateService_GermanTemplatesAreComplete tests that all German templates are fully translated
func TestEmailTemplateService_GermanTemplatesAreComplete(t *testing.T) {
service := &EmailTemplateService{}
emailTypes := []string{
models.EmailTypeWelcome,
models.EmailTypeEmailVerification,
models.EmailTypePasswordReset,
models.EmailTypePasswordChanged,
models.EmailType2FAEnabled,
models.EmailType2FADisabled,
models.EmailTypeNewDeviceLogin,
models.EmailTypeSuspiciousActivity,
models.EmailTypeAccountLocked,
models.EmailTypeAccountUnlocked,
models.EmailTypeDeletionRequested,
models.EmailTypeDeletionConfirmed,
models.EmailTypeDataExportReady,
models.EmailTypeEmailChanged,
models.EmailTypeNewVersionPublished,
models.EmailTypeConsentReminder,
models.EmailTypeConsentDeadlineWarning,
models.EmailTypeAccountSuspended,
}
germanKeywords := []string{"Hallo", "freundlichen", "Grüßen", "BreakPilot", "Ihr"}
for _, emailType := range emailTypes {
t.Run(emailType, func(t *testing.T) {
subject, bodyHTML, bodyText := service.GetDefaultTemplateContent(emailType, "de")
// Check that German text is present
foundGerman := false
for _, keyword := range germanKeywords {
if strings.Contains(bodyHTML, keyword) || strings.Contains(bodyText, keyword) {
foundGerman = true
break
}
}
if !foundGerman {
t.Errorf("Template %s does not appear to be in German", emailType)
}
// Check that subject is not just the fallback
if subject == "No template" {
t.Errorf("Template %s has fallback subject instead of German subject", emailType)
}
})
}
}
// TestEmailTemplateService_HTMLStructure tests that HTML templates have valid structure
func TestEmailTemplateService_HTMLStructure(t *testing.T) {
service := &EmailTemplateService{}
emailTypes := []string{
models.EmailTypeWelcome,
models.EmailTypeEmailVerification,
models.EmailTypePasswordReset,
}
for _, emailType := range emailTypes {
t.Run(emailType, func(t *testing.T) {
_, bodyHTML, _ := service.GetDefaultTemplateContent(emailType, "de")
// Check for basic HTML structure
if !strings.Contains(bodyHTML, "<!DOCTYPE html>") {
t.Errorf("Template %s missing DOCTYPE", emailType)
}
if !strings.Contains(bodyHTML, "<html>") {
t.Errorf("Template %s missing <html> tag", emailType)
}
if !strings.Contains(bodyHTML, "</html>") {
t.Errorf("Template %s missing closing </html> tag", emailType)
}
if !strings.Contains(bodyHTML, "<body") {
t.Errorf("Template %s missing <body> tag", emailType)
}
if !strings.Contains(bodyHTML, "</body>") {
t.Errorf("Template %s missing closing </body> tag", emailType)
}
})
}
}
// BenchmarkReplaceVariables benchmarks variable replacement
func BenchmarkReplaceVariables(b *testing.B) {
template := "Hallo {{user_name}}, Ihr Link: {{reset_url}}, gültig bis {{expires_in}}"
variables := map[string]string{
"user_name": "Max Mustermann",
"reset_url": "https://example.com/reset?token=abc123",
"expires_in": "24 Stunden",
}
for i := 0; i < b.N; i++ {
replaceVariables(template, variables)
}
}
// BenchmarkStripHTMLTags benchmarks HTML tag stripping
func BenchmarkStripHTMLTags(b *testing.B) {
html := "<html><body><h1>Title</h1><p>This is a <strong>test</strong> paragraph with <a href='#'>links</a>.</p></body></html>"
for i := 0; i < b.N; i++ {
stripHTMLTags(html)
}
}

View File

@@ -0,0 +1,543 @@
package services
import (
"context"
"fmt"
"time"
"github.com/breakpilot/consent-service/internal/database"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/google/uuid"
)
// GradeService handles grade management and notifications
type GradeService struct {
db *database.DB
matrix *matrix.MatrixService
}
// NewGradeService creates a new grade service
func NewGradeService(db *database.DB, matrixService *matrix.MatrixService) *GradeService {
return &GradeService{
db: db,
matrix: matrixService,
}
}
// ========================================
// Grade CRUD
// ========================================
// CreateGrade creates a new grade for a student
func (s *GradeService) CreateGrade(ctx context.Context, req models.CreateGradeRequest, teacherID uuid.UUID) (*models.Grade, error) {
studentID, err := uuid.Parse(req.StudentID)
if err != nil {
return nil, fmt.Errorf("invalid student ID: %w", err)
}
subjectID, err := uuid.Parse(req.SubjectID)
if err != nil {
return nil, fmt.Errorf("invalid subject ID: %w", err)
}
schoolYearID, err := uuid.Parse(req.SchoolYearID)
if err != nil {
return nil, fmt.Errorf("invalid school year ID: %w", err)
}
date, err := time.Parse("2006-01-02", req.Date)
if err != nil {
return nil, fmt.Errorf("invalid date format: %w", err)
}
// Get default grade scale for the school
var gradeScaleID uuid.UUID
var schoolID uuid.UUID
err = s.db.Pool.QueryRow(ctx, `
SELECT gs.id, gs.school_id
FROM grade_scales gs
JOIN students st ON st.school_id = gs.school_id
WHERE st.id = $1 AND gs.is_default = true`, studentID).Scan(&gradeScaleID, &schoolID)
if err != nil {
return nil, fmt.Errorf("failed to get grade scale: %w", err)
}
weight := req.Weight
if weight == 0 {
weight = 1.0
}
grade := &models.Grade{
ID: uuid.New(),
StudentID: studentID,
SubjectID: subjectID,
TeacherID: teacherID,
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Type: req.Type,
Value: req.Value,
Weight: weight,
Date: date,
Title: req.Title,
Description: req.Description,
IsVisible: true,
Semester: req.Semester,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO grades (id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
grade.ID, grade.StudentID, grade.SubjectID, grade.TeacherID,
grade.SchoolYearID, grade.GradeScaleID, grade.Type, grade.Value,
grade.Weight, grade.Date, grade.Title, grade.Description,
grade.IsVisible, grade.Semester, grade.CreatedAt, grade.UpdatedAt,
).Scan(&grade.ID)
if err != nil {
return nil, fmt.Errorf("failed to create grade: %w", err)
}
// Send notification to parents if grade is visible
if grade.IsVisible {
go s.notifyParentsOfNewGrade(context.Background(), grade)
}
return grade, nil
}
// GetGrade retrieves a grade by ID
func (s *GradeService) GetGrade(ctx context.Context, gradeID uuid.UUID) (*models.Grade, error) {
query := `
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
FROM grades
WHERE id = $1`
grade := &models.Grade{}
err := s.db.Pool.QueryRow(ctx, query, gradeID).Scan(
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get grade: %w", err)
}
return grade, nil
}
// UpdateGrade updates an existing grade
func (s *GradeService) UpdateGrade(ctx context.Context, gradeID uuid.UUID, value float64, title, description *string) error {
query := `
UPDATE grades
SET value = $1, title = COALESCE($2, title), description = COALESCE($3, description), updated_at = NOW()
WHERE id = $4`
result, err := s.db.Pool.Exec(ctx, query, value, title, description, gradeID)
if err != nil {
return fmt.Errorf("failed to update grade: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("grade not found")
}
return nil
}
// DeleteGrade deletes a grade
func (s *GradeService) DeleteGrade(ctx context.Context, gradeID uuid.UUID) error {
result, err := s.db.Pool.Exec(ctx, `DELETE FROM grades WHERE id = $1`, gradeID)
if err != nil {
return fmt.Errorf("failed to delete grade: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("grade not found")
}
return nil
}
// ========================================
// Grade Queries
// ========================================
// GetStudentGrades gets all grades for a student in a school year
func (s *GradeService) GetStudentGrades(ctx context.Context, studentID, schoolYearID uuid.UUID) ([]models.Grade, error) {
query := `
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
FROM grades
WHERE student_id = $1 AND school_year_id = $2 AND is_visible = true
ORDER BY date DESC`
rows, err := s.db.Pool.Query(ctx, query, studentID, schoolYearID)
if err != nil {
return nil, fmt.Errorf("failed to get student grades: %w", err)
}
defer rows.Close()
var grades []models.Grade
for rows.Next() {
var grade models.Grade
err := rows.Scan(
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan grade: %w", err)
}
grades = append(grades, grade)
}
return grades, nil
}
// GetStudentGradesBySubject gets grades for a student in a specific subject
func (s *GradeService) GetStudentGradesBySubject(ctx context.Context, studentID, subjectID, schoolYearID uuid.UUID, semester int) ([]models.Grade, error) {
query := `
SELECT id, student_id, subject_id, teacher_id, school_year_id, grade_scale_id, type, value, weight, date, title, description, is_visible, semester, created_at, updated_at
FROM grades
WHERE student_id = $1 AND subject_id = $2 AND school_year_id = $3 AND semester = $4 AND is_visible = true
ORDER BY date DESC`
rows, err := s.db.Pool.Query(ctx, query, studentID, subjectID, schoolYearID, semester)
if err != nil {
return nil, fmt.Errorf("failed to get grades by subject: %w", err)
}
defer rows.Close()
var grades []models.Grade
for rows.Next() {
var grade models.Grade
err := rows.Scan(
&grade.ID, &grade.StudentID, &grade.SubjectID, &grade.TeacherID,
&grade.SchoolYearID, &grade.GradeScaleID, &grade.Type, &grade.Value,
&grade.Weight, &grade.Date, &grade.Title, &grade.Description,
&grade.IsVisible, &grade.Semester, &grade.CreatedAt, &grade.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan grade: %w", err)
}
grades = append(grades, grade)
}
return grades, nil
}
// GetClassGradesBySubject gets all grades for a class in a subject (Notenspiegel)
func (s *GradeService) GetClassGradesBySubject(ctx context.Context, classID, subjectID, schoolYearID uuid.UUID, semester int) ([]models.StudentGradeOverview, error) {
// Get all students in the class
studentsQuery := `
SELECT id, first_name, last_name
FROM students
WHERE class_id = $1 AND is_active = true
ORDER BY last_name, first_name`
rows, err := s.db.Pool.Query(ctx, studentsQuery, classID)
if err != nil {
return nil, fmt.Errorf("failed to get students: %w", err)
}
defer rows.Close()
var students []struct {
ID uuid.UUID
FirstName string
LastName string
}
for rows.Next() {
var student struct {
ID uuid.UUID
FirstName string
LastName string
}
if err := rows.Scan(&student.ID, &student.FirstName, &student.LastName); err != nil {
return nil, fmt.Errorf("failed to scan student: %w", err)
}
students = append(students, student)
}
// Get subject info
var subject models.Subject
err = s.db.Pool.QueryRow(ctx, `SELECT id, school_id, name, short_name, color, is_active, created_at FROM subjects WHERE id = $1`, subjectID).Scan(
&subject.ID, &subject.SchoolID, &subject.Name, &subject.ShortName, &subject.Color, &subject.IsActive, &subject.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get subject: %w", err)
}
var overviews []models.StudentGradeOverview
for _, student := range students {
grades, err := s.GetStudentGradesBySubject(ctx, student.ID, subjectID, schoolYearID, semester)
if err != nil {
continue
}
// Calculate averages
var totalWeight, weightedSum float64
var oralWeight, oralSum float64
var examWeight, examSum float64
for _, grade := range grades {
totalWeight += grade.Weight
weightedSum += grade.Value * grade.Weight
if grade.Type == models.GradeTypeOral || grade.Type == models.GradeTypeParticipation {
oralWeight += grade.Weight
oralSum += grade.Value * grade.Weight
} else if grade.Type == models.GradeTypeExam || grade.Type == models.GradeTypeTest {
examWeight += grade.Weight
examSum += grade.Value * grade.Weight
}
}
var average, oralAverage, examAverage float64
if totalWeight > 0 {
average = weightedSum / totalWeight
}
if oralWeight > 0 {
oralAverage = oralSum / oralWeight
}
if examWeight > 0 {
examAverage = examSum / examWeight
}
overview := models.StudentGradeOverview{
Student: models.Student{
ID: student.ID,
FirstName: student.FirstName,
LastName: student.LastName,
},
Subject: subject,
Grades: grades,
Average: average,
OralAverage: oralAverage,
ExamAverage: examAverage,
Semester: semester,
}
overviews = append(overviews, overview)
}
return overviews, nil
}
// ========================================
// Grade Statistics
// ========================================
// GetStudentGradeAverage calculates the overall grade average for a student
func (s *GradeService) GetStudentGradeAverage(ctx context.Context, studentID, schoolYearID uuid.UUID, semester int) (float64, error) {
query := `
SELECT COALESCE(SUM(value * weight) / NULLIF(SUM(weight), 0), 0)
FROM grades
WHERE student_id = $1 AND school_year_id = $2 AND semester = $3 AND is_visible = true`
var average float64
err := s.db.Pool.QueryRow(ctx, query, studentID, schoolYearID, semester).Scan(&average)
if err != nil {
return 0, fmt.Errorf("failed to calculate average: %w", err)
}
return average, nil
}
// GetSubjectGradeStatistics gets grade statistics for a subject in a class
func (s *GradeService) GetSubjectGradeStatistics(ctx context.Context, classID, subjectID, schoolYearID uuid.UUID, semester int) (map[string]interface{}, error) {
query := `
SELECT
COUNT(DISTINCT g.student_id) as student_count,
AVG(g.value) as class_average,
MIN(g.value) as best_grade,
MAX(g.value) as worst_grade,
COUNT(*) as total_grades
FROM grades g
JOIN students s ON g.student_id = s.id
WHERE s.class_id = $1 AND g.subject_id = $2 AND g.school_year_id = $3 AND g.semester = $4 AND g.is_visible = true`
var studentCount, totalGrades int
var classAverage, bestGrade, worstGrade float64
err := s.db.Pool.QueryRow(ctx, query, classID, subjectID, schoolYearID, semester).Scan(
&studentCount, &classAverage, &bestGrade, &worstGrade, &totalGrades,
)
if err != nil {
return nil, fmt.Errorf("failed to get statistics: %w", err)
}
// Grade distribution (for German grades 1-6)
distributionQuery := `
SELECT
COUNT(CASE WHEN value >= 1 AND value < 1.5 THEN 1 END) as grade_1,
COUNT(CASE WHEN value >= 1.5 AND value < 2.5 THEN 1 END) as grade_2,
COUNT(CASE WHEN value >= 2.5 AND value < 3.5 THEN 1 END) as grade_3,
COUNT(CASE WHEN value >= 3.5 AND value < 4.5 THEN 1 END) as grade_4,
COUNT(CASE WHEN value >= 4.5 AND value < 5.5 THEN 1 END) as grade_5,
COUNT(CASE WHEN value >= 5.5 THEN 1 END) as grade_6
FROM grades g
JOIN students s ON g.student_id = s.id
WHERE s.class_id = $1 AND g.subject_id = $2 AND g.school_year_id = $3 AND g.semester = $4 AND g.is_visible = true AND g.type IN ('exam', 'test')`
var g1, g2, g3, g4, g5, g6 int
err = s.db.Pool.QueryRow(ctx, distributionQuery, classID, subjectID, schoolYearID, semester).Scan(
&g1, &g2, &g3, &g4, &g5, &g6,
)
if err != nil {
// Non-fatal, continue without distribution
g1, g2, g3, g4, g5, g6 = 0, 0, 0, 0, 0, 0
}
return map[string]interface{}{
"student_count": studentCount,
"class_average": classAverage,
"best_grade": bestGrade,
"worst_grade": worstGrade,
"total_grades": totalGrades,
"distribution": map[string]int{
"1": g1,
"2": g2,
"3": g3,
"4": g4,
"5": g5,
"6": g6,
},
}, nil
}
// ========================================
// Grade Comments
// ========================================
// AddGradeComment adds a comment to a grade
func (s *GradeService) AddGradeComment(ctx context.Context, gradeID, teacherID uuid.UUID, comment string, isPrivate bool) (*models.GradeComment, error) {
gradeComment := &models.GradeComment{
ID: uuid.New(),
GradeID: gradeID,
TeacherID: teacherID,
Comment: comment,
IsPrivate: isPrivate,
CreatedAt: time.Now(),
}
query := `
INSERT INTO grade_comments (id, grade_id, teacher_id, comment, is_private, created_at)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
gradeComment.ID, gradeComment.GradeID, gradeComment.TeacherID,
gradeComment.Comment, gradeComment.IsPrivate, gradeComment.CreatedAt,
).Scan(&gradeComment.ID)
if err != nil {
return nil, fmt.Errorf("failed to add grade comment: %w", err)
}
return gradeComment, nil
}
// GetGradeComments gets comments for a grade
func (s *GradeService) GetGradeComments(ctx context.Context, gradeID uuid.UUID, includePrivate bool) ([]models.GradeComment, error) {
query := `
SELECT id, grade_id, teacher_id, comment, is_private, created_at
FROM grade_comments
WHERE grade_id = $1`
if !includePrivate {
query += ` AND is_private = false`
}
query += ` ORDER BY created_at DESC`
rows, err := s.db.Pool.Query(ctx, query, gradeID)
if err != nil {
return nil, fmt.Errorf("failed to get grade comments: %w", err)
}
defer rows.Close()
var comments []models.GradeComment
for rows.Next() {
var comment models.GradeComment
err := rows.Scan(
&comment.ID, &comment.GradeID, &comment.TeacherID,
&comment.Comment, &comment.IsPrivate, &comment.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan grade comment: %w", err)
}
comments = append(comments, comment)
}
return comments, nil
}
// ========================================
// Parent Notifications
// ========================================
func (s *GradeService) notifyParentsOfNewGrade(ctx context.Context, grade *models.Grade) {
if s.matrix == nil {
return
}
// Get student info and Matrix room
var studentFirstName, studentLastName, matrixDMRoom string
err := s.db.Pool.QueryRow(ctx, `
SELECT first_name, last_name, matrix_dm_room
FROM students
WHERE id = $1`, grade.StudentID).Scan(&studentFirstName, &studentLastName, &matrixDMRoom)
if err != nil || matrixDMRoom == "" {
return
}
// Get subject name
var subjectName string
err = s.db.Pool.QueryRow(ctx, `SELECT name FROM subjects WHERE id = $1`, grade.SubjectID).Scan(&subjectName)
if err != nil {
return
}
studentName := studentFirstName + " " + studentLastName
gradeType := s.getGradeTypeDisplayName(grade.Type)
// Send Matrix notification
err = s.matrix.SendGradeNotification(ctx, matrixDMRoom, studentName, subjectName, gradeType, grade.Value)
if err != nil {
fmt.Printf("Failed to send grade notification: %v\n", err)
}
}
func (s *GradeService) getGradeTypeDisplayName(gradeType string) string {
switch gradeType {
case models.GradeTypeExam:
return "Klassenarbeit"
case models.GradeTypeTest:
return "Test"
case models.GradeTypeOral:
return "Mündliche Note"
case models.GradeTypeHomework:
return "Hausaufgabe"
case models.GradeTypeProject:
return "Projekt"
case models.GradeTypeParticipation:
return "Mitarbeit"
case models.GradeTypeSemester:
return "Halbjahreszeugnis"
case models.GradeTypeFinal:
return "Zeugnisnote"
default:
return gradeType
}
}

View File

@@ -0,0 +1,532 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestValidateGrade tests grade validation
func TestValidateGrade(t *testing.T) {
schoolYearID := uuid.New()
gradeScaleID := uuid.New()
tests := []struct {
name string
grade models.Grade
expectValid bool
}{
{
name: "valid grade 1",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 1.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 1,
},
expectValid: true,
},
{
name: "valid grade 6",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 6.0,
Type: models.GradeTypeOral,
Weight: 0.5,
Date: time.Now(),
Semester: 2,
},
expectValid: true,
},
{
name: "valid grade with plus (1.3)",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 1.3,
Type: models.GradeTypeTest,
Weight: 0.25,
Date: time.Now(),
Semester: 1,
},
expectValid: true,
},
{
name: "invalid grade 0",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 0.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 1,
},
expectValid: false,
},
{
name: "invalid grade 7",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 7.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 1,
},
expectValid: false,
},
{
name: "missing student ID",
grade: models.Grade{
StudentID: uuid.Nil,
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 2.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 1,
},
expectValid: false,
},
{
name: "invalid weight negative",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 2.0,
Type: models.GradeTypeExam,
Weight: -0.5,
Date: time.Now(),
Semester: 1,
},
expectValid: false,
},
{
name: "invalid semester 0",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 2.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 0,
},
expectValid: false,
},
{
name: "invalid semester 3",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 2.0,
Type: models.GradeTypeExam,
Weight: 1.0,
Date: time.Now(),
Semester: 3,
},
expectValid: false,
},
{
name: "invalid type",
grade: models.Grade{
StudentID: uuid.New(),
SubjectID: uuid.New(),
TeacherID: uuid.New(),
SchoolYearID: schoolYearID,
GradeScaleID: gradeScaleID,
Value: 2.0,
Type: "invalid_type",
Weight: 1.0,
Date: time.Now(),
Semester: 1,
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateGrade(tt.grade)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateGrade validates a grade
func validateGrade(grade models.Grade) bool {
if grade.StudentID == uuid.Nil {
return false
}
if grade.SubjectID == uuid.Nil {
return false
}
if grade.TeacherID == uuid.Nil {
return false
}
// German grading scale: 1 (best) to 6 (worst)
if grade.Value < 1.0 || grade.Value > 6.0 {
return false
}
if grade.Weight < 0 {
return false
}
if grade.Semester < 1 || grade.Semester > 2 {
return false
}
// Validate type
validTypes := map[string]bool{
models.GradeTypeExam: true,
models.GradeTypeTest: true,
models.GradeTypeOral: true,
models.GradeTypeHomework: true,
models.GradeTypeProject: true,
models.GradeTypeParticipation: true,
models.GradeTypeSemester: true,
models.GradeTypeFinal: true,
}
if !validTypes[grade.Type] {
return false
}
return true
}
// TestCalculateWeightedAverage tests weighted average calculation
func TestCalculateWeightedAverage(t *testing.T) {
tests := []struct {
name string
grades []models.Grade
expectedAverage float64
}{
{
name: "simple average equal weights",
grades: []models.Grade{
{Value: 1.0, Weight: 1.0},
{Value: 2.0, Weight: 1.0},
{Value: 3.0, Weight: 1.0},
},
expectedAverage: 2.0,
},
{
name: "weighted average different weights",
grades: []models.Grade{
{Value: 1.0, Weight: 2.0}, // Exam counts double
{Value: 3.0, Weight: 1.0},
},
// (1*2 + 3*1) / (2+1) = 5/3 = 1.67
expectedAverage: 1.67,
},
{
name: "single grade",
grades: []models.Grade{
{Value: 2.5, Weight: 1.0},
},
expectedAverage: 2.5,
},
{
name: "empty grades",
grades: []models.Grade{},
expectedAverage: 0.0,
},
{
name: "all same grades",
grades: []models.Grade{
{Value: 2.0, Weight: 1.0},
{Value: 2.0, Weight: 1.0},
{Value: 2.0, Weight: 1.0},
},
expectedAverage: 2.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
avg := calculateWeightedAverage(tt.grades)
// Allow small floating point differences
if avg < tt.expectedAverage-0.01 || avg > tt.expectedAverage+0.01 {
t.Errorf("expected average=%.2f, got average=%.2f", tt.expectedAverage, avg)
}
})
}
}
// calculateWeightedAverage calculates weighted average of grades
func calculateWeightedAverage(grades []models.Grade) float64 {
if len(grades) == 0 {
return 0.0
}
var weightedSum float64
var totalWeight float64
for _, g := range grades {
weightedSum += g.Value * g.Weight
totalWeight += g.Weight
}
if totalWeight == 0 {
return 0.0
}
avg := weightedSum / totalWeight
// Round to 2 decimal places
return float64(int(avg*100)) / 100
}
// TestGradeDistribution tests grade distribution calculation
func TestGradeDistribution(t *testing.T) {
tests := []struct {
name string
grades []models.Grade
expectedDist map[int]int
}{
{
name: "varied distribution",
grades: []models.Grade{
{Value: 1.0}, {Value: 1.3},
{Value: 2.0}, {Value: 2.0}, {Value: 2.7},
{Value: 3.0}, {Value: 3.0}, {Value: 3.0},
{Value: 4.0}, {Value: 4.3},
{Value: 5.0},
},
expectedDist: map[int]int{
1: 2, // 1.0, 1.3 (rounded: 1, 1)
2: 2, // 2.0, 2.0 (rounded: 2, 2)
3: 4, // 2.7, 3.0, 3.0, 3.0 (rounded: 3, 3, 3, 3)
4: 2, // 4.0, 4.3 (rounded: 4, 4)
5: 1, // 5.0 (rounded: 5)
6: 0,
},
},
{
name: "empty grades",
grades: []models.Grade{},
expectedDist: map[int]int{1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0},
},
{
name: "all same grade",
grades: []models.Grade{
{Value: 2.0},
{Value: 2.0},
{Value: 2.0},
},
expectedDist: map[int]int{1: 0, 2: 3, 3: 0, 4: 0, 5: 0, 6: 0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dist := calculateGradeDistribution(tt.grades)
for grade, count := range tt.expectedDist {
if dist[grade] != count {
t.Errorf("grade %d: expected count=%d, got count=%d", grade, count, dist[grade])
}
}
})
}
}
// calculateGradeDistribution calculates how many grades fall into each category
func calculateGradeDistribution(grades []models.Grade) map[int]int {
dist := map[int]int{1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0}
for _, g := range grades {
// Round to nearest integer for distribution
roundedGrade := int(g.Value + 0.5)
if roundedGrade < 1 {
roundedGrade = 1
}
if roundedGrade > 6 {
roundedGrade = 6
}
dist[roundedGrade]++
}
return dist
}
// TestGradePointConversion tests conversion between grades and points (Oberstufe)
func TestGradePointConversion(t *testing.T) {
tests := []struct {
name string
grade float64
expectedPoints int
}{
{"grade 1.0 = 15 points", 1.0, 15},
{"grade 1.3 = 14 points", 1.3, 14},
{"grade 1.7 = 13 points", 1.7, 13},
{"grade 2.0 = 12 points", 2.0, 12},
{"grade 2.3 = 11 points", 2.3, 11},
{"grade 2.7 = 10 points", 2.7, 10},
{"grade 3.0 = 9 points", 3.0, 9},
{"grade 3.3 = 8 points", 3.3, 8},
{"grade 3.7 = 7 points", 3.7, 7},
{"grade 4.0 = 6 points", 4.0, 6},
{"grade 4.3 = 5 points", 4.3, 5},
{"grade 4.7 = 4 points", 4.7, 4},
{"grade 5.0 = 3 points", 5.0, 3},
{"grade 5.3 = 2 points", 5.3, 2},
{"grade 5.7 = 1 point", 5.7, 1},
{"grade 6.0 = 0 points", 6.0, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
points := gradeToPoints(tt.grade)
if points != tt.expectedPoints {
t.Errorf("expected points=%d, got points=%d", tt.expectedPoints, points)
}
})
}
}
// gradeToPoints converts German grade (1-6) to Oberstufe points (0-15)
func gradeToPoints(grade float64) int {
// Mapping based on German school system
if grade <= 1.0 {
return 15
} else if grade <= 1.3 {
return 14
} else if grade <= 1.7 {
return 13
} else if grade <= 2.0 {
return 12
} else if grade <= 2.3 {
return 11
} else if grade <= 2.7 {
return 10
} else if grade <= 3.0 {
return 9
} else if grade <= 3.3 {
return 8
} else if grade <= 3.7 {
return 7
} else if grade <= 4.0 {
return 6
} else if grade <= 4.3 {
return 5
} else if grade <= 4.7 {
return 4
} else if grade <= 5.0 {
return 3
} else if grade <= 5.3 {
return 2
} else if grade <= 5.7 {
return 1
}
return 0
}
// TestFindBestAndWorstGrade tests finding best and worst grades
func TestFindBestAndWorstGrade(t *testing.T) {
tests := []struct {
name string
grades []models.Grade
expectedBest float64
expectedWorst float64
}{
{
name: "varied grades",
grades: []models.Grade{
{Value: 2.0},
{Value: 1.0},
{Value: 3.0},
{Value: 5.0},
{Value: 2.0},
},
expectedBest: 1.0,
expectedWorst: 5.0,
},
{
name: "all same",
grades: []models.Grade{
{Value: 2.0},
{Value: 2.0},
},
expectedBest: 2.0,
expectedWorst: 2.0,
},
{
name: "single grade",
grades: []models.Grade{
{Value: 3.0},
},
expectedBest: 3.0,
expectedWorst: 3.0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
best, worst := findBestAndWorstGrade(tt.grades)
if best != tt.expectedBest {
t.Errorf("expected best=%.1f, got best=%.1f", tt.expectedBest, best)
}
if worst != tt.expectedWorst {
t.Errorf("expected worst=%.1f, got worst=%.1f", tt.expectedWorst, worst)
}
})
}
}
// findBestAndWorstGrade finds the best (lowest) and worst (highest) grade
func findBestAndWorstGrade(grades []models.Grade) (best, worst float64) {
if len(grades) == 0 {
return 0, 0
}
best = grades[0].Value
worst = grades[0].Value
for _, g := range grades[1:] {
if g.Value < best {
best = g.Value
}
if g.Value > worst {
worst = g.Value
}
}
return best, worst
}

View File

@@ -0,0 +1,340 @@
package jitsi
import (
"context"
"fmt"
"time"
)
// ========================================
// Breakpilot Drive Game Meeting Types
// ========================================
// GameMeetingMode represents different game video call modes
type GameMeetingMode string
const (
GameMeetingCoop GameMeetingMode = "coop" // Co-Op voice/video
GameMeetingChallenge GameMeetingMode = "challenge" // 1v1 face-off
GameMeetingClassRace GameMeetingMode = "class_race" // Teacher supervises
GameMeetingTeamHuddle GameMeetingMode = "team_huddle" // Quick team sync
)
// GameMeetingConfig holds configuration for game video meetings
type GameMeetingConfig struct {
SessionID string `json:"session_id"`
Mode GameMeetingMode `json:"mode"`
HostID string `json:"host_id"`
HostName string `json:"host_name"`
Players []GamePlayer `json:"players"`
EnableVideo bool `json:"enable_video"`
EnableVoice bool `json:"enable_voice"`
TeacherID string `json:"teacher_id,omitempty"`
TeacherName string `json:"teacher_name,omitempty"`
ClassName string `json:"class_name,omitempty"`
}
// GamePlayer represents a player in the meeting
type GamePlayer struct {
ID string `json:"id"`
Name string `json:"name"`
IsModerator bool `json:"is_moderator,omitempty"`
}
// GameMeetingLink extends MeetingLink with game-specific info
type GameMeetingLink struct {
*MeetingLink
SessionID string `json:"session_id"`
Mode GameMeetingMode `json:"mode"`
Players []string `json:"players"`
}
// ========================================
// Game Meeting Creation
// ========================================
// CreateCoopMeeting creates a video call for Co-Op gameplay (2-4 players)
func (s *JitsiService) CreateCoopMeeting(ctx context.Context, config GameMeetingConfig) (*GameMeetingLink, error) {
roomName := fmt.Sprintf("bp-coop-%s", config.SessionID[:8])
meeting := Meeting{
RoomName: roomName,
DisplayName: config.HostName,
Subject: "Breakpilot Drive - Co-Op Session",
Moderator: true,
Config: &MeetingConfig{
StartWithAudioMuted: !config.EnableVoice,
StartWithVideoMuted: !config.EnableVideo,
RequireDisplayName: true,
EnableLobby: false, // Direct join for co-op
DisableDeepLinking: true,
},
}
link, err := s.CreateMeetingLink(ctx, meeting)
if err != nil {
return nil, fmt.Errorf("failed to create co-op meeting: %w", err)
}
playerIDs := make([]string, len(config.Players))
for i, p := range config.Players {
playerIDs[i] = p.ID
}
return &GameMeetingLink{
MeetingLink: link,
SessionID: config.SessionID,
Mode: GameMeetingCoop,
Players: playerIDs,
}, nil
}
// CreateChallengeMeeting creates a 1v1 video call for challenges
func (s *JitsiService) CreateChallengeMeeting(ctx context.Context, config GameMeetingConfig, challengerName string, opponentName string) (*GameMeetingLink, error) {
roomName := fmt.Sprintf("bp-challenge-%s", config.SessionID[:8])
meeting := Meeting{
RoomName: roomName,
DisplayName: challengerName,
Subject: fmt.Sprintf("Challenge: %s vs %s", challengerName, opponentName),
Moderator: false, // Both players are equal
Config: &MeetingConfig{
StartWithAudioMuted: false, // Voice enabled for trash talk
StartWithVideoMuted: !config.EnableVideo,
RequireDisplayName: true,
EnableLobby: false,
DisableDeepLinking: true,
},
}
link, err := s.CreateMeetingLink(ctx, meeting)
if err != nil {
return nil, fmt.Errorf("failed to create challenge meeting: %w", err)
}
return &GameMeetingLink{
MeetingLink: link,
SessionID: config.SessionID,
Mode: GameMeetingChallenge,
Players: []string{config.HostID},
}, nil
}
// CreateClassRaceMeeting creates a video call for teacher-supervised class races
func (s *JitsiService) CreateClassRaceMeeting(ctx context.Context, config GameMeetingConfig) (*GameMeetingLink, error) {
roomName := fmt.Sprintf("bp-klasse-%s-%s",
s.sanitizeRoomName(config.ClassName),
time.Now().Format("150405"))
// Teacher is moderator
meeting := Meeting{
RoomName: roomName,
DisplayName: config.TeacherName,
Subject: fmt.Sprintf("Klassenrennen: %s", config.ClassName),
Moderator: true,
Config: &MeetingConfig{
StartWithAudioMuted: true, // Students muted by default
StartWithVideoMuted: true, // Video off for performance
RequireDisplayName: true,
EnableLobby: true, // Teacher admits students
EnableRecording: false, // No recording for minors
DisableDeepLinking: true,
},
Features: &MeetingFeatures{
Recording: false,
Transcription: false,
},
}
link, err := s.CreateMeetingLink(ctx, meeting)
if err != nil {
return nil, fmt.Errorf("failed to create class race meeting: %w", err)
}
playerIDs := make([]string, len(config.Players))
for i, p := range config.Players {
playerIDs[i] = p.ID
}
return &GameMeetingLink{
MeetingLink: link,
SessionID: config.SessionID,
Mode: GameMeetingClassRace,
Players: playerIDs,
}, nil
}
// CreateTeamHuddleMeeting creates a quick sync meeting for teams
func (s *JitsiService) CreateTeamHuddleMeeting(ctx context.Context, config GameMeetingConfig, teamName string) (*GameMeetingLink, error) {
roomName := fmt.Sprintf("bp-team-%s-%s",
s.sanitizeRoomName(teamName),
config.SessionID[:8])
meeting := Meeting{
RoomName: roomName,
DisplayName: config.HostName,
Subject: fmt.Sprintf("Team %s - Huddle", teamName),
Duration: 5, // Short 5-minute huddles
Moderator: true,
Config: &MeetingConfig{
StartWithAudioMuted: false, // Voice on for quick sync
StartWithVideoMuted: true, // Video optional
RequireDisplayName: true,
EnableLobby: false,
DisableDeepLinking: true,
},
}
link, err := s.CreateMeetingLink(ctx, meeting)
if err != nil {
return nil, fmt.Errorf("failed to create team huddle: %w", err)
}
playerIDs := make([]string, len(config.Players))
for i, p := range config.Players {
playerIDs[i] = p.ID
}
return &GameMeetingLink{
MeetingLink: link,
SessionID: config.SessionID,
Mode: GameMeetingTeamHuddle,
Players: playerIDs,
}, nil
}
// ========================================
// Game-Specific Meeting Configurations
// ========================================
// GetGameEmbedConfig returns optimized config for embedding in Unity WebGL
func (s *JitsiService) GetGameEmbedConfig(enableVideo bool, enableVoice bool) *MeetingConfig {
return &MeetingConfig{
StartWithAudioMuted: !enableVoice,
StartWithVideoMuted: !enableVideo,
RequireDisplayName: true,
EnableLobby: false,
DisableDeepLinking: true, // Important for iframe embedding
}
}
// BuildGameEmbedURL creates a URL optimized for Unity WebGL embedding
func (s *JitsiService) BuildGameEmbedURL(roomName string, playerName string, enableVideo bool, enableVoice bool) string {
config := s.GetGameEmbedConfig(enableVideo, enableVoice)
return s.BuildEmbedURL(roomName, playerName, config)
}
// BuildUnityIFrameParams returns parameters for Unity's WebGL iframe
func (s *JitsiService) BuildUnityIFrameParams(link *GameMeetingLink, playerName string) map[string]interface{} {
return map[string]interface{}{
"domain": s.extractDomain(),
"roomName": link.RoomName,
"displayName": playerName,
"jwt": link.JWT,
"configOverwrite": map[string]interface{}{
"startWithAudioMuted": false,
"startWithVideoMuted": true,
"disableDeepLinking": true,
"prejoinPageEnabled": false,
"enableWelcomePage": false,
"enableClosePage": false,
"disableInviteFunctions": true,
},
"interfaceConfigOverwrite": map[string]interface{}{
"DISABLE_JOIN_LEAVE_NOTIFICATIONS": true,
"MOBILE_APP_PROMO": false,
"SHOW_CHROME_EXTENSION_BANNER": false,
"TOOLBAR_BUTTONS": []string{
"microphone", "camera", "hangup", "chat",
},
},
}
}
// ========================================
// Spectator Mode (for teachers/parents)
// ========================================
// CreateSpectatorLink creates a view-only link for observers
func (s *JitsiService) CreateSpectatorLink(ctx context.Context, roomName string, spectatorName string) (*MeetingLink, error) {
meeting := Meeting{
RoomName: roomName,
DisplayName: fmt.Sprintf("[Zuschauer] %s", spectatorName),
Moderator: false,
Config: &MeetingConfig{
StartWithAudioMuted: true,
StartWithVideoMuted: true,
DisableDeepLinking: true,
},
}
return s.CreateMeetingLink(ctx, meeting)
}
// ========================================
// Helper Functions
// ========================================
// extractDomain extracts the domain from baseURL
func (s *JitsiService) extractDomain() string {
// Remove protocol prefix
domain := s.baseURL
if len(domain) > 8 && domain[:8] == "https://" {
domain = domain[8:]
} else if len(domain) > 7 && domain[:7] == "http://" {
domain = domain[7:]
}
// Remove port if present
for i, c := range domain {
if c == ':' || c == '/' {
domain = domain[:i]
break
}
}
return domain
}
// ValidateGameMeetingConfig validates configuration before creating meeting
func ValidateGameMeetingConfig(config GameMeetingConfig) error {
if config.SessionID == "" {
return fmt.Errorf("session_id is required")
}
if config.Mode == "" {
return fmt.Errorf("mode is required")
}
if config.HostID == "" {
return fmt.Errorf("host_id is required")
}
if config.HostName == "" {
return fmt.Errorf("host_name is required")
}
switch config.Mode {
case GameMeetingCoop:
if len(config.Players) < 2 || len(config.Players) > 4 {
return fmt.Errorf("co-op mode requires 2-4 players")
}
case GameMeetingChallenge:
if len(config.Players) != 2 {
return fmt.Errorf("challenge mode requires exactly 2 players")
}
case GameMeetingClassRace:
if config.TeacherID == "" || config.TeacherName == "" {
return fmt.Errorf("class race mode requires teacher info")
}
if config.ClassName == "" {
return fmt.Errorf("class race mode requires class name")
}
case GameMeetingTeamHuddle:
if len(config.Players) < 2 {
return fmt.Errorf("team huddle requires at least 2 players")
}
default:
return fmt.Errorf("unknown game meeting mode: %s", config.Mode)
}
return nil
}

View File

@@ -0,0 +1,566 @@
package jitsi
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/google/uuid"
)
// JitsiService handles Jitsi Meet integration for video conferences
type JitsiService struct {
baseURL string
appID string
appSecret string
httpClient *http.Client
}
// Config holds Jitsi service configuration
type Config struct {
BaseURL string // e.g., "http://localhost:8443"
AppID string // Application ID for JWT (optional)
AppSecret string // Secret for JWT signing (optional)
}
// NewJitsiService creates a new Jitsi service instance
func NewJitsiService(cfg Config) *JitsiService {
return &JitsiService{
baseURL: strings.TrimSuffix(cfg.BaseURL, "/"),
appID: cfg.AppID,
appSecret: cfg.AppSecret,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// ========================================
// Types
// ========================================
// Meeting represents a Jitsi meeting configuration
type Meeting struct {
RoomName string `json:"room_name"`
DisplayName string `json:"display_name,omitempty"`
Email string `json:"email,omitempty"`
Avatar string `json:"avatar,omitempty"`
Subject string `json:"subject,omitempty"`
Password string `json:"password,omitempty"`
StartTime *time.Time `json:"start_time,omitempty"`
Duration int `json:"duration,omitempty"` // in minutes
Config *MeetingConfig `json:"config,omitempty"`
Moderator bool `json:"moderator,omitempty"`
Features *MeetingFeatures `json:"features,omitempty"`
}
// MeetingConfig holds Jitsi room configuration options
type MeetingConfig struct {
StartWithAudioMuted bool `json:"start_with_audio_muted,omitempty"`
StartWithVideoMuted bool `json:"start_with_video_muted,omitempty"`
DisableDeepLinking bool `json:"disable_deep_linking,omitempty"`
RequireDisplayName bool `json:"require_display_name,omitempty"`
EnableLobby bool `json:"enable_lobby,omitempty"`
EnableRecording bool `json:"enable_recording,omitempty"`
}
// MeetingFeatures controls which features are enabled
type MeetingFeatures struct {
Livestreaming bool `json:"livestreaming,omitempty"`
Recording bool `json:"recording,omitempty"`
Transcription bool `json:"transcription,omitempty"`
OutboundCall bool `json:"outbound_call,omitempty"`
}
// MeetingLink contains the generated meeting URL and metadata
type MeetingLink struct {
URL string `json:"url"`
RoomName string `json:"room_name"`
JoinURL string `json:"join_url"`
ModeratorURL string `json:"moderator_url,omitempty"`
Password string `json:"password,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
JWT string `json:"jwt,omitempty"`
}
// JWTClaims represents the JWT payload for Jitsi
type JWTClaims struct {
Audience string `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Room string `json:"room,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
Context *JWTContext `json:"context,omitempty"`
Moderator bool `json:"moderator,omitempty"`
Features *JWTFeatures `json:"features,omitempty"`
}
// JWTContext contains user information for JWT
type JWTContext struct {
User *JWTUser `json:"user,omitempty"`
Group string `json:"group,omitempty"`
Callee *JWTCallee `json:"callee,omitempty"`
}
// JWTUser represents user info in JWT
type JWTUser struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Email string `json:"email,omitempty"`
Avatar string `json:"avatar,omitempty"`
Moderator bool `json:"moderator,omitempty"`
HiddenFromRecorder bool `json:"hidden-from-recorder,omitempty"`
}
// JWTCallee represents callee info (for 1:1 calls)
type JWTCallee struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"`
}
// JWTFeatures controls JWT-based feature access
type JWTFeatures struct {
Livestreaming string `json:"livestreaming,omitempty"` // "true" or "false"
Recording string `json:"recording,omitempty"`
Transcription string `json:"transcription,omitempty"`
OutboundCall string `json:"outbound-call,omitempty"`
}
// ScheduledMeeting represents a scheduled training/meeting
type ScheduledMeeting struct {
ID string `json:"id"`
Title string `json:"title"`
Description string `json:"description,omitempty"`
RoomName string `json:"room_name"`
HostID string `json:"host_id"`
HostName string `json:"host_name"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration int `json:"duration"` // in minutes
Password string `json:"password,omitempty"`
MaxParticipants int `json:"max_participants,omitempty"`
Features *MeetingFeatures `json:"features,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// ========================================
// Meeting Management
// ========================================
// CreateMeetingLink generates a meeting URL with optional JWT authentication
func (s *JitsiService) CreateMeetingLink(ctx context.Context, meeting Meeting) (*MeetingLink, error) {
// Generate room name if not provided
roomName := meeting.RoomName
if roomName == "" {
roomName = s.generateRoomName()
}
// Sanitize room name (Jitsi-compatible)
roomName = s.sanitizeRoomName(roomName)
link := &MeetingLink{
RoomName: roomName,
URL: fmt.Sprintf("%s/%s", s.baseURL, roomName),
JoinURL: fmt.Sprintf("%s/%s", s.baseURL, roomName),
Password: meeting.Password,
}
// Generate JWT if authentication is configured
if s.appSecret != "" {
jwt, expiresAt, err := s.generateJWT(meeting, roomName)
if err != nil {
return nil, fmt.Errorf("failed to generate JWT: %w", err)
}
link.JWT = jwt
link.ExpiresAt = expiresAt
link.JoinURL = fmt.Sprintf("%s/%s?jwt=%s", s.baseURL, roomName, jwt)
// Generate moderator URL if user is moderator
if meeting.Moderator {
link.ModeratorURL = link.JoinURL
}
}
// Add config parameters to URL
if meeting.Config != nil {
params := s.buildConfigParams(meeting.Config)
if params != "" {
separator := "?"
if strings.Contains(link.JoinURL, "?") {
separator = "&"
}
link.JoinURL += separator + params
}
}
return link, nil
}
// CreateTrainingSession creates a meeting link optimized for training sessions
func (s *JitsiService) CreateTrainingSession(ctx context.Context, title string, hostName string, hostEmail string, duration int) (*MeetingLink, error) {
meeting := Meeting{
RoomName: s.generateTrainingRoomName(title),
DisplayName: hostName,
Email: hostEmail,
Subject: title,
Duration: duration,
Moderator: true,
Config: &MeetingConfig{
StartWithAudioMuted: true, // Participants start muted
StartWithVideoMuted: false, // Video on for training
RequireDisplayName: true, // Know who's attending
EnableLobby: true, // Waiting room
EnableRecording: true, // Allow recording
},
Features: &MeetingFeatures{
Recording: true,
Transcription: false,
},
}
return s.CreateMeetingLink(ctx, meeting)
}
// CreateQuickMeeting creates a simple ad-hoc meeting
func (s *JitsiService) CreateQuickMeeting(ctx context.Context, displayName string) (*MeetingLink, error) {
meeting := Meeting{
DisplayName: displayName,
Config: &MeetingConfig{
StartWithAudioMuted: false,
StartWithVideoMuted: false,
},
}
return s.CreateMeetingLink(ctx, meeting)
}
// CreateParentTeacherMeeting creates a meeting for parent-teacher conferences
func (s *JitsiService) CreateParentTeacherMeeting(ctx context.Context, teacherName string, parentName string, studentName string, scheduledTime time.Time) (*MeetingLink, error) {
roomName := fmt.Sprintf("elterngespraech-%s-%s",
s.sanitizeRoomName(studentName),
scheduledTime.Format("20060102-1504"))
meeting := Meeting{
RoomName: roomName,
DisplayName: teacherName,
Subject: fmt.Sprintf("Elterngespräch: %s", studentName),
StartTime: &scheduledTime,
Duration: 30, // 30 minutes default
Moderator: true,
Password: s.generatePassword(),
Config: &MeetingConfig{
StartWithAudioMuted: false,
StartWithVideoMuted: false,
RequireDisplayName: true,
EnableLobby: true, // Teacher admits parent
DisableDeepLinking: true,
},
}
return s.CreateMeetingLink(ctx, meeting)
}
// CreateClassMeeting creates a meeting for an entire class
func (s *JitsiService) CreateClassMeeting(ctx context.Context, className string, teacherName string, subject string) (*MeetingLink, error) {
roomName := fmt.Sprintf("klasse-%s-%s",
s.sanitizeRoomName(className),
time.Now().Format("20060102"))
meeting := Meeting{
RoomName: roomName,
DisplayName: teacherName,
Subject: fmt.Sprintf("%s - %s", className, subject),
Moderator: true,
Config: &MeetingConfig{
StartWithAudioMuted: true, // Students muted by default
StartWithVideoMuted: false,
RequireDisplayName: true,
EnableLobby: false, // Direct join for classes
},
}
return s.CreateMeetingLink(ctx, meeting)
}
// ========================================
// JWT Generation
// ========================================
// generateJWT creates a signed JWT for Jitsi authentication
func (s *JitsiService) generateJWT(meeting Meeting, roomName string) (string, *time.Time, error) {
if s.appSecret == "" {
return "", nil, fmt.Errorf("app secret not configured")
}
now := time.Now()
// Default expiration: 24 hours or based on meeting duration
expiration := now.Add(24 * time.Hour)
if meeting.Duration > 0 {
expiration = now.Add(time.Duration(meeting.Duration+30) * time.Minute)
}
if meeting.StartTime != nil {
expiration = meeting.StartTime.Add(time.Duration(meeting.Duration+60) * time.Minute)
}
claims := JWTClaims{
Audience: "jitsi",
Issuer: s.appID,
Subject: "meet.jitsi",
Room: roomName,
ExpiresAt: expiration.Unix(),
NotBefore: now.Add(-5 * time.Minute).Unix(), // 5 min grace period
Moderator: meeting.Moderator,
Context: &JWTContext{
User: &JWTUser{
ID: uuid.New().String(),
Name: meeting.DisplayName,
Email: meeting.Email,
Avatar: meeting.Avatar,
Moderator: meeting.Moderator,
},
},
}
// Add features if specified
if meeting.Features != nil {
claims.Features = &JWTFeatures{
Recording: boolToString(meeting.Features.Recording),
Livestreaming: boolToString(meeting.Features.Livestreaming),
Transcription: boolToString(meeting.Features.Transcription),
OutboundCall: boolToString(meeting.Features.OutboundCall),
}
}
// Create JWT
token, err := s.signJWT(claims)
if err != nil {
return "", nil, err
}
return token, &expiration, nil
}
// signJWT creates and signs a JWT token
func (s *JitsiService) signJWT(claims JWTClaims) (string, error) {
// Header
header := map[string]string{
"alg": "HS256",
"typ": "JWT",
}
headerJSON, err := json.Marshal(header)
if err != nil {
return "", err
}
// Payload
payloadJSON, err := json.Marshal(claims)
if err != nil {
return "", err
}
// Encode
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON)
// Sign
message := headerB64 + "." + payloadB64
h := hmac.New(sha256.New, []byte(s.appSecret))
h.Write([]byte(message))
signature := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
return message + "." + signature, nil
}
// ========================================
// Health Check
// ========================================
// HealthCheck verifies the Jitsi server is accessible
func (s *JitsiService) HealthCheck(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return fmt.Errorf("jitsi server unreachable: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 500 {
return fmt.Errorf("jitsi server error: status %d", resp.StatusCode)
}
return nil
}
// GetServerInfo returns information about the Jitsi server
func (s *JitsiService) GetServerInfo() map[string]string {
return map[string]string{
"base_url": s.baseURL,
"app_id": s.appID,
"auth_enabled": boolToString(s.appSecret != ""),
}
}
// ========================================
// URL Building
// ========================================
// BuildEmbedURL creates an embeddable iframe URL
func (s *JitsiService) BuildEmbedURL(roomName string, displayName string, config *MeetingConfig) string {
params := url.Values{}
if displayName != "" {
params.Set("userInfo.displayName", displayName)
}
if config != nil {
if config.StartWithAudioMuted {
params.Set("config.startWithAudioMuted", "true")
}
if config.StartWithVideoMuted {
params.Set("config.startWithVideoMuted", "true")
}
if config.DisableDeepLinking {
params.Set("config.disableDeepLinking", "true")
}
}
embedURL := fmt.Sprintf("%s/%s", s.baseURL, s.sanitizeRoomName(roomName))
if len(params) > 0 {
embedURL += "#" + params.Encode()
}
return embedURL
}
// BuildIFrameCode generates HTML iframe code for embedding
func (s *JitsiService) BuildIFrameCode(roomName string, width int, height int) string {
if width == 0 {
width = 800
}
if height == 0 {
height = 600
}
return fmt.Sprintf(
`<iframe src="%s/%s" width="%d" height="%d" allow="camera; microphone; fullscreen; display-capture; autoplay" style="border: 0;"></iframe>`,
s.baseURL,
s.sanitizeRoomName(roomName),
width,
height,
)
}
// ========================================
// Helper Functions
// ========================================
// generateRoomName creates a unique room name
func (s *JitsiService) generateRoomName() string {
return fmt.Sprintf("breakpilot-%s", uuid.New().String()[:8])
}
// generateTrainingRoomName creates a room name for training sessions
func (s *JitsiService) generateTrainingRoomName(title string) string {
sanitized := s.sanitizeRoomName(title)
if sanitized == "" {
sanitized = "schulung"
}
return fmt.Sprintf("%s-%s", sanitized, time.Now().Format("20060102-1504"))
}
// sanitizeRoomName removes invalid characters from room names
func (s *JitsiService) sanitizeRoomName(name string) string {
// Replace spaces and special characters
result := strings.ToLower(name)
result = strings.ReplaceAll(result, " ", "-")
result = strings.ReplaceAll(result, "ä", "ae")
result = strings.ReplaceAll(result, "ö", "oe")
result = strings.ReplaceAll(result, "ü", "ue")
result = strings.ReplaceAll(result, "ß", "ss")
// Remove any remaining non-alphanumeric characters except hyphen
var cleaned strings.Builder
for _, r := range result {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' {
cleaned.WriteRune(r)
}
}
// Remove consecutive hyphens
result = cleaned.String()
for strings.Contains(result, "--") {
result = strings.ReplaceAll(result, "--", "-")
}
// Trim hyphens from start and end
result = strings.Trim(result, "-")
// Limit length
if len(result) > 50 {
result = result[:50]
}
return result
}
// generatePassword creates a random meeting password
func (s *JitsiService) generatePassword() string {
return uuid.New().String()[:8]
}
// buildConfigParams creates URL parameters from config
func (s *JitsiService) buildConfigParams(config *MeetingConfig) string {
params := url.Values{}
if config.StartWithAudioMuted {
params.Set("config.startWithAudioMuted", "true")
}
if config.StartWithVideoMuted {
params.Set("config.startWithVideoMuted", "true")
}
if config.DisableDeepLinking {
params.Set("config.disableDeepLinking", "true")
}
if config.RequireDisplayName {
params.Set("config.requireDisplayName", "true")
}
if config.EnableLobby {
params.Set("config.enableLobby", "true")
}
return params.Encode()
}
// boolToString converts bool to "true"/"false" string
func boolToString(b bool) string {
if b {
return "true"
}
return "false"
}
// GetBaseURL returns the configured base URL
func (s *JitsiService) GetBaseURL() string {
return s.baseURL
}
// IsAuthEnabled returns whether JWT authentication is configured
func (s *JitsiService) IsAuthEnabled() bool {
return s.appSecret != ""
}

View File

@@ -0,0 +1,687 @@
package jitsi
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// ========================================
// Test Helpers
// ========================================
func createTestService() *JitsiService {
return NewJitsiService(Config{
BaseURL: "http://localhost:8443",
AppID: "breakpilot",
AppSecret: "test-secret-key",
})
}
func createTestServiceWithoutAuth() *JitsiService {
return NewJitsiService(Config{
BaseURL: "http://localhost:8443",
})
}
// ========================================
// Unit Tests: Service Creation
// ========================================
func TestNewJitsiService_ValidConfig_CreatesService(t *testing.T) {
cfg := Config{
BaseURL: "http://localhost:8443",
AppID: "test-app",
AppSecret: "test-secret",
}
service := NewJitsiService(cfg)
if service == nil {
t.Fatal("Expected service to be created, got nil")
}
if service.baseURL != cfg.BaseURL {
t.Errorf("Expected baseURL %s, got %s", cfg.BaseURL, service.baseURL)
}
if service.appID != cfg.AppID {
t.Errorf("Expected appID %s, got %s", cfg.AppID, service.appID)
}
if service.appSecret != cfg.AppSecret {
t.Errorf("Expected appSecret %s, got %s", cfg.AppSecret, service.appSecret)
}
if service.httpClient == nil {
t.Error("Expected httpClient to be initialized")
}
}
func TestNewJitsiService_TrailingSlash_Removed(t *testing.T) {
service := NewJitsiService(Config{
BaseURL: "http://localhost:8443/",
})
if service.baseURL != "http://localhost:8443" {
t.Errorf("Expected trailing slash to be removed, got %s", service.baseURL)
}
}
func TestGetBaseURL_ReturnsConfiguredURL(t *testing.T) {
service := createTestService()
result := service.GetBaseURL()
if result != "http://localhost:8443" {
t.Errorf("Expected 'http://localhost:8443', got '%s'", result)
}
}
func TestIsAuthEnabled_WithSecret_ReturnsTrue(t *testing.T) {
service := createTestService()
if !service.IsAuthEnabled() {
t.Error("Expected auth to be enabled when secret is configured")
}
}
func TestIsAuthEnabled_WithoutSecret_ReturnsFalse(t *testing.T) {
service := createTestServiceWithoutAuth()
if service.IsAuthEnabled() {
t.Error("Expected auth to be disabled when secret is not configured")
}
}
// ========================================
// Unit Tests: Room Name Generation
// ========================================
func TestSanitizeRoomName_ValidInput_ReturnsCleanName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple name",
input: "meeting",
expected: "meeting",
},
{
name: "with spaces",
input: "My Meeting Room",
expected: "my-meeting-room",
},
{
name: "german umlauts",
input: "Schüler Müller",
expected: "schueler-mueller",
},
{
name: "special characters",
input: "Test@#$%Meeting!",
expected: "testmeeting",
},
{
name: "consecutive hyphens",
input: "test---meeting",
expected: "test-meeting",
},
{
name: "leading trailing hyphens",
input: "-test-meeting-",
expected: "test-meeting",
},
{
name: "eszett",
input: "Straße",
expected: "strasse",
},
{
name: "numbers",
input: "Klasse 5a",
expected: "klasse-5a",
},
}
service := createTestService()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := service.sanitizeRoomName(tt.input)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestSanitizeRoomName_LongName_Truncated(t *testing.T) {
service := createTestService()
longName := strings.Repeat("a", 100)
result := service.sanitizeRoomName(longName)
if len(result) > 50 {
t.Errorf("Expected max 50 chars, got %d", len(result))
}
}
func TestGenerateRoomName_ReturnsUniqueNames(t *testing.T) {
service := createTestService()
name1 := service.generateRoomName()
name2 := service.generateRoomName()
if name1 == name2 {
t.Error("Expected unique room names")
}
if !strings.HasPrefix(name1, "breakpilot-") {
t.Errorf("Expected prefix 'breakpilot-', got '%s'", name1)
}
}
func TestGenerateTrainingRoomName_IncludesTitle(t *testing.T) {
service := createTestService()
result := service.generateTrainingRoomName("Go Workshop")
if !strings.HasPrefix(result, "go-workshop-") {
t.Errorf("Expected to start with 'go-workshop-', got '%s'", result)
}
}
func TestGeneratePassword_ReturnsValidPassword(t *testing.T) {
service := createTestService()
password := service.generatePassword()
if len(password) != 8 {
t.Errorf("Expected 8 char password, got %d", len(password))
}
}
// ========================================
// Unit Tests: Meeting Link Creation
// ========================================
func TestCreateMeetingLink_BasicMeeting_ReturnsValidLink(t *testing.T) {
service := createTestServiceWithoutAuth()
meeting := Meeting{
RoomName: "test-room",
DisplayName: "Test User",
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.RoomName != "test-room" {
t.Errorf("Expected room name 'test-room', got '%s'", link.RoomName)
}
if link.URL != "http://localhost:8443/test-room" {
t.Errorf("Expected URL 'http://localhost:8443/test-room', got '%s'", link.URL)
}
if link.JoinURL != "http://localhost:8443/test-room" {
t.Errorf("Expected JoinURL 'http://localhost:8443/test-room', got '%s'", link.JoinURL)
}
}
func TestCreateMeetingLink_NoRoomName_GeneratesName(t *testing.T) {
service := createTestServiceWithoutAuth()
meeting := Meeting{
DisplayName: "Test User",
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.RoomName == "" {
t.Error("Expected room name to be generated")
}
if !strings.HasPrefix(link.RoomName, "breakpilot-") {
t.Errorf("Expected generated room name to start with 'breakpilot-', got '%s'", link.RoomName)
}
}
func TestCreateMeetingLink_WithPassword_IncludesPassword(t *testing.T) {
service := createTestServiceWithoutAuth()
meeting := Meeting{
RoomName: "test-room",
Password: "secret123",
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.Password != "secret123" {
t.Errorf("Expected password 'secret123', got '%s'", link.Password)
}
}
func TestCreateMeetingLink_WithAuth_IncludesJWT(t *testing.T) {
service := createTestService()
meeting := Meeting{
RoomName: "test-room",
DisplayName: "Test User",
Email: "test@example.com",
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.JWT == "" {
t.Error("Expected JWT to be generated")
}
if !strings.Contains(link.JoinURL, "jwt=") {
t.Error("Expected JoinURL to contain JWT parameter")
}
if link.ExpiresAt == nil {
t.Error("Expected ExpiresAt to be set")
}
}
func TestCreateMeetingLink_WithConfig_IncludesParams(t *testing.T) {
service := createTestServiceWithoutAuth()
meeting := Meeting{
RoomName: "test-room",
Config: &MeetingConfig{
StartWithAudioMuted: true,
StartWithVideoMuted: true,
RequireDisplayName: true,
},
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(link.JoinURL, "startWithAudioMuted=true") {
t.Error("Expected JoinURL to contain audio muted config")
}
if !strings.Contains(link.JoinURL, "startWithVideoMuted=true") {
t.Error("Expected JoinURL to contain video muted config")
}
}
func TestCreateMeetingLink_Moderator_SetsModeratorURL(t *testing.T) {
service := createTestService()
meeting := Meeting{
RoomName: "test-room",
DisplayName: "Admin",
Moderator: true,
}
link, err := service.CreateMeetingLink(context.Background(), meeting)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.ModeratorURL == "" {
t.Error("Expected ModeratorURL to be set for moderator")
}
}
// ========================================
// Unit Tests: Specialized Meeting Types
// ========================================
func TestCreateTrainingSession_ReturnsOptimizedConfig(t *testing.T) {
service := createTestServiceWithoutAuth()
link, err := service.CreateTrainingSession(
context.Background(),
"Go Grundlagen",
"Max Trainer",
"trainer@example.com",
60,
)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(link.RoomName, "go-grundlagen") {
t.Errorf("Expected room name to contain 'go-grundlagen', got '%s'", link.RoomName)
}
// Config should have lobby enabled for training
if !strings.Contains(link.JoinURL, "enableLobby=true") {
t.Error("Expected training to have lobby enabled")
}
}
func TestCreateQuickMeeting_ReturnsSimpleMeeting(t *testing.T) {
service := createTestServiceWithoutAuth()
link, err := service.CreateQuickMeeting(context.Background(), "Quick User")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if link.RoomName == "" {
t.Error("Expected room name to be generated")
}
}
func TestCreateParentTeacherMeeting_ReturnsSecureMeeting(t *testing.T) {
service := createTestServiceWithoutAuth()
scheduledTime := time.Now().Add(24 * time.Hour)
link, err := service.CreateParentTeacherMeeting(
context.Background(),
"Frau Müller",
"Herr Schmidt",
"Max Mustermann",
scheduledTime,
)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(link.RoomName, "elterngespraech") {
t.Errorf("Expected room name to contain 'elterngespraech', got '%s'", link.RoomName)
}
if link.Password == "" {
t.Error("Expected password for parent-teacher meeting")
}
if !strings.Contains(link.JoinURL, "enableLobby=true") {
t.Error("Expected lobby to be enabled")
}
}
func TestCreateClassMeeting_ReturnsMeetingForClass(t *testing.T) {
service := createTestServiceWithoutAuth()
link, err := service.CreateClassMeeting(
context.Background(),
"5a",
"Herr Lehrer",
"Mathematik",
)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(link.RoomName, "klasse-5a") {
t.Errorf("Expected room name to contain 'klasse-5a', got '%s'", link.RoomName)
}
// Students should be muted by default
if !strings.Contains(link.JoinURL, "startWithAudioMuted=true") {
t.Error("Expected students to start muted")
}
}
// ========================================
// Unit Tests: JWT Generation
// ========================================
func TestGenerateJWT_ValidClaims_ReturnsValidToken(t *testing.T) {
service := createTestService()
meeting := Meeting{
RoomName: "test-room",
DisplayName: "Test User",
Email: "test@example.com",
Moderator: true,
Duration: 60,
}
token, expiresAt, err := service.generateJWT(meeting, "test-room")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if token == "" {
t.Error("Expected token to be generated")
}
if expiresAt == nil {
t.Error("Expected expiration time to be set")
}
// Verify token structure (header.payload.signature)
parts := strings.Split(token, ".")
if len(parts) != 3 {
t.Errorf("Expected 3 JWT parts, got %d", len(parts))
}
// Decode and verify payload
payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
t.Fatalf("Failed to decode payload: %v", err)
}
var claims JWTClaims
if err := json.Unmarshal(payloadJSON, &claims); err != nil {
t.Fatalf("Failed to unmarshal claims: %v", err)
}
if claims.Room != "test-room" {
t.Errorf("Expected room 'test-room', got '%s'", claims.Room)
}
if !claims.Moderator {
t.Error("Expected moderator to be true")
}
if claims.Context == nil || claims.Context.User == nil {
t.Error("Expected user context to be set")
}
if claims.Context.User.Name != "Test User" {
t.Errorf("Expected user name 'Test User', got '%s'", claims.Context.User.Name)
}
}
func TestGenerateJWT_WithFeatures_IncludesFeatures(t *testing.T) {
service := createTestService()
meeting := Meeting{
RoomName: "test-room",
Features: &MeetingFeatures{
Recording: true,
Transcription: true,
},
}
token, _, err := service.generateJWT(meeting, "test-room")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
parts := strings.Split(token, ".")
payloadJSON, _ := base64.RawURLEncoding.DecodeString(parts[1])
var claims JWTClaims
json.Unmarshal(payloadJSON, &claims)
if claims.Features == nil {
t.Error("Expected features to be set")
}
if claims.Features.Recording != "true" {
t.Errorf("Expected recording 'true', got '%s'", claims.Features.Recording)
}
}
func TestGenerateJWT_NoSecret_ReturnsError(t *testing.T) {
service := createTestServiceWithoutAuth()
meeting := Meeting{RoomName: "test"}
_, _, err := service.generateJWT(meeting, "test")
if err == nil {
t.Error("Expected error when secret is not configured")
}
}
// ========================================
// Unit Tests: Health Check
// ========================================
func TestHealthCheck_ServerAvailable_ReturnsNil(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
service := NewJitsiService(Config{BaseURL: server.URL})
err := service.HealthCheck(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
func TestHealthCheck_ServerError_ReturnsError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
service := NewJitsiService(Config{BaseURL: server.URL})
err := service.HealthCheck(context.Background())
if err == nil {
t.Error("Expected error for server error response")
}
}
func TestHealthCheck_ServerUnreachable_ReturnsError(t *testing.T) {
service := NewJitsiService(Config{BaseURL: "http://localhost:59999"})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := service.HealthCheck(ctx)
if err == nil {
t.Error("Expected error for unreachable server")
}
}
// ========================================
// Unit Tests: URL Building
// ========================================
func TestBuildEmbedURL_BasicRoom_ReturnsURL(t *testing.T) {
service := createTestService()
url := service.BuildEmbedURL("test-room", "", nil)
if url != "http://localhost:8443/test-room" {
t.Errorf("Expected 'http://localhost:8443/test-room', got '%s'", url)
}
}
func TestBuildEmbedURL_WithDisplayName_IncludesParam(t *testing.T) {
service := createTestService()
url := service.BuildEmbedURL("test-room", "Max Mustermann", nil)
if !strings.Contains(url, "displayName=Max") {
t.Errorf("Expected URL to contain display name, got '%s'", url)
}
}
func TestBuildEmbedURL_WithConfig_IncludesParams(t *testing.T) {
service := createTestService()
config := &MeetingConfig{
StartWithAudioMuted: true,
StartWithVideoMuted: true,
}
url := service.BuildEmbedURL("test-room", "", config)
if !strings.Contains(url, "startWithAudioMuted=true") {
t.Error("Expected URL to contain audio muted config")
}
if !strings.Contains(url, "startWithVideoMuted=true") {
t.Error("Expected URL to contain video muted config")
}
}
func TestBuildIFrameCode_DefaultSize_Returns800x600(t *testing.T) {
service := createTestService()
code := service.BuildIFrameCode("test-room", 0, 0)
if !strings.Contains(code, "width=\"800\"") {
t.Error("Expected default width 800")
}
if !strings.Contains(code, "height=\"600\"") {
t.Error("Expected default height 600")
}
if !strings.Contains(code, "test-room") {
t.Error("Expected room name in iframe")
}
if !strings.Contains(code, "allow=\"camera; microphone") {
t.Error("Expected camera/microphone permissions")
}
}
func TestBuildIFrameCode_CustomSize_ReturnsCorrectDimensions(t *testing.T) {
service := createTestService()
code := service.BuildIFrameCode("test-room", 1920, 1080)
if !strings.Contains(code, "width=\"1920\"") {
t.Error("Expected width 1920")
}
if !strings.Contains(code, "height=\"1080\"") {
t.Error("Expected height 1080")
}
}
// ========================================
// Unit Tests: Server Info
// ========================================
func TestGetServerInfo_ReturnsInfo(t *testing.T) {
service := createTestService()
info := service.GetServerInfo()
if info["base_url"] != "http://localhost:8443" {
t.Errorf("Expected base_url, got '%s'", info["base_url"])
}
if info["app_id"] != "breakpilot" {
t.Errorf("Expected app_id 'breakpilot', got '%s'", info["app_id"])
}
if info["auth_enabled"] != "true" {
t.Errorf("Expected auth_enabled 'true', got '%s'", info["auth_enabled"])
}
}
// ========================================
// Unit Tests: Helper Functions
// ========================================
func TestBoolToString_True_ReturnsTrue(t *testing.T) {
result := boolToString(true)
if result != "true" {
t.Errorf("Expected 'true', got '%s'", result)
}
}
func TestBoolToString_False_ReturnsFalse(t *testing.T) {
result := boolToString(false)
if result != "false" {
t.Errorf("Expected 'false', got '%s'", result)
}
}

View File

@@ -0,0 +1,368 @@
package matrix
import (
"context"
"fmt"
"time"
)
// ========================================
// Breakpilot Drive Game Room Types
// ========================================
// GameMode represents different multiplayer game modes
type GameMode string
const (
GameModeSolo GameMode = "solo"
GameModeCoop GameMode = "coop" // 2 players, same track
GameModeChallenge GameMode = "challenge" // 1v1 competition
GameModeClassRace GameMode = "class_race" // Whole class competition
)
// GameRoomConfig holds configuration for game rooms
type GameRoomConfig struct {
GameMode GameMode `json:"game_mode"`
SessionID string `json:"session_id"`
HostUserID string `json:"host_user_id"`
HostName string `json:"host_name"`
ClassName string `json:"class_name,omitempty"`
MaxPlayers int `json:"max_players,omitempty"`
TeacherIDs []string `json:"teacher_ids,omitempty"`
EnableVoice bool `json:"enable_voice,omitempty"`
}
// GameRoom represents an active game room
type GameRoom struct {
RoomID string `json:"room_id"`
SessionID string `json:"session_id"`
GameMode GameMode `json:"game_mode"`
HostUserID string `json:"host_user_id"`
Players []string `json:"players"`
CreatedAt time.Time `json:"created_at"`
IsActive bool `json:"is_active"`
}
// GameEvent represents game events to broadcast
type GameEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
PlayerID string `json:"player_id"`
Data interface{} `json:"data"`
Timestamp time.Time `json:"timestamp"`
}
// GameEventType constants
const (
GameEventPlayerJoined = "player_joined"
GameEventPlayerLeft = "player_left"
GameEventGameStarted = "game_started"
GameEventQuizAnswered = "quiz_answered"
GameEventScoreUpdate = "score_update"
GameEventAchievement = "achievement"
GameEventChallengeWon = "challenge_won"
GameEventRaceFinished = "race_finished"
)
// ========================================
// Game Room Management
// ========================================
// CreateGameTeamRoom creates a private room for 2-4 players (Co-Op mode)
func (s *MatrixService) CreateGameTeamRoom(ctx context.Context, config GameRoomConfig) (*CreateRoomResponse, error) {
roomName := fmt.Sprintf("Breakpilot Drive - Team %s", config.SessionID[:8])
topic := "Co-Op Spielsession - Arbeitet zusammen!"
// All players can write
users := make(map[string]int)
users[s.GenerateUserID(config.HostUserID)] = 50
req := CreateRoomRequest{
Name: roomName,
Topic: topic,
Visibility: "private",
Preset: "private_chat",
InitialState: []StateEvent{
{
Type: "m.room.encryption",
StateKey: "",
Content: map[string]string{
"algorithm": "m.megolm.v1.aes-sha2",
},
},
// Custom game state
{
Type: "breakpilot.game.session",
StateKey: "",
Content: map[string]interface{}{
"session_id": config.SessionID,
"game_mode": string(config.GameMode),
"host_id": config.HostUserID,
"created_at": time.Now().UTC().Format(time.RFC3339),
},
},
},
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 0, // All players can send messages
UsersDefault: 50,
Users: users,
Events: map[string]int{
"breakpilot.game.event": 0, // Anyone can send game events
},
},
}
return s.CreateRoom(ctx, req)
}
// CreateGameChallengeRoom creates a 1v1 challenge room
func (s *MatrixService) CreateGameChallengeRoom(ctx context.Context, config GameRoomConfig, challengerID string, opponentID string) (*CreateRoomResponse, error) {
roomName := fmt.Sprintf("Challenge: %s", config.SessionID[:8])
topic := "1v1 Wettbewerb - Möge der Bessere gewinnen!"
allPlayers := []string{
s.GenerateUserID(challengerID),
s.GenerateUserID(opponentID),
}
users := make(map[string]int)
for _, id := range allPlayers {
users[id] = 50
}
req := CreateRoomRequest{
Name: roomName,
Topic: topic,
Visibility: "private",
Preset: "private_chat",
Invite: allPlayers,
InitialState: []StateEvent{
{
Type: "breakpilot.game.session",
StateKey: "",
Content: map[string]interface{}{
"session_id": config.SessionID,
"game_mode": string(GameModeChallenge),
"challenger_id": challengerID,
"opponent_id": opponentID,
"created_at": time.Now().UTC().Format(time.RFC3339),
},
},
},
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 0,
UsersDefault: 50,
Users: users,
},
}
return s.CreateRoom(ctx, req)
}
// CreateGameClassRaceRoom creates a room for class-wide competition
func (s *MatrixService) CreateGameClassRaceRoom(ctx context.Context, config GameRoomConfig) (*CreateRoomResponse, error) {
roomName := fmt.Sprintf("Klassenrennen: %s", config.ClassName)
topic := fmt.Sprintf("Klassenrennen der %s - Alle gegen alle!", config.ClassName)
// Teachers get moderator power level
users := make(map[string]int)
for _, teacherID := range config.TeacherIDs {
users[s.GenerateUserID(teacherID)] = 100
}
req := CreateRoomRequest{
Name: roomName,
Topic: topic,
Visibility: "private",
Preset: "private_chat",
InitialState: []StateEvent{
{
Type: "breakpilot.game.session",
StateKey: "",
Content: map[string]interface{}{
"session_id": config.SessionID,
"game_mode": string(GameModeClassRace),
"class_name": config.ClassName,
"teacher_ids": config.TeacherIDs,
"created_at": time.Now().UTC().Format(time.RFC3339),
},
},
},
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 0, // Students can send messages
UsersDefault: 10, // Default student level
Users: users,
Invite: 100, // Only teachers can invite
Kick: 100, // Only teachers can kick
Events: map[string]int{
"breakpilot.game.event": 0, // Anyone can send game events
"breakpilot.game.leaderboard": 100, // Only teachers update leaderboard
},
},
}
return s.CreateRoom(ctx, req)
}
// ========================================
// Game Event Broadcasting
// ========================================
// SendGameEvent sends a game event to a room
func (s *MatrixService) SendGameEvent(ctx context.Context, roomID string, event GameEvent) error {
event.Timestamp = time.Now().UTC()
return s.sendEvent(ctx, roomID, "breakpilot.game.event", event)
}
// SendPlayerJoinedEvent notifies room that a player joined
func (s *MatrixService) SendPlayerJoinedEvent(ctx context.Context, roomID string, sessionID string, playerID string, playerName string) error {
event := GameEvent{
Type: GameEventPlayerJoined,
SessionID: sessionID,
PlayerID: playerID,
Data: map[string]string{
"player_name": playerName,
},
}
// Also send a visible message
msg := fmt.Sprintf("🎮 %s ist dem Spiel beigetreten!", playerName)
if err := s.SendMessage(ctx, roomID, msg); err != nil {
// Log but don't fail
fmt.Printf("Warning: failed to send join message: %v\n", err)
}
return s.SendGameEvent(ctx, roomID, event)
}
// SendScoreUpdateEvent broadcasts score updates
func (s *MatrixService) SendScoreUpdateEvent(ctx context.Context, roomID string, sessionID string, playerID string, score int, accuracy float64) error {
event := GameEvent{
Type: GameEventScoreUpdate,
SessionID: sessionID,
PlayerID: playerID,
Data: map[string]interface{}{
"score": score,
"accuracy": accuracy,
},
}
return s.SendGameEvent(ctx, roomID, event)
}
// SendQuizAnsweredEvent broadcasts when a player answers a quiz
func (s *MatrixService) SendQuizAnsweredEvent(ctx context.Context, roomID string, sessionID string, playerID string, correct bool, subject string) error {
event := GameEvent{
Type: GameEventQuizAnswered,
SessionID: sessionID,
PlayerID: playerID,
Data: map[string]interface{}{
"correct": correct,
"subject": subject,
},
}
return s.SendGameEvent(ctx, roomID, event)
}
// SendAchievementEvent broadcasts when a player earns an achievement
func (s *MatrixService) SendAchievementEvent(ctx context.Context, roomID string, sessionID string, playerID string, achievementID string, achievementName string) error {
event := GameEvent{
Type: GameEventAchievement,
SessionID: sessionID,
PlayerID: playerID,
Data: map[string]interface{}{
"achievement_id": achievementID,
"achievement_name": achievementName,
},
}
// Also send a visible celebration message
msg := fmt.Sprintf("🏆 Erfolg freigeschaltet: %s!", achievementName)
if err := s.SendMessage(ctx, roomID, msg); err != nil {
fmt.Printf("Warning: failed to send achievement message: %v\n", err)
}
return s.SendGameEvent(ctx, roomID, event)
}
// SendChallengeWonEvent broadcasts challenge result
func (s *MatrixService) SendChallengeWonEvent(ctx context.Context, roomID string, sessionID string, winnerID string, winnerName string, loserName string, winnerScore int, loserScore int) error {
event := GameEvent{
Type: GameEventChallengeWon,
SessionID: sessionID,
PlayerID: winnerID,
Data: map[string]interface{}{
"winner_name": winnerName,
"loser_name": loserName,
"winner_score": winnerScore,
"loser_score": loserScore,
},
}
// Send celebration message
msg := fmt.Sprintf("🎉 %s gewinnt gegen %s mit %d zu %d Punkten!", winnerName, loserName, winnerScore, loserScore)
if err := s.SendHTMLMessage(ctx, roomID, msg, fmt.Sprintf("<h3>🎉 Challenge beendet!</h3><p><strong>%s</strong> gewinnt gegen %s</p><p>Endstand: %d : %d</p>", winnerName, loserName, winnerScore, loserScore)); err != nil {
fmt.Printf("Warning: failed to send challenge result message: %v\n", err)
}
return s.SendGameEvent(ctx, roomID, event)
}
// SendClassRaceLeaderboard broadcasts current leaderboard in class race
func (s *MatrixService) SendClassRaceLeaderboard(ctx context.Context, roomID string, sessionID string, leaderboard []map[string]interface{}) error {
// Build leaderboard message
msg := "🏁 Aktueller Stand:\n"
htmlMsg := "<h3>🏁 Aktueller Stand</h3><ol>"
for i, entry := range leaderboard {
if i >= 10 { // Top 10 only
break
}
name := entry["name"].(string)
score := entry["score"].(int)
msg += fmt.Sprintf("%d. %s - %d Punkte\n", i+1, name, score)
htmlMsg += fmt.Sprintf("<li><strong>%s</strong> - %d Punkte</li>", name, score)
}
htmlMsg += "</ol>"
return s.SendHTMLMessage(ctx, roomID, msg, htmlMsg)
}
// ========================================
// Game Room Utilities
// ========================================
// AddPlayerToGameRoom invites and sets up a player in a game room
func (s *MatrixService) AddPlayerToGameRoom(ctx context.Context, roomID string, playerMatrixID string, playerName string) error {
// Invite the player
if err := s.InviteUser(ctx, roomID, playerMatrixID); err != nil {
return fmt.Errorf("failed to invite player: %w", err)
}
// Set display name if not already set
if err := s.SetDisplayName(ctx, playerMatrixID, playerName); err != nil {
// Log but don't fail - display name might already be set
fmt.Printf("Warning: failed to set display name: %v\n", err)
}
return nil
}
// CloseGameRoom sends end message and archives the room
func (s *MatrixService) CloseGameRoom(ctx context.Context, roomID string, sessionID string) error {
// Send closing message
msg := "🏁 Spiel beendet! Danke fürs Mitspielen. Dieser Raum wird archiviert."
if err := s.SendMessage(ctx, roomID, msg); err != nil {
return fmt.Errorf("failed to send closing message: %w", err)
}
// Update room state to mark as closed
closeEvent := map[string]interface{}{
"closed": true,
"closed_at": time.Now().UTC().Format(time.RFC3339),
}
return s.sendEvent(ctx, roomID, "breakpilot.game.closed", closeEvent)
}

View File

@@ -0,0 +1,548 @@
package matrix
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/google/uuid"
)
// MatrixService handles Matrix homeserver communication
type MatrixService struct {
homeserverURL string
accessToken string
serverName string
httpClient *http.Client
}
// Config holds Matrix service configuration
type Config struct {
HomeserverURL string // e.g., "http://synapse:8008"
AccessToken string // Admin/bot access token
ServerName string // e.g., "breakpilot.local"
}
// NewMatrixService creates a new Matrix service instance
func NewMatrixService(cfg Config) *MatrixService {
return &MatrixService{
homeserverURL: cfg.HomeserverURL,
accessToken: cfg.AccessToken,
serverName: cfg.ServerName,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// ========================================
// Matrix API Types
// ========================================
// CreateRoomRequest represents a Matrix room creation request
type CreateRoomRequest struct {
Name string `json:"name,omitempty"`
RoomAliasName string `json:"room_alias_name,omitempty"`
Topic string `json:"topic,omitempty"`
Visibility string `json:"visibility,omitempty"` // "private" or "public"
Preset string `json:"preset,omitempty"` // "private_chat", "public_chat", "trusted_private_chat"
IsDirect bool `json:"is_direct,omitempty"`
Invite []string `json:"invite,omitempty"`
InitialState []StateEvent `json:"initial_state,omitempty"`
PowerLevelContentOverride *PowerLevels `json:"power_level_content_override,omitempty"`
}
// CreateRoomResponse represents a Matrix room creation response
type CreateRoomResponse struct {
RoomID string `json:"room_id"`
}
// StateEvent represents a Matrix state event
type StateEvent struct {
Type string `json:"type"`
StateKey string `json:"state_key"`
Content interface{} `json:"content"`
}
// PowerLevels represents Matrix power levels
type PowerLevels struct {
Ban int `json:"ban,omitempty"`
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
Invite int `json:"invite,omitempty"`
Kick int `json:"kick,omitempty"`
Redact int `json:"redact,omitempty"`
StateDefault int `json:"state_default,omitempty"`
Users map[string]int `json:"users,omitempty"`
UsersDefault int `json:"users_default,omitempty"`
}
// SendMessageRequest represents a message to send
type SendMessageRequest struct {
MsgType string `json:"msgtype"`
Body string `json:"body"`
Format string `json:"format,omitempty"`
FormattedBody string `json:"formatted_body,omitempty"`
}
// UserInfo represents Matrix user information
type UserInfo struct {
UserID string `json:"user_id"`
DisplayName string `json:"displayname,omitempty"`
AvatarURL string `json:"avatar_url,omitempty"`
}
// RegisterRequest for user registration
type RegisterRequest struct {
Username string `json:"username"`
Password string `json:"password,omitempty"`
Admin bool `json:"admin,omitempty"`
}
// RegisterResponse for user registration
type RegisterResponse struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
DeviceID string `json:"device_id"`
}
// InviteRequest for inviting a user to a room
type InviteRequest struct {
UserID string `json:"user_id"`
}
// JoinRequest for joining a room
type JoinRequest struct {
Reason string `json:"reason,omitempty"`
}
// ========================================
// Room Management
// ========================================
// CreateRoom creates a new Matrix room
func (s *MatrixService) CreateRoom(ctx context.Context, req CreateRoomRequest) (*CreateRoomResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
resp, err := s.doRequest(ctx, "POST", "/_matrix/client/v3/createRoom", body)
if err != nil {
return nil, fmt.Errorf("failed to create room: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, s.parseError(resp)
}
var result CreateRoomResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
return &result, nil
}
// CreateClassInfoRoom creates a broadcast room for a class (teachers write, parents read)
func (s *MatrixService) CreateClassInfoRoom(ctx context.Context, className string, schoolName string, teacherMatrixIDs []string) (*CreateRoomResponse, error) {
// Set up power levels: teachers can write (50), parents read-only (0)
users := make(map[string]int)
for _, teacherID := range teacherMatrixIDs {
users[teacherID] = 50
}
req := CreateRoomRequest{
Name: fmt.Sprintf("%s - %s (Info)", className, schoolName),
Topic: fmt.Sprintf("Info-Kanal für %s. Nur Lehrer können schreiben.", className),
Visibility: "private",
Preset: "private_chat",
Invite: teacherMatrixIDs,
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 50, // Only power level 50+ can send messages
UsersDefault: 0, // Parents get power level 0 by default
Users: users,
Invite: 50,
Kick: 50,
Ban: 50,
Redact: 50,
},
}
return s.CreateRoom(ctx, req)
}
// CreateStudentDMRoom creates a direct message room for parent-teacher communication about a student
func (s *MatrixService) CreateStudentDMRoom(ctx context.Context, studentName string, className string, teacherMatrixIDs []string, parentMatrixIDs []string) (*CreateRoomResponse, error) {
allUsers := append(teacherMatrixIDs, parentMatrixIDs...)
users := make(map[string]int)
for _, id := range allUsers {
users[id] = 50 // All can write
}
req := CreateRoomRequest{
Name: fmt.Sprintf("%s (%s) - Dialog", studentName, className),
Topic: fmt.Sprintf("Kommunikation über %s", studentName),
Visibility: "private",
Preset: "trusted_private_chat",
IsDirect: false,
Invite: allUsers,
InitialState: []StateEvent{
{
Type: "m.room.encryption",
StateKey: "",
Content: map[string]string{
"algorithm": "m.megolm.v1.aes-sha2",
},
},
},
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 0, // Everyone can send messages
UsersDefault: 50,
Users: users,
},
}
return s.CreateRoom(ctx, req)
}
// CreateParentRepRoom creates a room for class teacher and parent representatives
func (s *MatrixService) CreateParentRepRoom(ctx context.Context, className string, teacherMatrixIDs []string, parentRepMatrixIDs []string) (*CreateRoomResponse, error) {
allUsers := append(teacherMatrixIDs, parentRepMatrixIDs...)
users := make(map[string]int)
for _, id := range allUsers {
users[id] = 50
}
req := CreateRoomRequest{
Name: fmt.Sprintf("%s - Elternvertreter", className),
Topic: fmt.Sprintf("Kommunikation zwischen Lehrkräften und Elternvertretern der %s", className),
Visibility: "private",
Preset: "private_chat",
Invite: allUsers,
PowerLevelContentOverride: &PowerLevels{
EventsDefault: 0,
UsersDefault: 50,
Users: users,
},
}
return s.CreateRoom(ctx, req)
}
// ========================================
// User Management
// ========================================
// RegisterUser registers a new Matrix user (requires admin token)
func (s *MatrixService) RegisterUser(ctx context.Context, username string, displayName string) (*RegisterResponse, error) {
// Use admin API for user registration
req := map[string]interface{}{
"username": username,
"password": uuid.New().String(), // Generate random password
"admin": false,
}
body, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Use admin registration endpoint
resp, err := s.doRequest(ctx, "POST", "/_synapse/admin/v2/users/@"+username+":"+s.serverName, body)
if err != nil {
return nil, fmt.Errorf("failed to register user: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, s.parseError(resp)
}
// Set display name
if displayName != "" {
if err := s.SetDisplayName(ctx, "@"+username+":"+s.serverName, displayName); err != nil {
// Log but don't fail
fmt.Printf("Warning: failed to set display name: %v\n", err)
}
}
return &RegisterResponse{
UserID: "@" + username + ":" + s.serverName,
}, nil
}
// SetDisplayName sets the display name for a user
func (s *MatrixService) SetDisplayName(ctx context.Context, userID string, displayName string) error {
req := map[string]string{
"displayname": displayName,
}
body, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
endpoint := fmt.Sprintf("/_matrix/client/v3/profile/%s/displayname", url.PathEscape(userID))
resp, err := s.doRequest(ctx, "PUT", endpoint, body)
if err != nil {
return fmt.Errorf("failed to set display name: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return s.parseError(resp)
}
return nil
}
// ========================================
// Room Membership
// ========================================
// InviteUser invites a user to a room
func (s *MatrixService) InviteUser(ctx context.Context, roomID string, userID string) error {
req := InviteRequest{UserID: userID}
body, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/invite", url.PathEscape(roomID))
resp, err := s.doRequest(ctx, "POST", endpoint, body)
if err != nil {
return fmt.Errorf("failed to invite user: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return s.parseError(resp)
}
return nil
}
// JoinRoom makes the bot join a room
func (s *MatrixService) JoinRoom(ctx context.Context, roomIDOrAlias string) error {
endpoint := fmt.Sprintf("/_matrix/client/v3/join/%s", url.PathEscape(roomIDOrAlias))
resp, err := s.doRequest(ctx, "POST", endpoint, []byte("{}"))
if err != nil {
return fmt.Errorf("failed to join room: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return s.parseError(resp)
}
return nil
}
// SetUserPowerLevel sets a user's power level in a room
func (s *MatrixService) SetUserPowerLevel(ctx context.Context, roomID string, userID string, powerLevel int) error {
// First, get current power levels
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/state/m.room.power_levels/", url.PathEscape(roomID))
resp, err := s.doRequest(ctx, "GET", endpoint, nil)
if err != nil {
return fmt.Errorf("failed to get power levels: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return s.parseError(resp)
}
var powerLevels PowerLevels
if err := json.NewDecoder(resp.Body).Decode(&powerLevels); err != nil {
return fmt.Errorf("failed to decode power levels: %w", err)
}
// Update user power level
if powerLevels.Users == nil {
powerLevels.Users = make(map[string]int)
}
powerLevels.Users[userID] = powerLevel
// Send updated power levels
body, err := json.Marshal(powerLevels)
if err != nil {
return fmt.Errorf("failed to marshal power levels: %w", err)
}
resp2, err := s.doRequest(ctx, "PUT", endpoint, body)
if err != nil {
return fmt.Errorf("failed to set power levels: %w", err)
}
defer resp2.Body.Close()
if resp2.StatusCode != http.StatusOK {
return s.parseError(resp2)
}
return nil
}
// ========================================
// Messaging
// ========================================
// SendMessage sends a text message to a room
func (s *MatrixService) SendMessage(ctx context.Context, roomID string, message string) error {
req := SendMessageRequest{
MsgType: "m.text",
Body: message,
}
return s.sendEvent(ctx, roomID, "m.room.message", req)
}
// SendHTMLMessage sends an HTML-formatted message to a room
func (s *MatrixService) SendHTMLMessage(ctx context.Context, roomID string, plainText string, htmlBody string) error {
req := SendMessageRequest{
MsgType: "m.text",
Body: plainText,
Format: "org.matrix.custom.html",
FormattedBody: htmlBody,
}
return s.sendEvent(ctx, roomID, "m.room.message", req)
}
// SendAbsenceNotification sends an absence notification to parents
func (s *MatrixService) SendAbsenceNotification(ctx context.Context, roomID string, studentName string, date string, lessonNumber int) error {
plainText := fmt.Sprintf("⚠️ Abwesenheitsmeldung\n\nIhr Kind %s war heute (%s) in der %d. Stunde nicht im Unterricht anwesend.\n\nBitte bestätigen Sie den Grund der Abwesenheit.", studentName, date, lessonNumber)
htmlBody := fmt.Sprintf(`<h3>⚠️ Abwesenheitsmeldung</h3>
<p>Ihr Kind <strong>%s</strong> war heute (%s) in der <strong>%d. Stunde</strong> nicht im Unterricht anwesend.</p>
<p>Bitte bestätigen Sie den Grund der Abwesenheit.</p>
<ul>
<li>✅ Entschuldigt (Krankheit)</li>
<li>📋 Arztbesuch</li>
<li>❓ Sonstiges (bitte erläutern)</li>
</ul>`, studentName, date, lessonNumber)
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
}
// SendGradeNotification sends a grade notification to parents
func (s *MatrixService) SendGradeNotification(ctx context.Context, roomID string, studentName string, subject string, gradeType string, grade float64) error {
plainText := fmt.Sprintf("📊 Neue Note eingetragen\n\nFür %s wurde eine neue Note eingetragen:\n\nFach: %s\nArt: %s\nNote: %.1f", studentName, subject, gradeType, grade)
htmlBody := fmt.Sprintf(`<h3>📊 Neue Note eingetragen</h3>
<p>Für <strong>%s</strong> wurde eine neue Note eingetragen:</p>
<table>
<tr><td>Fach:</td><td><strong>%s</strong></td></tr>
<tr><td>Art:</td><td>%s</td></tr>
<tr><td>Note:</td><td><strong>%.1f</strong></td></tr>
</table>`, studentName, subject, gradeType, grade)
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
}
// SendClassAnnouncement sends an announcement to a class info room
func (s *MatrixService) SendClassAnnouncement(ctx context.Context, roomID string, title string, content string, teacherName string) error {
plainText := fmt.Sprintf("📢 %s\n\n%s\n\n— %s", title, content, teacherName)
htmlBody := fmt.Sprintf(`<h3>📢 %s</h3>
<p>%s</p>
<p><em>— %s</em></p>`, title, content, teacherName)
return s.SendHTMLMessage(ctx, roomID, plainText, htmlBody)
}
// ========================================
// Internal Helpers
// ========================================
func (s *MatrixService) sendEvent(ctx context.Context, roomID string, eventType string, content interface{}) error {
body, err := json.Marshal(content)
if err != nil {
return fmt.Errorf("failed to marshal content: %w", err)
}
txnID := uuid.New().String()
endpoint := fmt.Sprintf("/_matrix/client/v3/rooms/%s/send/%s/%s",
url.PathEscape(roomID), url.PathEscape(eventType), txnID)
resp, err := s.doRequest(ctx, "PUT", endpoint, body)
if err != nil {
return fmt.Errorf("failed to send event: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return s.parseError(resp)
}
return nil
}
func (s *MatrixService) doRequest(ctx context.Context, method string, endpoint string, body []byte) (*http.Response, error) {
fullURL := s.homeserverURL + endpoint
var bodyReader io.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, method, fullURL, bodyReader)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+s.accessToken)
req.Header.Set("Content-Type", "application/json")
return s.httpClient.Do(req)
}
func (s *MatrixService) parseError(resp *http.Response) error {
body, _ := io.ReadAll(resp.Body)
var errResp struct {
ErrCode string `json:"errcode"`
Error string `json:"error"`
}
if err := json.Unmarshal(body, &errResp); err != nil {
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
}
return fmt.Errorf("matrix error %s: %s", errResp.ErrCode, errResp.Error)
}
// ========================================
// Health Check
// ========================================
// HealthCheck checks if the Matrix server is reachable
func (s *MatrixService) HealthCheck(ctx context.Context) error {
resp, err := s.doRequest(ctx, "GET", "/_matrix/client/versions", nil)
if err != nil {
return fmt.Errorf("matrix server unreachable: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("matrix server returned status %d", resp.StatusCode)
}
return nil
}
// GetServerName returns the configured server name
func (s *MatrixService) GetServerName() string {
return s.serverName
}
// GenerateUserID generates a Matrix user ID from a username
func (s *MatrixService) GenerateUserID(username string) string {
return fmt.Sprintf("@%s:%s", username, s.serverName)
}

View File

@@ -0,0 +1,791 @@
package matrix
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
// ========================================
// Test Helpers
// ========================================
// createTestServer creates a mock Matrix server for testing
func createTestServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *MatrixService) {
server := httptest.NewServer(handler)
service := NewMatrixService(Config{
HomeserverURL: server.URL,
AccessToken: "test-access-token",
ServerName: "test.local",
})
return server, service
}
// ========================================
// Unit Tests: Service Creation
// ========================================
func TestNewMatrixService_ValidConfig_CreatesService(t *testing.T) {
cfg := Config{
HomeserverURL: "http://localhost:8008",
AccessToken: "test-token",
ServerName: "breakpilot.local",
}
service := NewMatrixService(cfg)
if service == nil {
t.Fatal("Expected service to be created, got nil")
}
if service.homeserverURL != cfg.HomeserverURL {
t.Errorf("Expected homeserverURL %s, got %s", cfg.HomeserverURL, service.homeserverURL)
}
if service.accessToken != cfg.AccessToken {
t.Errorf("Expected accessToken %s, got %s", cfg.AccessToken, service.accessToken)
}
if service.serverName != cfg.ServerName {
t.Errorf("Expected serverName %s, got %s", cfg.ServerName, service.serverName)
}
if service.httpClient == nil {
t.Error("Expected httpClient to be initialized")
}
if service.httpClient.Timeout != 30*time.Second {
t.Errorf("Expected timeout 30s, got %v", service.httpClient.Timeout)
}
}
func TestGetServerName_ReturnsConfiguredName(t *testing.T) {
service := NewMatrixService(Config{
HomeserverURL: "http://localhost:8008",
AccessToken: "test-token",
ServerName: "school.example.com",
})
result := service.GetServerName()
if result != "school.example.com" {
t.Errorf("Expected 'school.example.com', got '%s'", result)
}
}
func TestGenerateUserID_ValidUsername_ReturnsFormattedID(t *testing.T) {
tests := []struct {
name string
serverName string
username string
expected string
}{
{
name: "simple username",
serverName: "breakpilot.local",
username: "max.mustermann",
expected: "@max.mustermann:breakpilot.local",
},
{
name: "teacher username",
serverName: "school.de",
username: "lehrer_mueller",
expected: "@lehrer_mueller:school.de",
},
{
name: "parent username with numbers",
serverName: "test.local",
username: "eltern123",
expected: "@eltern123:test.local",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := NewMatrixService(Config{
HomeserverURL: "http://localhost:8008",
AccessToken: "test-token",
ServerName: tt.serverName,
})
result := service.GenerateUserID(tt.username)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
// ========================================
// Unit Tests: Health Check
// ========================================
func TestHealthCheck_ServerHealthy_ReturnsNil(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/_matrix/client/versions" {
t.Errorf("Expected path /_matrix/client/versions, got %s", r.URL.Path)
}
if r.Method != "GET" {
t.Errorf("Expected GET method, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"versions": []string{"v1.1", "v1.2"},
})
})
defer server.Close()
err := service.HealthCheck(context.Background())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}
func TestHealthCheck_ServerUnreachable_ReturnsError(t *testing.T) {
service := NewMatrixService(Config{
HomeserverURL: "http://localhost:59999", // Non-existent server
AccessToken: "test-token",
ServerName: "test.local",
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := service.HealthCheck(ctx)
if err == nil {
t.Error("Expected error for unreachable server, got nil")
}
if !strings.Contains(err.Error(), "unreachable") {
t.Errorf("Expected 'unreachable' in error message, got: %v", err)
}
}
func TestHealthCheck_ServerReturns500_ReturnsError(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})
defer server.Close()
err := service.HealthCheck(context.Background())
if err == nil {
t.Error("Expected error for 500 response, got nil")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("Expected '500' in error message, got: %v", err)
}
}
// ========================================
// Unit Tests: Room Creation
// ========================================
func TestCreateRoom_ValidRequest_ReturnsRoomID(t *testing.T) {
expectedRoomID := "!abc123:test.local"
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/_matrix/client/v3/createRoom" {
t.Errorf("Expected path /_matrix/client/v3/createRoom, got %s", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
// Verify authorization header
auth := r.Header.Get("Authorization")
if auth != "Bearer test-access-token" {
t.Errorf("Expected 'Bearer test-access-token', got '%s'", auth)
}
// Verify content type
contentType := r.Header.Get("Content-Type")
if contentType != "application/json" {
t.Errorf("Expected 'application/json', got '%s'", contentType)
}
// Decode and verify request body
var req CreateRoomRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("Failed to decode request body: %v", err)
}
if req.Name != "Test Room" {
t.Errorf("Expected name 'Test Room', got '%s'", req.Name)
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(CreateRoomResponse{
RoomID: expectedRoomID,
})
})
defer server.Close()
req := CreateRoomRequest{
Name: "Test Room",
Visibility: "private",
}
result, err := service.CreateRoom(context.Background(), req)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.RoomID != expectedRoomID {
t.Errorf("Expected room ID '%s', got '%s'", expectedRoomID, result.RoomID)
}
}
func TestCreateRoom_ServerError_ReturnsError(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{
"errcode": "M_FORBIDDEN",
"error": "Not allowed to create rooms",
})
})
defer server.Close()
req := CreateRoomRequest{Name: "Test"}
_, err := service.CreateRoom(context.Background(), req)
if err == nil {
t.Error("Expected error, got nil")
}
if !strings.Contains(err.Error(), "M_FORBIDDEN") {
t.Errorf("Expected 'M_FORBIDDEN' in error, got: %v", err)
}
}
func TestCreateClassInfoRoom_ValidInput_CreatesRoomWithCorrectPowerLevels(t *testing.T) {
var receivedRequest CreateRoomRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedRequest)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!class:test.local"})
})
defer server.Close()
teacherIDs := []string{"@lehrer1:test.local", "@lehrer2:test.local"}
result, err := service.CreateClassInfoRoom(context.Background(), "5a", "Grundschule Musterstadt", teacherIDs)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.RoomID != "!class:test.local" {
t.Errorf("Expected room ID '!class:test.local', got '%s'", result.RoomID)
}
// Verify room name format
expectedName := "5a - Grundschule Musterstadt (Info)"
if receivedRequest.Name != expectedName {
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
}
// Verify power levels
if receivedRequest.PowerLevelContentOverride == nil {
t.Fatal("Expected power level override, got nil")
}
if receivedRequest.PowerLevelContentOverride.EventsDefault != 50 {
t.Errorf("Expected EventsDefault 50, got %d", receivedRequest.PowerLevelContentOverride.EventsDefault)
}
if receivedRequest.PowerLevelContentOverride.UsersDefault != 0 {
t.Errorf("Expected UsersDefault 0, got %d", receivedRequest.PowerLevelContentOverride.UsersDefault)
}
// Verify teachers have power level 50
for _, teacherID := range teacherIDs {
if level, ok := receivedRequest.PowerLevelContentOverride.Users[teacherID]; !ok || level != 50 {
t.Errorf("Expected teacher %s to have power level 50, got %d", teacherID, level)
}
}
}
func TestCreateStudentDMRoom_ValidInput_CreatesEncryptedRoom(t *testing.T) {
var receivedRequest CreateRoomRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedRequest)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!dm:test.local"})
})
defer server.Close()
teacherIDs := []string{"@lehrer:test.local"}
parentIDs := []string{"@eltern1:test.local", "@eltern2:test.local"}
result, err := service.CreateStudentDMRoom(context.Background(), "Max Mustermann", "5a", teacherIDs, parentIDs)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.RoomID != "!dm:test.local" {
t.Errorf("Expected room ID '!dm:test.local', got '%s'", result.RoomID)
}
// Verify room name
expectedName := "Max Mustermann (5a) - Dialog"
if receivedRequest.Name != expectedName {
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
}
// Verify encryption is enabled
foundEncryption := false
for _, state := range receivedRequest.InitialState {
if state.Type == "m.room.encryption" {
foundEncryption = true
// Content comes as map[string]interface{} from JSON unmarshaling
content, ok := state.Content.(map[string]interface{})
if !ok {
t.Errorf("Expected encryption content to be map[string]interface{}, got %T", state.Content)
continue
}
if algo, ok := content["algorithm"].(string); !ok || algo != "m.megolm.v1.aes-sha2" {
t.Errorf("Expected algorithm 'm.megolm.v1.aes-sha2', got '%v'", content["algorithm"])
}
}
}
if !foundEncryption {
t.Error("Expected encryption state event, not found")
}
// Verify all users are invited
expectedInvites := append(teacherIDs, parentIDs...)
for _, expected := range expectedInvites {
found := false
for _, invited := range receivedRequest.Invite {
if invited == expected {
found = true
break
}
}
if !found {
t.Errorf("Expected user %s to be invited", expected)
}
}
}
func TestCreateParentRepRoom_ValidInput_CreatesRoom(t *testing.T) {
var receivedRequest CreateRoomRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedRequest)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(CreateRoomResponse{RoomID: "!rep:test.local"})
})
defer server.Close()
teacherIDs := []string{"@lehrer:test.local"}
repIDs := []string{"@elternvertreter1:test.local", "@elternvertreter2:test.local"}
result, err := service.CreateParentRepRoom(context.Background(), "5a", teacherIDs, repIDs)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if result.RoomID != "!rep:test.local" {
t.Errorf("Expected room ID '!rep:test.local', got '%s'", result.RoomID)
}
// Verify room name
expectedName := "5a - Elternvertreter"
if receivedRequest.Name != expectedName {
t.Errorf("Expected name '%s', got '%s'", expectedName, receivedRequest.Name)
}
// Verify all participants can write (power level 50)
allUsers := append(teacherIDs, repIDs...)
for _, userID := range allUsers {
if level, ok := receivedRequest.PowerLevelContentOverride.Users[userID]; !ok || level != 50 {
t.Errorf("Expected user %s to have power level 50, got %d", userID, level)
}
}
}
// ========================================
// Unit Tests: User Management
// ========================================
func TestSetDisplayName_ValidRequest_Succeeds(t *testing.T) {
var receivedPath string
var receivedBody map[string]string
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
receivedPath = r.URL.Path
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
})
defer server.Close()
err := service.SetDisplayName(context.Background(), "@user:test.local", "Max Mustermann")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
// Path may or may not be URL-encoded depending on Go version
if !strings.Contains(receivedPath, "/profile/") || !strings.Contains(receivedPath, "/displayname") {
t.Errorf("Expected path to contain '/profile/' and '/displayname', got '%s'", receivedPath)
}
if receivedBody["displayname"] != "Max Mustermann" {
t.Errorf("Expected displayname 'Max Mustermann', got '%s'", receivedBody["displayname"])
}
}
// ========================================
// Unit Tests: Room Membership
// ========================================
func TestInviteUser_ValidRequest_Succeeds(t *testing.T) {
var receivedBody InviteRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, "/invite") {
t.Errorf("Expected path to contain '/invite', got '%s'", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
})
defer server.Close()
err := service.InviteUser(context.Background(), "!room:test.local", "@user:test.local")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if receivedBody.UserID != "@user:test.local" {
t.Errorf("Expected user_id '@user:test.local', got '%s'", receivedBody.UserID)
}
}
func TestInviteUser_UserAlreadyInRoom_ReturnsError(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
json.NewEncoder(w).Encode(map[string]string{
"errcode": "M_FORBIDDEN",
"error": "User is already in the room",
})
})
defer server.Close()
err := service.InviteUser(context.Background(), "!room:test.local", "@user:test.local")
if err == nil {
t.Error("Expected error, got nil")
}
}
func TestJoinRoom_ValidRequest_Succeeds(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, "/join/") {
t.Errorf("Expected path to contain '/join/', got '%s'", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
}
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"room_id": "!room:test.local"})
})
defer server.Close()
err := service.JoinRoom(context.Background(), "!room:test.local")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
}
// ========================================
// Unit Tests: Messaging
// ========================================
func TestSendMessage_ValidRequest_Succeeds(t *testing.T) {
var receivedBody SendMessageRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, "/send/m.room.message/") {
t.Errorf("Expected path to contain '/send/m.room.message/', got '%s'", r.URL.Path)
}
if r.Method != "PUT" {
t.Errorf("Expected PUT method, got %s", r.Method)
}
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
})
defer server.Close()
err := service.SendMessage(context.Background(), "!room:test.local", "Hello, World!")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if receivedBody.MsgType != "m.text" {
t.Errorf("Expected msgtype 'm.text', got '%s'", receivedBody.MsgType)
}
if receivedBody.Body != "Hello, World!" {
t.Errorf("Expected body 'Hello, World!', got '%s'", receivedBody.Body)
}
}
func TestSendHTMLMessage_ValidRequest_IncludesFormattedBody(t *testing.T) {
var receivedBody SendMessageRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
})
defer server.Close()
err := service.SendHTMLMessage(context.Background(), "!room:test.local", "Plain text", "<b>Bold text</b>")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if receivedBody.Format != "org.matrix.custom.html" {
t.Errorf("Expected format 'org.matrix.custom.html', got '%s'", receivedBody.Format)
}
if receivedBody.Body != "Plain text" {
t.Errorf("Expected body 'Plain text', got '%s'", receivedBody.Body)
}
if receivedBody.FormattedBody != "<b>Bold text</b>" {
t.Errorf("Expected formatted_body '<b>Bold text</b>', got '%s'", receivedBody.FormattedBody)
}
}
func TestSendAbsenceNotification_ValidRequest_FormatsCorrectly(t *testing.T) {
var receivedBody SendMessageRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
})
defer server.Close()
err := service.SendAbsenceNotification(context.Background(), "!room:test.local", "Max Mustermann", "15.12.2025", 3)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
// Verify plain text contains key information
if !strings.Contains(receivedBody.Body, "Max Mustermann") {
t.Error("Expected body to contain student name")
}
if !strings.Contains(receivedBody.Body, "15.12.2025") {
t.Error("Expected body to contain date")
}
if !strings.Contains(receivedBody.Body, "3. Stunde") {
t.Error("Expected body to contain lesson number")
}
if !strings.Contains(receivedBody.Body, "Abwesenheitsmeldung") {
t.Error("Expected body to contain 'Abwesenheitsmeldung'")
}
// Verify HTML is set
if receivedBody.FormattedBody == "" {
t.Error("Expected formatted body to be set")
}
}
func TestSendGradeNotification_ValidRequest_FormatsCorrectly(t *testing.T) {
var receivedBody SendMessageRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
})
defer server.Close()
err := service.SendGradeNotification(context.Background(), "!room:test.local", "Max Mustermann", "Mathematik", "Klassenarbeit", 2.3)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(receivedBody.Body, "Max Mustermann") {
t.Error("Expected body to contain student name")
}
if !strings.Contains(receivedBody.Body, "Mathematik") {
t.Error("Expected body to contain subject")
}
if !strings.Contains(receivedBody.Body, "Klassenarbeit") {
t.Error("Expected body to contain grade type")
}
if !strings.Contains(receivedBody.Body, "2.3") {
t.Error("Expected body to contain grade")
}
}
func TestSendClassAnnouncement_ValidRequest_FormatsCorrectly(t *testing.T) {
var receivedBody SendMessageRequest
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&receivedBody)
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"event_id": "$event123"})
})
defer server.Close()
err := service.SendClassAnnouncement(context.Background(), "!room:test.local", "Elternabend", "Am 20.12. findet der Elternabend statt.", "Frau Müller")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !strings.Contains(receivedBody.Body, "Elternabend") {
t.Error("Expected body to contain title")
}
if !strings.Contains(receivedBody.Body, "20.12.") {
t.Error("Expected body to contain content")
}
if !strings.Contains(receivedBody.Body, "Frau Müller") {
t.Error("Expected body to contain teacher name")
}
}
// ========================================
// Unit Tests: Power Levels
// ========================================
func TestSetUserPowerLevel_ValidRequest_UpdatesPowerLevel(t *testing.T) {
callCount := 0
var putBody PowerLevels
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method == "GET" {
// Return current power levels
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(PowerLevels{
Users: map[string]int{
"@admin:test.local": 100,
},
UsersDefault: 0,
})
} else if r.Method == "PUT" {
// Update power levels
json.NewDecoder(r.Body).Decode(&putBody)
w.WriteHeader(http.StatusOK)
}
})
defer server.Close()
err := service.SetUserPowerLevel(context.Background(), "!room:test.local", "@newuser:test.local", 50)
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if callCount != 2 {
t.Errorf("Expected 2 API calls (GET then PUT), got %d", callCount)
}
if putBody.Users["@newuser:test.local"] != 50 {
t.Errorf("Expected user power level 50, got %d", putBody.Users["@newuser:test.local"])
}
// Verify existing users are preserved
if putBody.Users["@admin:test.local"] != 100 {
t.Errorf("Expected admin power level 100 to be preserved, got %d", putBody.Users["@admin:test.local"])
}
}
// ========================================
// Unit Tests: Error Handling
// ========================================
func TestParseError_MatrixError_ExtractsFields(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"errcode": "M_UNKNOWN",
"error": "Something went wrong",
})
})
defer server.Close()
_, err := service.CreateRoom(context.Background(), CreateRoomRequest{Name: "Test"})
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "M_UNKNOWN") {
t.Errorf("Expected error to contain 'M_UNKNOWN', got: %v", err)
}
if !strings.Contains(err.Error(), "Something went wrong") {
t.Errorf("Expected error to contain 'Something went wrong', got: %v", err)
}
}
func TestParseError_NonJSONError_ReturnsRawBody(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error"))
})
defer server.Close()
_, err := service.CreateRoom(context.Background(), CreateRoomRequest{Name: "Test"})
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("Expected error to contain '500', got: %v", err)
}
}
// ========================================
// Unit Tests: Context Handling
// ========================================
func TestCreateRoom_ContextCanceled_ReturnsError(t *testing.T) {
server, service := createTestServer(t, func(w http.ResponseWriter, r *http.Request) {
time.Sleep(100 * time.Millisecond)
w.WriteHeader(http.StatusOK)
})
defer server.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
_, err := service.CreateRoom(ctx, CreateRoomRequest{Name: "Test"})
if err == nil {
t.Error("Expected error for canceled context, got nil")
}
}
// ========================================
// Integration Tests (require running Synapse)
// ========================================
// These tests are skipped by default as they require a running Matrix server
// Run with: go test -tags=integration ./...
func TestIntegration_HealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
service := NewMatrixService(Config{
HomeserverURL: "http://localhost:8008",
AccessToken: "", // Not needed for health check
ServerName: "breakpilot.local",
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := service.HealthCheck(ctx)
if err != nil {
t.Skipf("Matrix server not available: %v", err)
}
}

View File

@@ -0,0 +1,347 @@
package services
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
)
// NotificationType defines the type of notification
type NotificationType string
const (
NotificationTypeConsentRequired NotificationType = "consent_required"
NotificationTypeConsentReminder NotificationType = "consent_reminder"
NotificationTypeVersionPublished NotificationType = "version_published"
NotificationTypeVersionApproved NotificationType = "version_approved"
NotificationTypeVersionRejected NotificationType = "version_rejected"
NotificationTypeAccountSuspended NotificationType = "account_suspended"
NotificationTypeAccountRestored NotificationType = "account_restored"
NotificationTypeGeneral NotificationType = "general"
// DSR (Data Subject Request) notification types
NotificationTypeDSRReceived NotificationType = "dsr_received"
NotificationTypeDSRAssigned NotificationType = "dsr_assigned"
NotificationTypeDSRDeadline NotificationType = "dsr_deadline"
)
// NotificationChannel defines how notification is delivered
type NotificationChannel string
const (
ChannelInApp NotificationChannel = "in_app"
ChannelEmail NotificationChannel = "email"
ChannelPush NotificationChannel = "push"
)
// Notification represents a notification entity
type Notification struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Type NotificationType `json:"type"`
Channel NotificationChannel `json:"channel"`
Title string `json:"title"`
Body string `json:"body"`
Data map[string]interface{} `json:"data,omitempty"`
ReadAt *time.Time `json:"read_at,omitempty"`
SentAt *time.Time `json:"sent_at,omitempty"`
CreatedAt time.Time `json:"created_at"`
}
// NotificationPreferences holds user notification settings
type NotificationPreferences struct {
UserID uuid.UUID `json:"user_id"`
EmailEnabled bool `json:"email_enabled"`
PushEnabled bool `json:"push_enabled"`
InAppEnabled bool `json:"in_app_enabled"`
ReminderFrequency string `json:"reminder_frequency"`
}
// NotificationService handles notification operations
type NotificationService struct {
pool *pgxpool.Pool
emailService *EmailService
}
// NewNotificationService creates a new notification service
func NewNotificationService(pool *pgxpool.Pool, emailService *EmailService) *NotificationService {
return &NotificationService{
pool: pool,
emailService: emailService,
}
}
// CreateNotification creates and optionally sends a notification
func (s *NotificationService) CreateNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) error {
// Get user preferences
prefs, err := s.GetPreferences(ctx, userID)
if err != nil {
// Use default preferences if not found
prefs = &NotificationPreferences{
UserID: userID,
EmailEnabled: true,
PushEnabled: true,
InAppEnabled: true,
}
}
// Create in-app notification if enabled
if prefs.InAppEnabled {
if err := s.createInAppNotification(ctx, userID, notifType, title, body, data); err != nil {
return fmt.Errorf("failed to create in-app notification: %w", err)
}
}
// Send email notification if enabled
if prefs.EmailEnabled && s.emailService != nil {
go s.sendEmailNotification(ctx, userID, notifType, title, body, data)
}
// Push notification would be sent here if enabled
// if prefs.PushEnabled {
// go s.sendPushNotification(ctx, userID, title, body, data)
// }
return nil
}
// createInAppNotification creates an in-app notification
func (s *NotificationService) createInAppNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) error {
dataJSON, _ := json.Marshal(data)
_, err := s.pool.Exec(ctx, `
INSERT INTO notifications (user_id, type, channel, title, body, data, created_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW())
`, userID, notifType, ChannelInApp, title, body, dataJSON)
return err
}
// sendEmailNotification sends an email notification
func (s *NotificationService) sendEmailNotification(ctx context.Context, userID uuid.UUID, notifType NotificationType, title, body string, data map[string]interface{}) {
// Get user email
var email string
err := s.pool.QueryRow(ctx, `SELECT email FROM users WHERE id = $1`, userID).Scan(&email)
if err != nil {
return
}
// Send based on notification type
switch notifType {
case NotificationTypeConsentRequired, NotificationTypeConsentReminder:
s.emailService.SendConsentReminderEmail(email, title, body)
default:
s.emailService.SendGenericNotificationEmail(email, title, body)
}
// Mark as sent
s.pool.Exec(ctx, `
UPDATE notifications SET sent_at = NOW()
WHERE user_id = $1 AND type = $2 AND channel = $3 AND sent_at IS NULL
ORDER BY created_at DESC LIMIT 1
`, userID, notifType, ChannelEmail)
}
// GetUserNotifications returns notifications for a user
func (s *NotificationService) GetUserNotifications(ctx context.Context, userID uuid.UUID, limit, offset int, unreadOnly bool) ([]Notification, int, error) {
// Count total
var totalQuery string
var total int
if unreadOnly {
totalQuery = `SELECT COUNT(*) FROM notifications WHERE user_id = $1 AND read_at IS NULL`
} else {
totalQuery = `SELECT COUNT(*) FROM notifications WHERE user_id = $1`
}
s.pool.QueryRow(ctx, totalQuery, userID).Scan(&total)
// Get notifications
var query string
if unreadOnly {
query = `
SELECT id, user_id, type, channel, title, body, data, read_at, sent_at, created_at
FROM notifications
WHERE user_id = $1 AND read_at IS NULL
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
} else {
query = `
SELECT id, user_id, type, channel, title, body, data, read_at, sent_at, created_at
FROM notifications
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`
}
rows, err := s.pool.Query(ctx, query, userID, limit, offset)
if err != nil {
return nil, 0, err
}
defer rows.Close()
var notifications []Notification
for rows.Next() {
var n Notification
var dataJSON []byte
if err := rows.Scan(&n.ID, &n.UserID, &n.Type, &n.Channel, &n.Title, &n.Body, &dataJSON, &n.ReadAt, &n.SentAt, &n.CreatedAt); err != nil {
continue
}
if dataJSON != nil {
json.Unmarshal(dataJSON, &n.Data)
}
notifications = append(notifications, n)
}
return notifications, total, nil
}
// GetUnreadCount returns the count of unread notifications
func (s *NotificationService) GetUnreadCount(ctx context.Context, userID uuid.UUID) (int, error) {
var count int
err := s.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM notifications WHERE user_id = $1 AND read_at IS NULL
`, userID).Scan(&count)
return count, err
}
// MarkAsRead marks a notification as read
func (s *NotificationService) MarkAsRead(ctx context.Context, userID uuid.UUID, notificationID uuid.UUID) error {
result, err := s.pool.Exec(ctx, `
UPDATE notifications SET read_at = NOW()
WHERE id = $1 AND user_id = $2 AND read_at IS NULL
`, notificationID, userID)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return fmt.Errorf("notification not found or already read")
}
return nil
}
// MarkAllAsRead marks all notifications as read for a user
func (s *NotificationService) MarkAllAsRead(ctx context.Context, userID uuid.UUID) error {
_, err := s.pool.Exec(ctx, `
UPDATE notifications SET read_at = NOW()
WHERE user_id = $1 AND read_at IS NULL
`, userID)
return err
}
// DeleteNotification deletes a notification
func (s *NotificationService) DeleteNotification(ctx context.Context, userID uuid.UUID, notificationID uuid.UUID) error {
result, err := s.pool.Exec(ctx, `
DELETE FROM notifications WHERE id = $1 AND user_id = $2
`, notificationID, userID)
if err != nil {
return err
}
if result.RowsAffected() == 0 {
return fmt.Errorf("notification not found")
}
return nil
}
// GetPreferences returns notification preferences for a user
func (s *NotificationService) GetPreferences(ctx context.Context, userID uuid.UUID) (*NotificationPreferences, error) {
var prefs NotificationPreferences
prefs.UserID = userID
err := s.pool.QueryRow(ctx, `
SELECT email_enabled, push_enabled, in_app_enabled, reminder_frequency
FROM notification_preferences
WHERE user_id = $1
`, userID).Scan(&prefs.EmailEnabled, &prefs.PushEnabled, &prefs.InAppEnabled, &prefs.ReminderFrequency)
if err != nil {
// Return defaults if not found
return &NotificationPreferences{
UserID: userID,
EmailEnabled: true,
PushEnabled: true,
InAppEnabled: true,
ReminderFrequency: "weekly",
}, nil
}
return &prefs, nil
}
// UpdatePreferences updates notification preferences for a user
func (s *NotificationService) UpdatePreferences(ctx context.Context, userID uuid.UUID, prefs *NotificationPreferences) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO notification_preferences (user_id, email_enabled, push_enabled, in_app_enabled, reminder_frequency, updated_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (user_id) DO UPDATE SET
email_enabled = $2,
push_enabled = $3,
in_app_enabled = $4,
reminder_frequency = $5,
updated_at = NOW()
`, userID, prefs.EmailEnabled, prefs.PushEnabled, prefs.InAppEnabled, prefs.ReminderFrequency)
return err
}
// NotifyConsentRequired sends consent required notifications to all active users
func (s *NotificationService) NotifyConsentRequired(ctx context.Context, documentName, versionID string) error {
// Get all active users
rows, err := s.pool.Query(ctx, `
SELECT id FROM users WHERE account_status = 'active'
`)
if err != nil {
return err
}
defer rows.Close()
title := "Neue Zustimmung erforderlich"
body := fmt.Sprintf("Eine neue Version von '%s' wurde veröffentlicht. Bitte überprüfen und bestätigen Sie diese.", documentName)
data := map[string]interface{}{
"version_id": versionID,
"document_name": documentName,
}
for rows.Next() {
var userID uuid.UUID
if err := rows.Scan(&userID); err != nil {
continue
}
go s.CreateNotification(ctx, userID, NotificationTypeConsentRequired, title, body, data)
}
return nil
}
// NotifyVersionApproved notifies the creator that their version was approved
func (s *NotificationService) NotifyVersionApproved(ctx context.Context, creatorID uuid.UUID, documentName, versionNumber, approverEmail string) error {
title := "Version genehmigt"
body := fmt.Sprintf("Ihre Version %s von '%s' wurde von %s genehmigt und kann nun veröffentlicht werden.", versionNumber, documentName, approverEmail)
data := map[string]interface{}{
"document_name": documentName,
"version_number": versionNumber,
"approver": approverEmail,
}
return s.CreateNotification(ctx, creatorID, NotificationTypeVersionApproved, title, body, data)
}
// NotifyVersionRejected notifies the creator that their version was rejected
func (s *NotificationService) NotifyVersionRejected(ctx context.Context, creatorID uuid.UUID, documentName, versionNumber, reason, rejecterEmail string) error {
title := "Version abgelehnt"
body := fmt.Sprintf("Ihre Version %s von '%s' wurde von %s abgelehnt. Grund: %s", versionNumber, documentName, rejecterEmail, reason)
data := map[string]interface{}{
"document_name": documentName,
"version_number": versionNumber,
"rejecter": rejecterEmail,
"reason": reason,
}
return s.CreateNotification(ctx, creatorID, NotificationTypeVersionRejected, title, body, data)
}

View File

@@ -0,0 +1,660 @@
package services
import (
"testing"
"time"
"github.com/google/uuid"
)
// TestNotificationService_CreateNotification tests notification creation
func TestNotificationService_CreateNotification(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
notifType NotificationType
title string
body string
data map[string]interface{}
expectError bool
}{
{
name: "valid notification",
userID: uuid.New(),
notifType: NotificationTypeConsentRequired,
title: "Consent Required",
body: "Please review and accept the new terms",
data: map[string]interface{}{"document_id": "123"},
expectError: false,
},
{
name: "notification without data",
userID: uuid.New(),
notifType: NotificationTypeGeneral,
title: "General Notification",
body: "This is a test",
data: nil,
expectError: false,
},
{
name: "empty user ID",
userID: uuid.Nil,
notifType: NotificationTypeGeneral,
title: "Test",
body: "Test",
expectError: true,
},
{
name: "empty title",
userID: uuid.New(),
notifType: NotificationTypeGeneral,
title: "",
body: "Test body",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.title == "" {
err = &ValidationError{Field: "title", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_NotificationTypes tests notification type validation
func TestNotificationService_NotificationTypes(t *testing.T) {
tests := []struct {
notifType NotificationType
isValid bool
}{
{NotificationTypeConsentRequired, true},
{NotificationTypeConsentReminder, true},
{NotificationTypeVersionPublished, true},
{NotificationTypeVersionApproved, true},
{NotificationTypeVersionRejected, true},
{NotificationTypeAccountSuspended, true},
{NotificationTypeAccountRestored, true},
{NotificationTypeGeneral, true},
{NotificationType("invalid_type"), false},
{NotificationType(""), false},
}
validTypes := map[NotificationType]bool{
NotificationTypeConsentRequired: true,
NotificationTypeConsentReminder: true,
NotificationTypeVersionPublished: true,
NotificationTypeVersionApproved: true,
NotificationTypeVersionRejected: true,
NotificationTypeAccountSuspended: true,
NotificationTypeAccountRestored: true,
NotificationTypeGeneral: true,
}
for _, tt := range tests {
t.Run(string(tt.notifType), func(t *testing.T) {
isValid := validTypes[tt.notifType]
if isValid != tt.isValid {
t.Errorf("Type %s: expected valid=%v, got %v", tt.notifType, tt.isValid, isValid)
}
})
}
}
// TestNotificationService_NotificationChannels tests channel validation
func TestNotificationService_NotificationChannels(t *testing.T) {
tests := []struct {
channel NotificationChannel
isValid bool
}{
{ChannelInApp, true},
{ChannelEmail, true},
{ChannelPush, true},
{NotificationChannel("sms"), false},
{NotificationChannel(""), false},
}
validChannels := map[NotificationChannel]bool{
ChannelInApp: true,
ChannelEmail: true,
ChannelPush: true,
}
for _, tt := range tests {
t.Run(string(tt.channel), func(t *testing.T) {
isValid := validChannels[tt.channel]
if isValid != tt.isValid {
t.Errorf("Channel %s: expected valid=%v, got %v", tt.channel, tt.isValid, isValid)
}
})
}
}
// TestNotificationService_GetUserNotifications tests retrieving notifications
func TestNotificationService_GetUserNotifications(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
limit int
offset int
unreadOnly bool
expectError bool
}{
{
name: "get all notifications",
userID: uuid.New(),
limit: 50,
offset: 0,
unreadOnly: false,
expectError: false,
},
{
name: "get unread only",
userID: uuid.New(),
limit: 50,
offset: 0,
unreadOnly: true,
expectError: false,
},
{
name: "with pagination",
userID: uuid.New(),
limit: 10,
offset: 20,
unreadOnly: false,
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
limit: 50,
offset: 0,
unreadOnly: false,
expectError: true,
},
{
name: "negative limit",
userID: uuid.New(),
limit: -1,
offset: 0,
unreadOnly: false,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
} else if tt.limit < 0 {
err = &ValidationError{Field: "limit", Message: "must be >= 0"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_MarkAsRead tests marking notifications as read
func TestNotificationService_MarkAsRead(t *testing.T) {
tests := []struct {
name string
notificationID uuid.UUID
userID uuid.UUID
expectError bool
}{
{
name: "mark valid notification as read",
notificationID: uuid.New(),
userID: uuid.New(),
expectError: false,
},
{
name: "invalid notification ID",
notificationID: uuid.Nil,
userID: uuid.New(),
expectError: true,
},
{
name: "invalid user ID",
notificationID: uuid.New(),
userID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.notificationID == uuid.Nil {
err = &ValidationError{Field: "notification ID", Message: "required"}
} else if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_GetPreferences tests retrieving user preferences
func TestNotificationService_GetPreferences(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
expectError bool
}{
{
name: "get valid user preferences",
userID: uuid.New(),
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_UpdatePreferences tests updating notification preferences
func TestNotificationService_UpdatePreferences(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
emailEnabled bool
pushEnabled bool
inAppEnabled bool
reminderFrequency string
expectError bool
}{
{
name: "enable all notifications",
userID: uuid.New(),
emailEnabled: true,
pushEnabled: true,
inAppEnabled: true,
reminderFrequency: "daily",
expectError: false,
},
{
name: "disable email notifications",
userID: uuid.New(),
emailEnabled: false,
pushEnabled: true,
inAppEnabled: true,
reminderFrequency: "weekly",
expectError: false,
},
{
name: "set reminder frequency to never",
userID: uuid.New(),
emailEnabled: true,
pushEnabled: false,
inAppEnabled: true,
reminderFrequency: "never",
expectError: false,
},
{
name: "invalid reminder frequency",
userID: uuid.New(),
emailEnabled: true,
pushEnabled: true,
inAppEnabled: true,
reminderFrequency: "hourly",
expectError: true,
},
}
validFrequencies := map[string]bool{
"daily": true,
"weekly": true,
"never": true,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if !validFrequencies[tt.reminderFrequency] {
err = &ValidationError{Field: "reminder_frequency", Message: "must be daily, weekly, or never"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_NotifyConsentRequired tests consent required notification
func TestNotificationService_NotifyConsentRequired(t *testing.T) {
tests := []struct {
name string
documentName string
versionID string
expectError bool
}{
{
name: "valid consent notification",
documentName: "Terms of Service",
versionID: uuid.New().String(),
expectError: false,
},
{
name: "empty document name",
documentName: "",
versionID: uuid.New().String(),
expectError: true,
},
{
name: "empty version ID",
documentName: "Privacy Policy",
versionID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.documentName == "" {
err = &ValidationError{Field: "document name", Message: "required"}
} else if tt.versionID == "" {
err = &ValidationError{Field: "version ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_DeleteNotification tests deleting notifications
func TestNotificationService_DeleteNotification(t *testing.T) {
tests := []struct {
name string
notificationID uuid.UUID
userID uuid.UUID
expectError bool
}{
{
name: "delete valid notification",
notificationID: uuid.New(),
userID: uuid.New(),
expectError: false,
},
{
name: "invalid notification ID",
notificationID: uuid.Nil,
userID: uuid.New(),
expectError: true,
},
{
name: "invalid user ID",
notificationID: uuid.New(),
userID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.notificationID == uuid.Nil {
err = &ValidationError{Field: "notification ID", Message: "required"}
} else if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_BatchMarkAsRead tests batch marking as read
func TestNotificationService_BatchMarkAsRead(t *testing.T) {
tests := []struct {
name string
notificationIDs []uuid.UUID
userID uuid.UUID
expectError bool
}{
{
name: "mark multiple notifications",
notificationIDs: []uuid.UUID{uuid.New(), uuid.New(), uuid.New()},
userID: uuid.New(),
expectError: false,
},
{
name: "empty list",
notificationIDs: []uuid.UUID{},
userID: uuid.New(),
expectError: false,
},
{
name: "invalid user ID",
notificationIDs: []uuid.UUID{uuid.New()},
userID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_GetUnreadCount tests getting unread count
func TestNotificationService_GetUnreadCount(t *testing.T) {
tests := []struct {
name string
userID uuid.UUID
expectError bool
}{
{
name: "get count for valid user",
userID: uuid.New(),
expectError: false,
},
{
name: "invalid user ID",
userID: uuid.Nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var err error
if tt.userID == uuid.Nil {
err = &ValidationError{Field: "user ID", Message: "required"}
}
if tt.expectError && err == nil {
t.Error("Expected error, got nil")
}
if !tt.expectError && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
// TestNotificationService_NotificationPriority tests notification priority
func TestNotificationService_NotificationPriority(t *testing.T) {
tests := []struct {
name string
notifType NotificationType
expectedPrio string
}{
{
name: "consent required - high priority",
notifType: NotificationTypeConsentRequired,
expectedPrio: "high",
},
{
name: "account suspended - critical",
notifType: NotificationTypeAccountSuspended,
expectedPrio: "critical",
},
{
name: "version published - normal",
notifType: NotificationTypeVersionPublished,
expectedPrio: "normal",
},
{
name: "general - low",
notifType: NotificationTypeGeneral,
expectedPrio: "low",
},
}
priorityMap := map[NotificationType]string{
NotificationTypeConsentRequired: "high",
NotificationTypeConsentReminder: "high",
NotificationTypeAccountSuspended: "critical",
NotificationTypeAccountRestored: "normal",
NotificationTypeVersionPublished: "normal",
NotificationTypeVersionApproved: "normal",
NotificationTypeVersionRejected: "normal",
NotificationTypeGeneral: "low",
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
priority := priorityMap[tt.notifType]
if priority != tt.expectedPrio {
t.Errorf("Expected priority %s, got %s", tt.expectedPrio, priority)
}
})
}
}
// TestNotificationService_ReminderFrequency tests reminder frequency logic
func TestNotificationService_ReminderFrequency(t *testing.T) {
now := time.Now()
tests := []struct {
name string
frequency string
lastReminder time.Time
shouldSend bool
}{
{
name: "daily - last sent yesterday",
frequency: "daily",
lastReminder: now.AddDate(0, 0, -1),
shouldSend: true,
},
{
name: "daily - last sent today",
frequency: "daily",
lastReminder: now.Add(-1 * time.Hour),
shouldSend: false,
},
{
name: "weekly - last sent 8 days ago",
frequency: "weekly",
lastReminder: now.AddDate(0, 0, -8),
shouldSend: true,
},
{
name: "weekly - last sent 5 days ago",
frequency: "weekly",
lastReminder: now.AddDate(0, 0, -5),
shouldSend: false,
},
{
name: "never - should not send",
frequency: "never",
lastReminder: now.AddDate(0, 0, -30),
shouldSend: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var shouldSend bool
switch tt.frequency {
case "daily":
daysSince := int(now.Sub(tt.lastReminder).Hours() / 24)
shouldSend = daysSince >= 1
case "weekly":
daysSince := int(now.Sub(tt.lastReminder).Hours() / 24)
shouldSend = daysSince >= 7
case "never":
shouldSend = false
}
if shouldSend != tt.shouldSend {
t.Errorf("Expected shouldSend=%v, got %v", tt.shouldSend, shouldSend)
}
})
}
}

View File

@@ -0,0 +1,524 @@
package services
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/breakpilot/consent-service/internal/models"
)
var (
ErrInvalidClient = errors.New("invalid_client")
ErrInvalidGrant = errors.New("invalid_grant")
ErrInvalidScope = errors.New("invalid_scope")
ErrInvalidRequest = errors.New("invalid_request")
ErrUnauthorizedClient = errors.New("unauthorized_client")
ErrAccessDenied = errors.New("access_denied")
ErrInvalidRedirectURI = errors.New("invalid redirect_uri")
ErrCodeExpired = errors.New("authorization code expired")
ErrCodeUsed = errors.New("authorization code already used")
ErrPKCERequired = errors.New("PKCE code_challenge required for public clients")
ErrPKCEVerifyFailed = errors.New("PKCE verification failed")
)
// OAuthService handles OAuth 2.0 Authorization Code Flow with PKCE
type OAuthService struct {
db *pgxpool.Pool
jwtSecret string
authCodeExpiration time.Duration
accessTokenExpiration time.Duration
refreshTokenExpiration time.Duration
}
// NewOAuthService creates a new OAuthService
func NewOAuthService(db *pgxpool.Pool, jwtSecret string) *OAuthService {
return &OAuthService{
db: db,
jwtSecret: jwtSecret,
authCodeExpiration: 10 * time.Minute, // Authorization codes expire quickly
accessTokenExpiration: time.Hour, // 1 hour
refreshTokenExpiration: 30 * 24 * time.Hour, // 30 days
}
}
// ValidateClient validates an OAuth client
func (s *OAuthService) ValidateClient(ctx context.Context, clientID string) (*models.OAuthClient, error) {
var client models.OAuthClient
var redirectURIsJSON, scopesJSON, grantTypesJSON []byte
err := s.db.QueryRow(ctx, `
SELECT id, client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, is_active, created_at
FROM oauth_clients WHERE client_id = $1
`, clientID).Scan(
&client.ID, &client.ClientID, &client.ClientSecret, &client.Name, &client.Description,
&redirectURIsJSON, &scopesJSON, &grantTypesJSON, &client.IsPublic, &client.IsActive, &client.CreatedAt,
)
if err != nil {
return nil, ErrInvalidClient
}
if !client.IsActive {
return nil, ErrInvalidClient
}
// Parse JSON arrays
json.Unmarshal(redirectURIsJSON, &client.RedirectURIs)
json.Unmarshal(scopesJSON, &client.Scopes)
json.Unmarshal(grantTypesJSON, &client.GrantTypes)
return &client, nil
}
// ValidateClientSecret validates client credentials for confidential clients
func (s *OAuthService) ValidateClientSecret(client *models.OAuthClient, clientSecret string) error {
if client.IsPublic {
// Public clients don't have a secret
return nil
}
if client.ClientSecret != clientSecret {
return ErrInvalidClient
}
return nil
}
// ValidateRedirectURI validates the redirect URI against registered URIs
func (s *OAuthService) ValidateRedirectURI(client *models.OAuthClient, redirectURI string) error {
for _, uri := range client.RedirectURIs {
if uri == redirectURI {
return nil
}
}
return ErrInvalidRedirectURI
}
// ValidateScopes validates requested scopes against client's allowed scopes
func (s *OAuthService) ValidateScopes(client *models.OAuthClient, requestedScopes string) ([]string, error) {
if requestedScopes == "" {
// Return default scopes
return []string{"openid", "profile", "email"}, nil
}
requested := strings.Split(requestedScopes, " ")
allowedMap := make(map[string]bool)
for _, scope := range client.Scopes {
allowedMap[scope] = true
}
var validScopes []string
for _, scope := range requested {
if allowedMap[scope] {
validScopes = append(validScopes, scope)
}
}
if len(validScopes) == 0 {
return nil, ErrInvalidScope
}
return validScopes, nil
}
// GenerateAuthorizationCode generates a new authorization code
func (s *OAuthService) GenerateAuthorizationCode(
ctx context.Context,
client *models.OAuthClient,
userID uuid.UUID,
redirectURI string,
scopes []string,
codeChallenge, codeChallengeMethod string,
) (string, error) {
// For public clients, PKCE is required
if client.IsPublic && codeChallenge == "" {
return "", ErrPKCERequired
}
// Generate a secure random code
codeBytes := make([]byte, 32)
if _, err := rand.Read(codeBytes); err != nil {
return "", fmt.Errorf("failed to generate code: %w", err)
}
code := base64.URLEncoding.EncodeToString(codeBytes)
// Hash the code for storage
codeHash := sha256.Sum256([]byte(code))
hashedCode := hex.EncodeToString(codeHash[:])
scopesJSON, _ := json.Marshal(scopes)
var challengePtr, methodPtr *string
if codeChallenge != "" {
challengePtr = &codeChallenge
if codeChallengeMethod == "" {
codeChallengeMethod = "plain"
}
methodPtr = &codeChallengeMethod
}
_, err := s.db.Exec(ctx, `
INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
`, hashedCode, client.ClientID, userID, redirectURI, scopesJSON, challengePtr, methodPtr, time.Now().Add(s.authCodeExpiration))
if err != nil {
return "", fmt.Errorf("failed to store authorization code: %w", err)
}
return code, nil
}
// ExchangeAuthorizationCode exchanges an authorization code for tokens
func (s *OAuthService) ExchangeAuthorizationCode(
ctx context.Context,
code string,
clientID string,
redirectURI string,
codeVerifier string,
) (*models.OAuthTokenResponse, error) {
// Hash the code to look it up
codeHash := sha256.Sum256([]byte(code))
hashedCode := hex.EncodeToString(codeHash[:])
var authCode models.OAuthAuthorizationCode
var scopesJSON []byte
err := s.db.QueryRow(ctx, `
SELECT id, client_id, user_id, redirect_uri, scopes, code_challenge, code_challenge_method, expires_at, used_at
FROM oauth_authorization_codes WHERE code = $1
`, hashedCode).Scan(
&authCode.ID, &authCode.ClientID, &authCode.UserID, &authCode.RedirectURI,
&scopesJSON, &authCode.CodeChallenge, &authCode.CodeChallengeMethod,
&authCode.ExpiresAt, &authCode.UsedAt,
)
if err != nil {
return nil, ErrInvalidGrant
}
// Check if code was already used
if authCode.UsedAt != nil {
return nil, ErrCodeUsed
}
// Check if code is expired
if time.Now().After(authCode.ExpiresAt) {
return nil, ErrCodeExpired
}
// Verify client_id matches
if authCode.ClientID != clientID {
return nil, ErrInvalidGrant
}
// Verify redirect_uri matches
if authCode.RedirectURI != redirectURI {
return nil, ErrInvalidGrant
}
// Verify PKCE if code_challenge was provided
if authCode.CodeChallenge != nil && *authCode.CodeChallenge != "" {
if codeVerifier == "" {
return nil, ErrPKCEVerifyFailed
}
var expectedChallenge string
if authCode.CodeChallengeMethod != nil && *authCode.CodeChallengeMethod == "S256" {
// SHA256 hash of verifier
hash := sha256.Sum256([]byte(codeVerifier))
expectedChallenge = base64.RawURLEncoding.EncodeToString(hash[:])
} else {
// Plain method
expectedChallenge = codeVerifier
}
if expectedChallenge != *authCode.CodeChallenge {
return nil, ErrPKCEVerifyFailed
}
}
// Mark code as used
_, err = s.db.Exec(ctx, `UPDATE oauth_authorization_codes SET used_at = NOW() WHERE id = $1`, authCode.ID)
if err != nil {
return nil, fmt.Errorf("failed to mark code as used: %w", err)
}
// Parse scopes
var scopes []string
json.Unmarshal(scopesJSON, &scopes)
// Generate tokens
return s.generateTokens(ctx, clientID, authCode.UserID, scopes)
}
// RefreshAccessToken refreshes an access token using a refresh token
func (s *OAuthService) RefreshAccessToken(ctx context.Context, refreshToken, clientID string, requestedScope string) (*models.OAuthTokenResponse, error) {
// Hash the refresh token
tokenHash := sha256.Sum256([]byte(refreshToken))
hashedToken := hex.EncodeToString(tokenHash[:])
var rt models.OAuthRefreshToken
var scopesJSON []byte
err := s.db.QueryRow(ctx, `
SELECT id, client_id, user_id, scopes, expires_at, revoked_at
FROM oauth_refresh_tokens WHERE token_hash = $1
`, hashedToken).Scan(
&rt.ID, &rt.ClientID, &rt.UserID, &scopesJSON, &rt.ExpiresAt, &rt.RevokedAt,
)
if err != nil {
return nil, ErrInvalidGrant
}
// Check if token is revoked
if rt.RevokedAt != nil {
return nil, ErrInvalidGrant
}
// Check if token is expired
if time.Now().After(rt.ExpiresAt) {
return nil, ErrInvalidGrant
}
// Verify client_id matches
if rt.ClientID != clientID {
return nil, ErrInvalidGrant
}
// Parse original scopes
var originalScopes []string
json.Unmarshal(scopesJSON, &originalScopes)
// Determine scopes for new tokens
var scopes []string
if requestedScope != "" {
// Validate that requested scopes are subset of original scopes
originalMap := make(map[string]bool)
for _, s := range originalScopes {
originalMap[s] = true
}
for _, s := range strings.Split(requestedScope, " ") {
if originalMap[s] {
scopes = append(scopes, s)
}
}
if len(scopes) == 0 {
return nil, ErrInvalidScope
}
} else {
scopes = originalScopes
}
// Revoke old refresh token (rotate)
_, _ = s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE id = $1`, rt.ID)
// Generate new tokens
return s.generateTokens(ctx, clientID, rt.UserID, scopes)
}
// generateTokens generates access and refresh tokens
func (s *OAuthService) generateTokens(ctx context.Context, clientID string, userID uuid.UUID, scopes []string) (*models.OAuthTokenResponse, error) {
// Get user info for JWT
var user models.User
err := s.db.QueryRow(ctx, `
SELECT id, email, name, role, account_status FROM users WHERE id = $1
`, userID).Scan(&user.ID, &user.Email, &user.Name, &user.Role, &user.AccountStatus)
if err != nil {
return nil, ErrInvalidGrant
}
// Generate access token (JWT)
accessTokenClaims := jwt.MapClaims{
"sub": userID.String(),
"email": user.Email,
"role": user.Role,
"account_status": user.AccountStatus,
"client_id": clientID,
"scope": strings.Join(scopes, " "),
"iat": time.Now().Unix(),
"exp": time.Now().Add(s.accessTokenExpiration).Unix(),
"iss": "breakpilot-consent-service",
"aud": clientID,
}
if user.Name != nil {
accessTokenClaims["name"] = *user.Name
}
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessTokenClaims)
accessTokenString, err := accessToken.SignedString([]byte(s.jwtSecret))
if err != nil {
return nil, fmt.Errorf("failed to sign access token: %w", err)
}
// Hash access token for storage
accessTokenHash := sha256.Sum256([]byte(accessTokenString))
hashedAccessToken := hex.EncodeToString(accessTokenHash[:])
scopesJSON, _ := json.Marshal(scopes)
// Store access token
var accessTokenID uuid.UUID
err = s.db.QueryRow(ctx, `
INSERT INTO oauth_access_tokens (token_hash, client_id, user_id, scopes, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id
`, hashedAccessToken, clientID, userID, scopesJSON, time.Now().Add(s.accessTokenExpiration)).Scan(&accessTokenID)
if err != nil {
return nil, fmt.Errorf("failed to store access token: %w", err)
}
// Generate refresh token (opaque)
refreshTokenBytes := make([]byte, 32)
if _, err := rand.Read(refreshTokenBytes); err != nil {
return nil, fmt.Errorf("failed to generate refresh token: %w", err)
}
refreshTokenString := base64.URLEncoding.EncodeToString(refreshTokenBytes)
// Hash refresh token for storage
refreshTokenHash := sha256.Sum256([]byte(refreshTokenString))
hashedRefreshToken := hex.EncodeToString(refreshTokenHash[:])
// Store refresh token
_, err = s.db.Exec(ctx, `
INSERT INTO oauth_refresh_tokens (token_hash, access_token_id, client_id, user_id, scopes, expires_at)
VALUES ($1, $2, $3, $4, $5, $6)
`, hashedRefreshToken, accessTokenID, clientID, userID, scopesJSON, time.Now().Add(s.refreshTokenExpiration))
if err != nil {
return nil, fmt.Errorf("failed to store refresh token: %w", err)
}
return &models.OAuthTokenResponse{
AccessToken: accessTokenString,
TokenType: "Bearer",
ExpiresIn: int(s.accessTokenExpiration.Seconds()),
RefreshToken: refreshTokenString,
Scope: strings.Join(scopes, " "),
}, nil
}
// RevokeToken revokes an access or refresh token
func (s *OAuthService) RevokeToken(ctx context.Context, token, tokenTypeHint string) error {
tokenHash := sha256.Sum256([]byte(token))
hashedToken := hex.EncodeToString(tokenHash[:])
// Try to revoke as access token
if tokenTypeHint == "" || tokenTypeHint == "access_token" {
result, err := s.db.Exec(ctx, `UPDATE oauth_access_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
if err == nil && result.RowsAffected() > 0 {
return nil
}
}
// Try to revoke as refresh token
if tokenTypeHint == "" || tokenTypeHint == "refresh_token" {
result, err := s.db.Exec(ctx, `UPDATE oauth_refresh_tokens SET revoked_at = NOW() WHERE token_hash = $1`, hashedToken)
if err == nil && result.RowsAffected() > 0 {
return nil
}
}
return nil // RFC 7009: Always return success
}
// ValidateAccessToken validates an OAuth access token
func (s *OAuthService) ValidateAccessToken(ctx context.Context, tokenString string) (*jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(s.jwtSecret), nil
})
if err != nil {
return nil, ErrInvalidToken
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return nil, ErrInvalidToken
}
// Check if token is revoked in database
tokenHash := sha256.Sum256([]byte(tokenString))
hashedToken := hex.EncodeToString(tokenHash[:])
var revokedAt *time.Time
err = s.db.QueryRow(ctx, `SELECT revoked_at FROM oauth_access_tokens WHERE token_hash = $1`, hashedToken).Scan(&revokedAt)
if err == nil && revokedAt != nil {
return nil, ErrInvalidToken
}
return &claims, nil
}
// GetClientByID retrieves an OAuth client by its client_id
func (s *OAuthService) GetClientByID(ctx context.Context, clientID string) (*models.OAuthClient, error) {
return s.ValidateClient(ctx, clientID)
}
// CreateClient creates a new OAuth client (admin only)
func (s *OAuthService) CreateClient(
ctx context.Context,
name, description string,
redirectURIs, scopes, grantTypes []string,
isPublic bool,
createdBy *uuid.UUID,
) (*models.OAuthClient, string, error) {
// Generate client_id
clientIDBytes := make([]byte, 16)
rand.Read(clientIDBytes)
clientID := hex.EncodeToString(clientIDBytes)
// Generate client_secret for confidential clients
var clientSecret string
var clientSecretPtr *string
if !isPublic {
secretBytes := make([]byte, 32)
rand.Read(secretBytes)
clientSecret = base64.URLEncoding.EncodeToString(secretBytes)
clientSecretPtr = &clientSecret
}
redirectURIsJSON, _ := json.Marshal(redirectURIs)
scopesJSON, _ := json.Marshal(scopes)
grantTypesJSON, _ := json.Marshal(grantTypes)
var client models.OAuthClient
err := s.db.QueryRow(ctx, `
INSERT INTO oauth_clients (client_id, client_secret, name, description, redirect_uris, scopes, grant_types, is_public, created_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id, client_id, name, is_public, is_active, created_at
`, clientID, clientSecretPtr, name, description, redirectURIsJSON, scopesJSON, grantTypesJSON, isPublic, createdBy).Scan(
&client.ID, &client.ClientID, &client.Name, &client.IsPublic, &client.IsActive, &client.CreatedAt,
)
if err != nil {
return nil, "", fmt.Errorf("failed to create client: %w", err)
}
client.RedirectURIs = redirectURIs
client.Scopes = scopes
client.GrantTypes = grantTypes
return &client, clientSecret, nil
}

View File

@@ -0,0 +1,855 @@
package services
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"strings"
"testing"
"time"
)
// TestPKCEVerification tests PKCE code_challenge and code_verifier validation
func TestPKCEVerification_S256_ValidVerifier(t *testing.T) {
// Generate a code_verifier
codeVerifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
// Calculate expected code_challenge (S256)
hash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
// Verify the challenge matches
verifierHash := sha256.Sum256([]byte(codeVerifier))
calculatedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
if calculatedChallenge != codeChallenge {
t.Errorf("PKCE verification failed: expected %s, got %s", codeChallenge, calculatedChallenge)
}
}
func TestPKCEVerification_S256_InvalidVerifier(t *testing.T) {
codeVerifier := "correct-verifier-12345678901234567890"
wrongVerifier := "wrong-verifier-00000000000000000000"
// Calculate code_challenge from correct verifier
hash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
// Calculate challenge from wrong verifier
wrongHash := sha256.Sum256([]byte(wrongVerifier))
wrongChallenge := base64.RawURLEncoding.EncodeToString(wrongHash[:])
if wrongChallenge == codeChallenge {
t.Error("PKCE verification should fail for wrong verifier")
}
}
func TestPKCEVerification_Plain_ValidVerifier(t *testing.T) {
codeVerifier := "plain-text-verifier-12345"
codeChallenge := codeVerifier // Plain method: challenge = verifier
if codeVerifier != codeChallenge {
t.Error("Plain PKCE verification failed")
}
}
// TestTokenHashing tests that token hashing is consistent
func TestTokenHashing_Consistency(t *testing.T) {
token := "sample-access-token-12345"
hash1 := sha256.Sum256([]byte(token))
hash2 := sha256.Sum256([]byte(token))
if hash1 != hash2 {
t.Error("Token hashing should be consistent")
}
}
func TestTokenHashing_DifferentTokens(t *testing.T) {
token1 := "token-1-abcdefgh"
token2 := "token-2-ijklmnop"
hash1 := sha256.Sum256([]byte(token1))
hash2 := sha256.Sum256([]byte(token2))
if hash1 == hash2 {
t.Error("Different tokens should produce different hashes")
}
}
// TestScopeValidation tests scope parsing and validation
func TestScopeValidation_ParseScopes(t *testing.T) {
tests := []struct {
name string
requestedScope string
allowedScopes []string
expectedCount int
}{
{
name: "all scopes allowed",
requestedScope: "openid profile email",
allowedScopes: []string{"openid", "profile", "email", "offline_access"},
expectedCount: 3,
},
{
name: "some scopes allowed",
requestedScope: "openid profile admin",
allowedScopes: []string{"openid", "profile", "email"},
expectedCount: 2, // admin not allowed
},
{
name: "no scopes allowed",
requestedScope: "admin superuser",
allowedScopes: []string{"openid", "profile", "email"},
expectedCount: 0,
},
{
name: "empty request defaults",
requestedScope: "",
allowedScopes: []string{"openid", "profile", "email"},
expectedCount: 0, // Empty request returns 0 from this test logic
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.requestedScope == "" {
// Empty scope should use defaults in actual service
return
}
allowedMap := make(map[string]bool)
for _, scope := range tt.allowedScopes {
allowedMap[scope] = true
}
var validScopes []string
requestedScopes := splitScopes(tt.requestedScope)
for _, scope := range requestedScopes {
if allowedMap[scope] {
validScopes = append(validScopes, scope)
}
}
if len(validScopes) != tt.expectedCount {
t.Errorf("Expected %d valid scopes, got %d", tt.expectedCount, len(validScopes))
}
})
}
}
// Helper function for scope splitting
func splitScopes(scopes string) []string {
if scopes == "" {
return nil
}
var result []string
start := 0
for i := 0; i <= len(scopes); i++ {
if i == len(scopes) || scopes[i] == ' ' {
if start < i {
result = append(result, scopes[start:i])
}
start = i + 1
}
}
return result
}
// TestRedirectURIValidation tests redirect URI validation
func TestRedirectURIValidation(t *testing.T) {
tests := []struct {
name string
registeredURIs []string
requestURI string
shouldMatch bool
}{
{
name: "exact match",
registeredURIs: []string{"https://example.com/callback"},
requestURI: "https://example.com/callback",
shouldMatch: true,
},
{
name: "no match different domain",
registeredURIs: []string{"https://example.com/callback"},
requestURI: "https://evil.com/callback",
shouldMatch: false,
},
{
name: "no match different path",
registeredURIs: []string{"https://example.com/callback"},
requestURI: "https://example.com/other",
shouldMatch: false,
},
{
name: "multiple URIs - second matches",
registeredURIs: []string{"https://example.com/callback", "https://example.com/auth"},
requestURI: "https://example.com/auth",
shouldMatch: true,
},
{
name: "localhost for development",
registeredURIs: []string{"http://localhost:3000/callback"},
requestURI: "http://localhost:3000/callback",
shouldMatch: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matched := false
for _, uri := range tt.registeredURIs {
if uri == tt.requestURI {
matched = true
break
}
}
if matched != tt.shouldMatch {
t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matched)
}
})
}
}
// TestGrantTypeValidation tests grant type validation
func TestGrantTypeValidation(t *testing.T) {
tests := []struct {
name string
allowedGrants []string
requestedGrant string
shouldAllow bool
}{
{
name: "authorization_code allowed",
allowedGrants: []string{"authorization_code", "refresh_token"},
requestedGrant: "authorization_code",
shouldAllow: true,
},
{
name: "refresh_token allowed",
allowedGrants: []string{"authorization_code", "refresh_token"},
requestedGrant: "refresh_token",
shouldAllow: true,
},
{
name: "password not allowed",
allowedGrants: []string{"authorization_code", "refresh_token"},
requestedGrant: "password",
shouldAllow: false,
},
{
name: "client_credentials not allowed",
allowedGrants: []string{"authorization_code", "refresh_token"},
requestedGrant: "client_credentials",
shouldAllow: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allowed := false
for _, grant := range tt.allowedGrants {
if grant == tt.requestedGrant {
allowed = true
break
}
}
if allowed != tt.shouldAllow {
t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, allowed)
}
})
}
}
// TestAuthorizationCodeExpiry tests that expired codes should be rejected
func TestAuthorizationCodeExpiry_Logic(t *testing.T) {
tests := []struct {
name string
expiryMins int
usedAfter int // minutes after creation
shouldAllow bool
}{
{
name: "code used within expiry",
expiryMins: 10,
usedAfter: 5,
shouldAllow: true,
},
{
name: "code used at expiry boundary",
expiryMins: 10,
usedAfter: 10,
shouldAllow: false, // Expired at exactly 10 mins
},
{
name: "code used after expiry",
expiryMins: 10,
usedAfter: 15,
shouldAllow: false,
},
{
name: "code used immediately",
expiryMins: 10,
usedAfter: 0,
shouldAllow: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.usedAfter < tt.expiryMins
if isValid != tt.shouldAllow {
t.Errorf("Expected allow=%v for code used after %d mins (expiry: %d mins)",
tt.shouldAllow, tt.usedAfter, tt.expiryMins)
}
})
}
}
// TestClientSecretValidation tests confidential client authentication
func TestClientSecretValidation(t *testing.T) {
tests := []struct {
name string
isPublic bool
storedSecret string
providedSecret string
shouldAllow bool
}{
{
name: "public client - no secret needed",
isPublic: true,
storedSecret: "",
providedSecret: "",
shouldAllow: true,
},
{
name: "confidential client - correct secret",
isPublic: false,
storedSecret: "super-secret-123",
providedSecret: "super-secret-123",
shouldAllow: true,
},
{
name: "confidential client - wrong secret",
isPublic: false,
storedSecret: "super-secret-123",
providedSecret: "wrong-secret",
shouldAllow: false,
},
{
name: "confidential client - empty secret",
isPublic: false,
storedSecret: "super-secret-123",
providedSecret: "",
shouldAllow: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var isValid bool
if tt.isPublic {
isValid = true
} else {
isValid = tt.storedSecret == tt.providedSecret
}
if isValid != tt.shouldAllow {
t.Errorf("Expected allow=%v, got allow=%v", tt.shouldAllow, isValid)
}
})
}
}
// ========================================
// Extended OAuth 2.0 Tests
// ========================================
// TestCodeVerifierGeneration tests that code verifiers meet RFC 7636 requirements
func TestCodeVerifierGeneration_RFC7636(t *testing.T) {
tests := []struct {
name string
length int
expectedLength int
description string
}{
{"minimum length (43)", 43, 43, "RFC 7636 minimum"},
{"standard length (64)", 64, 64, "Recommended length"},
{"maximum length (128)", 128, 128, "RFC 7636 maximum"},
{"too short (42) - corrected to minimum", 42, 43, "Should be corrected to minimum"},
{"too long (129) - corrected to maximum", 129, 128, "Should be corrected to maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
verifier := generateCodeVerifier(tt.length)
// Check that length is corrected to valid range
if len(verifier) != tt.expectedLength {
t.Errorf("Expected length %d, got %d", tt.expectedLength, len(verifier))
}
// Check character set (unreserved characters only: A-Z, a-z, 0-9, -, ., _, ~)
for _, c := range verifier {
if !isUnreservedChar(c) {
t.Errorf("Code verifier contains invalid character: %c", c)
}
}
})
}
}
// TestCodeVerifierLength_Validation tests length validation logic
func TestCodeVerifierLength_Validation(t *testing.T) {
tests := []struct {
name string
length int
isValid bool
}{
{"length 42 - too short", 42, false},
{"length 43 - minimum valid", 43, true},
{"length 64 - recommended", 64, true},
{"length 128 - maximum valid", 128, true},
{"length 129 - too long", 129, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.length >= 43 && tt.length <= 128
if isValid != tt.isValid {
t.Errorf("Expected valid=%v for length %d, got valid=%v",
tt.isValid, tt.length, isValid)
}
})
}
}
// generateCodeVerifier generates a code verifier of specified length
func generateCodeVerifier(length int) string {
// Ensure minimum and maximum bounds
if length < 43 {
length = 43
}
if length > 128 {
length = 128
}
const unreserved = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"
bytes := make([]byte, length)
rand.Read(bytes)
result := make([]byte, length)
for i, b := range bytes {
result[i] = unreserved[int(b)%len(unreserved)]
}
return string(result)
}
// isUnreservedChar checks if a character is an unreserved character per RFC 3986
func isUnreservedChar(c rune) bool {
return (c >= 'A' && c <= 'Z') ||
(c >= 'a' && c <= 'z') ||
(c >= '0' && c <= '9') ||
c == '-' || c == '.' || c == '_' || c == '~'
}
// TestCodeChallengeGeneration tests S256 challenge generation
func TestCodeChallengeGeneration_S256(t *testing.T) {
// Known test vector from RFC 7636 Appendix B
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
expectedChallenge := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
if challenge != expectedChallenge {
t.Errorf("S256 challenge mismatch: expected %s, got %s", expectedChallenge, challenge)
}
}
// TestRefreshTokenRotation tests that refresh tokens are rotated on use
func TestRefreshTokenRotation_Logic(t *testing.T) {
// Simulate refresh token rotation
oldToken := "old-refresh-token-123"
oldTokenHash := hashToken(oldToken)
// Generate new token
newToken := generateSecureToken(32)
newTokenHash := hashToken(newToken)
// Verify tokens are different
if oldTokenHash == newTokenHash {
t.Error("New refresh token should be different from old token")
}
// Verify old token would be revoked (simulated by marking revoked_at)
oldTokenRevoked := true
if !oldTokenRevoked {
t.Error("Old refresh token should be revoked after rotation")
}
}
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func generateSecureToken(length int) string {
bytes := make([]byte, length)
rand.Read(bytes)
return base64.URLEncoding.EncodeToString(bytes)
}
// TestAccessTokenExpiry tests access token expiration handling
func TestAccessTokenExpiry_Scenarios(t *testing.T) {
tests := []struct {
name string
tokenDuration time.Duration
usedAfter time.Duration
shouldBeValid bool
}{
{
name: "token used immediately",
tokenDuration: time.Hour,
usedAfter: 0,
shouldBeValid: true,
},
{
name: "token used within validity",
tokenDuration: time.Hour,
usedAfter: 30 * time.Minute,
shouldBeValid: true,
},
{
name: "token used at expiry",
tokenDuration: time.Hour,
usedAfter: time.Hour,
shouldBeValid: false,
},
{
name: "token used after expiry",
tokenDuration: time.Hour,
usedAfter: 2 * time.Hour,
shouldBeValid: false,
},
{
name: "short-lived token",
tokenDuration: 5 * time.Minute,
usedAfter: 6 * time.Minute,
shouldBeValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
issuedAt := time.Now()
expiresAt := issuedAt.Add(tt.tokenDuration)
usedAt := issuedAt.Add(tt.usedAfter)
isValid := usedAt.Before(expiresAt)
if isValid != tt.shouldBeValid {
t.Errorf("Expected valid=%v for token used after %v (duration: %v)",
tt.shouldBeValid, tt.usedAfter, tt.tokenDuration)
}
})
}
}
// TestOAuthErrors tests that OAuth error codes are correct
func TestOAuthErrors_RFC6749(t *testing.T) {
tests := []struct {
scenario string
errorCode string
description string
}{
{"invalid client_id", "invalid_client", "Client authentication failed"},
{"invalid grant (code)", "invalid_grant", "Authorization code invalid or expired"},
{"invalid scope", "invalid_scope", "Requested scope is invalid"},
{"invalid request", "invalid_request", "Request is missing required parameter"},
{"unauthorized client", "unauthorized_client", "Client not authorized for this grant type"},
{"access denied", "access_denied", "Resource owner denied the request"},
}
for _, tt := range tests {
t.Run(tt.scenario, func(t *testing.T) {
// Verify error codes match RFC 6749 Section 5.2
validErrors := map[string]bool{
"invalid_request": true,
"invalid_client": true,
"invalid_grant": true,
"unauthorized_client": true,
"unsupported_grant_type": true,
"invalid_scope": true,
"access_denied": true,
"unsupported_response_type": true,
"server_error": true,
"temporarily_unavailable": true,
}
if !validErrors[tt.errorCode] {
t.Errorf("Error code %s is not a valid OAuth 2.0 error code", tt.errorCode)
}
})
}
}
// TestStateParameter tests state parameter handling for CSRF protection
func TestStateParameter_CSRF(t *testing.T) {
tests := []struct {
name string
requestState string
responseState string
shouldMatch bool
}{
{
name: "matching state",
requestState: "abc123xyz",
responseState: "abc123xyz",
shouldMatch: true,
},
{
name: "non-matching state",
requestState: "abc123xyz",
responseState: "different",
shouldMatch: false,
},
{
name: "empty request state",
requestState: "",
responseState: "abc123xyz",
shouldMatch: false,
},
{
name: "empty response state",
requestState: "abc123xyz",
responseState: "",
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := tt.requestState != "" && tt.requestState == tt.responseState
if matches != tt.shouldMatch {
t.Errorf("Expected match=%v, got match=%v", tt.shouldMatch, matches)
}
})
}
}
// TestResponseType tests response_type validation
func TestResponseType_Validation(t *testing.T) {
tests := []struct {
name string
responseType string
isValid bool
}{
{"code - valid", "code", true},
{"token - implicit flow (disabled)", "token", false},
{"id_token - OIDC", "id_token", false},
{"code token - hybrid", "code token", false},
{"empty", "", false},
{"invalid", "password", false},
}
supportedResponseTypes := map[string]bool{
"code": true, // Only authorization code flow is supported
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := supportedResponseTypes[tt.responseType]
if isValid != tt.isValid {
t.Errorf("Expected valid=%v for response_type=%s, got valid=%v",
tt.isValid, tt.responseType, isValid)
}
})
}
}
// TestCodeChallengeMethod tests code_challenge_method validation
func TestCodeChallengeMethod_Validation(t *testing.T) {
tests := []struct {
name string
method string
isValid bool
}{
{"S256 - recommended", "S256", true},
{"plain - discouraged but valid", "plain", true},
{"empty - defaults to plain", "", true},
{"sha512 - not supported", "sha512", false},
{"invalid", "md5", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.method == "S256" || tt.method == "plain" || tt.method == ""
if isValid != tt.isValid {
t.Errorf("Expected valid=%v for method=%s, got valid=%v",
tt.isValid, tt.method, isValid)
}
})
}
}
// TestTokenRevocation tests token revocation behavior per RFC 7009
func TestTokenRevocation_RFC7009(t *testing.T) {
tests := []struct {
name string
tokenExists bool
tokenRevoked bool
expectSuccess bool
}{
{
name: "revoke existing active token",
tokenExists: true,
tokenRevoked: false,
expectSuccess: true,
},
{
name: "revoke already revoked token",
tokenExists: true,
tokenRevoked: true,
expectSuccess: true, // RFC 7009: Always return 200
},
{
name: "revoke non-existent token",
tokenExists: false,
tokenRevoked: false,
expectSuccess: true, // RFC 7009: Always return 200
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Simulate revocation logic
// Per RFC 7009, revocation endpoint always returns 200 OK
success := true
if success != tt.expectSuccess {
t.Errorf("Expected success=%v, got success=%v", tt.expectSuccess, success)
}
})
}
}
// TestClientIDGeneration tests client_id format
func TestClientIDGeneration_Format(t *testing.T) {
// Generate multiple client IDs
clientIDs := make(map[string]bool)
for i := 0; i < 100; i++ {
bytes := make([]byte, 16)
rand.Read(bytes)
clientID := hex.EncodeToString(bytes)
// Check format (32 hex characters)
if len(clientID) != 32 {
t.Errorf("Client ID should be 32 characters, got %d", len(clientID))
}
// Check uniqueness
if clientIDs[clientID] {
t.Error("Client ID should be unique")
}
clientIDs[clientID] = true
// Check only hex characters
for _, c := range clientID {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("Client ID should only contain hex characters, found %c", c)
}
}
}
}
// TestScopeNormalization tests scope string normalization
func TestScopeNormalization(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "single scope",
input: "openid",
expected: []string{"openid"},
},
{
name: "multiple scopes",
input: "openid profile email",
expected: []string{"openid", "profile", "email"},
},
{
name: "extra spaces",
input: "openid profile email",
expected: []string{"openid", "profile", "email"},
},
{
name: "empty string",
input: "",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scopes := normalizeScopes(tt.input)
if len(scopes) != len(tt.expected) {
t.Errorf("Expected %d scopes, got %d", len(tt.expected), len(scopes))
return
}
for i, scope := range scopes {
if scope != tt.expected[i] {
t.Errorf("Expected scope[%d]=%s, got %s", i, tt.expected[i], scope)
}
}
})
}
}
func normalizeScopes(scope string) []string {
if scope == "" {
return []string{}
}
parts := strings.Fields(scope) // Handles multiple spaces
return parts
}
// BenchmarkPKCEVerification benchmarks PKCE S256 verification
func BenchmarkPKCEVerification_S256(b *testing.B) {
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
for i := 0; i < b.N; i++ {
hash := sha256.Sum256([]byte(verifier))
base64.RawURLEncoding.EncodeToString(hash[:])
}
}
// BenchmarkTokenHashing benchmarks token hashing for storage
func BenchmarkTokenHashing(b *testing.B) {
token := "sample-access-token-12345678901234567890"
for i := 0; i < b.N; i++ {
hash := sha256.Sum256([]byte(token))
hex.EncodeToString(hash[:])
}
}
// BenchmarkCodeVerifierGeneration benchmarks code verifier generation
func BenchmarkCodeVerifierGeneration(b *testing.B) {
for i := 0; i < b.N; i++ {
generateCodeVerifier(64)
}
}

View File

@@ -0,0 +1,698 @@
package services
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"time"
"github.com/breakpilot/consent-service/internal/database"
"github.com/breakpilot/consent-service/internal/models"
"github.com/breakpilot/consent-service/internal/services/matrix"
"github.com/google/uuid"
)
// SchoolService handles school management operations
type SchoolService struct {
db *database.DB
matrix *matrix.MatrixService
}
// NewSchoolService creates a new school service
func NewSchoolService(db *database.DB, matrixService *matrix.MatrixService) *SchoolService {
return &SchoolService{
db: db,
matrix: matrixService,
}
}
// ========================================
// School CRUD
// ========================================
// CreateSchool creates a new school
func (s *SchoolService) CreateSchool(ctx context.Context, req models.CreateSchoolRequest) (*models.School, error) {
school := &models.School{
ID: uuid.New(),
Name: req.Name,
ShortName: req.ShortName,
Type: req.Type,
Address: req.Address,
City: req.City,
PostalCode: req.PostalCode,
State: req.State,
Country: "DE",
Phone: req.Phone,
Email: req.Email,
Website: req.Website,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO schools (id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, is_active, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
school.ID, school.Name, school.ShortName, school.Type, school.Address,
school.City, school.PostalCode, school.State, school.Country, school.Phone,
school.Email, school.Website, school.IsActive, school.CreatedAt, school.UpdatedAt,
).Scan(&school.ID)
if err != nil {
return nil, fmt.Errorf("failed to create school: %w", err)
}
// Create default timetable slots for the school
if err := s.createDefaultTimetableSlots(ctx, school.ID); err != nil {
// Log but don't fail
fmt.Printf("Warning: failed to create default timetable slots: %v\n", err)
}
// Create default grade scale
if err := s.createDefaultGradeScale(ctx, school.ID); err != nil {
fmt.Printf("Warning: failed to create default grade scale: %v\n", err)
}
return school, nil
}
// GetSchool retrieves a school by ID
func (s *SchoolService) GetSchool(ctx context.Context, schoolID uuid.UUID) (*models.School, error) {
query := `
SELECT id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, matrix_server_name, logo_url, is_active, created_at, updated_at
FROM schools
WHERE id = $1`
school := &models.School{}
err := s.db.Pool.QueryRow(ctx, query, schoolID).Scan(
&school.ID, &school.Name, &school.ShortName, &school.Type, &school.Address,
&school.City, &school.PostalCode, &school.State, &school.Country, &school.Phone,
&school.Email, &school.Website, &school.MatrixServerName, &school.LogoURL,
&school.IsActive, &school.CreatedAt, &school.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get school: %w", err)
}
return school, nil
}
// ListSchools lists all active schools
func (s *SchoolService) ListSchools(ctx context.Context) ([]models.School, error) {
query := `
SELECT id, name, short_name, type, address, city, postal_code, state, country, phone, email, website, matrix_server_name, logo_url, is_active, created_at, updated_at
FROM schools
WHERE is_active = true
ORDER BY name`
rows, err := s.db.Pool.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to list schools: %w", err)
}
defer rows.Close()
var schools []models.School
for rows.Next() {
var school models.School
err := rows.Scan(
&school.ID, &school.Name, &school.ShortName, &school.Type, &school.Address,
&school.City, &school.PostalCode, &school.State, &school.Country, &school.Phone,
&school.Email, &school.Website, &school.MatrixServerName, &school.LogoURL,
&school.IsActive, &school.CreatedAt, &school.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan school: %w", err)
}
schools = append(schools, school)
}
return schools, nil
}
// ========================================
// School Year Management
// ========================================
// CreateSchoolYear creates a new school year
func (s *SchoolService) CreateSchoolYear(ctx context.Context, schoolID uuid.UUID, name string, startDate, endDate time.Time) (*models.SchoolYear, error) {
schoolYear := &models.SchoolYear{
ID: uuid.New(),
SchoolID: schoolID,
Name: name,
StartDate: startDate,
EndDate: endDate,
IsCurrent: false,
CreatedAt: time.Now(),
}
query := `
INSERT INTO school_years (id, school_id, name, start_date, end_date, is_current, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
schoolYear.ID, schoolYear.SchoolID, schoolYear.Name,
schoolYear.StartDate, schoolYear.EndDate, schoolYear.IsCurrent, schoolYear.CreatedAt,
).Scan(&schoolYear.ID)
if err != nil {
return nil, fmt.Errorf("failed to create school year: %w", err)
}
return schoolYear, nil
}
// SetCurrentSchoolYear sets a school year as the current one
func (s *SchoolService) SetCurrentSchoolYear(ctx context.Context, schoolID, schoolYearID uuid.UUID) error {
// First, unset all current school years for this school
_, err := s.db.Pool.Exec(ctx, `UPDATE school_years SET is_current = false WHERE school_id = $1`, schoolID)
if err != nil {
return fmt.Errorf("failed to unset current school years: %w", err)
}
// Then set the specified school year as current
_, err = s.db.Pool.Exec(ctx, `UPDATE school_years SET is_current = true WHERE id = $1 AND school_id = $2`, schoolYearID, schoolID)
if err != nil {
return fmt.Errorf("failed to set current school year: %w", err)
}
return nil
}
// GetCurrentSchoolYear gets the current school year for a school
func (s *SchoolService) GetCurrentSchoolYear(ctx context.Context, schoolID uuid.UUID) (*models.SchoolYear, error) {
query := `
SELECT id, school_id, name, start_date, end_date, is_current, created_at
FROM school_years
WHERE school_id = $1 AND is_current = true`
schoolYear := &models.SchoolYear{}
err := s.db.Pool.QueryRow(ctx, query, schoolID).Scan(
&schoolYear.ID, &schoolYear.SchoolID, &schoolYear.Name,
&schoolYear.StartDate, &schoolYear.EndDate, &schoolYear.IsCurrent, &schoolYear.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get current school year: %w", err)
}
return schoolYear, nil
}
// ========================================
// Class Management
// ========================================
// CreateClass creates a new class
func (s *SchoolService) CreateClass(ctx context.Context, schoolID uuid.UUID, req models.CreateClassRequest) (*models.Class, error) {
schoolYearID, err := uuid.Parse(req.SchoolYearID)
if err != nil {
return nil, fmt.Errorf("invalid school year ID: %w", err)
}
class := &models.Class{
ID: uuid.New(),
SchoolID: schoolID,
SchoolYearID: schoolYearID,
Name: req.Name,
Grade: req.Grade,
Section: req.Section,
Room: req.Room,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO classes (id, school_id, school_year_id, name, grade, section, room, is_active, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
class.ID, class.SchoolID, class.SchoolYearID, class.Name,
class.Grade, class.Section, class.Room, class.IsActive, class.CreatedAt, class.UpdatedAt,
).Scan(&class.ID)
if err != nil {
return nil, fmt.Errorf("failed to create class: %w", err)
}
return class, nil
}
// GetClass retrieves a class by ID
func (s *SchoolService) GetClass(ctx context.Context, classID uuid.UUID) (*models.Class, error) {
query := `
SELECT id, school_id, school_year_id, name, grade, section, room, matrix_info_room, matrix_rep_room, is_active, created_at, updated_at
FROM classes
WHERE id = $1`
class := &models.Class{}
err := s.db.Pool.QueryRow(ctx, query, classID).Scan(
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
&class.Grade, &class.Section, &class.Room, &class.MatrixInfoRoom,
&class.MatrixRepRoom, &class.IsActive, &class.CreatedAt, &class.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get class: %w", err)
}
return class, nil
}
// ListClasses lists all classes for a school in a school year
func (s *SchoolService) ListClasses(ctx context.Context, schoolID, schoolYearID uuid.UUID) ([]models.Class, error) {
query := `
SELECT id, school_id, school_year_id, name, grade, section, room, matrix_info_room, matrix_rep_room, is_active, created_at, updated_at
FROM classes
WHERE school_id = $1 AND school_year_id = $2 AND is_active = true
ORDER BY grade, name`
rows, err := s.db.Pool.Query(ctx, query, schoolID, schoolYearID)
if err != nil {
return nil, fmt.Errorf("failed to list classes: %w", err)
}
defer rows.Close()
var classes []models.Class
for rows.Next() {
var class models.Class
err := rows.Scan(
&class.ID, &class.SchoolID, &class.SchoolYearID, &class.Name,
&class.Grade, &class.Section, &class.Room, &class.MatrixInfoRoom,
&class.MatrixRepRoom, &class.IsActive, &class.CreatedAt, &class.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan class: %w", err)
}
classes = append(classes, class)
}
return classes, nil
}
// ========================================
// Student Management
// ========================================
// CreateStudent creates a new student
func (s *SchoolService) CreateStudent(ctx context.Context, schoolID uuid.UUID, req models.CreateStudentRequest) (*models.Student, error) {
classID, err := uuid.Parse(req.ClassID)
if err != nil {
return nil, fmt.Errorf("invalid class ID: %w", err)
}
student := &models.Student{
ID: uuid.New(),
SchoolID: schoolID,
ClassID: classID,
StudentNumber: req.StudentNumber,
FirstName: req.FirstName,
LastName: req.LastName,
Gender: req.Gender,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if req.DateOfBirth != nil {
dob, err := time.Parse("2006-01-02", *req.DateOfBirth)
if err == nil {
student.DateOfBirth = &dob
}
}
query := `
INSERT INTO students (id, school_id, class_id, student_number, first_name, last_name, date_of_birth, gender, is_active, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
RETURNING id`
err = s.db.Pool.QueryRow(ctx, query,
student.ID, student.SchoolID, student.ClassID, student.StudentNumber,
student.FirstName, student.LastName, student.DateOfBirth, student.Gender,
student.IsActive, student.CreatedAt, student.UpdatedAt,
).Scan(&student.ID)
if err != nil {
return nil, fmt.Errorf("failed to create student: %w", err)
}
return student, nil
}
// GetStudent retrieves a student by ID
func (s *SchoolService) GetStudent(ctx context.Context, studentID uuid.UUID) (*models.Student, error) {
query := `
SELECT id, school_id, class_id, user_id, student_number, first_name, last_name, date_of_birth, gender, matrix_user_id, matrix_dm_room, is_active, created_at, updated_at
FROM students
WHERE id = $1`
student := &models.Student{}
err := s.db.Pool.QueryRow(ctx, query, studentID).Scan(
&student.ID, &student.SchoolID, &student.ClassID, &student.UserID,
&student.StudentNumber, &student.FirstName, &student.LastName,
&student.DateOfBirth, &student.Gender, &student.MatrixUserID,
&student.MatrixDMRoom, &student.IsActive, &student.CreatedAt, &student.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get student: %w", err)
}
return student, nil
}
// ListStudentsByClass lists all students in a class
func (s *SchoolService) ListStudentsByClass(ctx context.Context, classID uuid.UUID) ([]models.Student, error) {
query := `
SELECT id, school_id, class_id, user_id, student_number, first_name, last_name, date_of_birth, gender, matrix_user_id, matrix_dm_room, is_active, created_at, updated_at
FROM students
WHERE class_id = $1 AND is_active = true
ORDER BY last_name, first_name`
rows, err := s.db.Pool.Query(ctx, query, classID)
if err != nil {
return nil, fmt.Errorf("failed to list students: %w", err)
}
defer rows.Close()
var students []models.Student
for rows.Next() {
var student models.Student
err := rows.Scan(
&student.ID, &student.SchoolID, &student.ClassID, &student.UserID,
&student.StudentNumber, &student.FirstName, &student.LastName,
&student.DateOfBirth, &student.Gender, &student.MatrixUserID,
&student.MatrixDMRoom, &student.IsActive, &student.CreatedAt, &student.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan student: %w", err)
}
students = append(students, student)
}
return students, nil
}
// ========================================
// Teacher Management
// ========================================
// CreateTeacher creates a new teacher linked to a user account
func (s *SchoolService) CreateTeacher(ctx context.Context, schoolID, userID uuid.UUID, firstName, lastName string, teacherCode, title *string) (*models.Teacher, error) {
teacher := &models.Teacher{
ID: uuid.New(),
SchoolID: schoolID,
UserID: userID,
TeacherCode: teacherCode,
Title: title,
FirstName: firstName,
LastName: lastName,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
query := `
INSERT INTO teachers (id, school_id, user_id, teacher_code, title, first_name, last_name, is_active, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
teacher.ID, teacher.SchoolID, teacher.UserID, teacher.TeacherCode,
teacher.Title, teacher.FirstName, teacher.LastName,
teacher.IsActive, teacher.CreatedAt, teacher.UpdatedAt,
).Scan(&teacher.ID)
if err != nil {
return nil, fmt.Errorf("failed to create teacher: %w", err)
}
return teacher, nil
}
// GetTeacher retrieves a teacher by ID
func (s *SchoolService) GetTeacher(ctx context.Context, teacherID uuid.UUID) (*models.Teacher, error) {
query := `
SELECT id, school_id, user_id, teacher_code, title, first_name, last_name, matrix_user_id, is_active, created_at, updated_at
FROM teachers
WHERE id = $1`
teacher := &models.Teacher{}
err := s.db.Pool.QueryRow(ctx, query, teacherID).Scan(
&teacher.ID, &teacher.SchoolID, &teacher.UserID, &teacher.TeacherCode,
&teacher.Title, &teacher.FirstName, &teacher.LastName, &teacher.MatrixUserID,
&teacher.IsActive, &teacher.CreatedAt, &teacher.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get teacher: %w", err)
}
return teacher, nil
}
// GetTeacherByUserID retrieves a teacher by their user ID
func (s *SchoolService) GetTeacherByUserID(ctx context.Context, userID uuid.UUID) (*models.Teacher, error) {
query := `
SELECT id, school_id, user_id, teacher_code, title, first_name, last_name, matrix_user_id, is_active, created_at, updated_at
FROM teachers
WHERE user_id = $1 AND is_active = true`
teacher := &models.Teacher{}
err := s.db.Pool.QueryRow(ctx, query, userID).Scan(
&teacher.ID, &teacher.SchoolID, &teacher.UserID, &teacher.TeacherCode,
&teacher.Title, &teacher.FirstName, &teacher.LastName, &teacher.MatrixUserID,
&teacher.IsActive, &teacher.CreatedAt, &teacher.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to get teacher by user ID: %w", err)
}
return teacher, nil
}
// AssignClassTeacher assigns a teacher to a class
func (s *SchoolService) AssignClassTeacher(ctx context.Context, classID, teacherID uuid.UUID, isPrimary bool) error {
query := `
INSERT INTO class_teachers (id, class_id, teacher_id, is_primary, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (class_id, teacher_id) DO UPDATE SET is_primary = EXCLUDED.is_primary`
_, err := s.db.Pool.Exec(ctx, query, uuid.New(), classID, teacherID, isPrimary, time.Now())
if err != nil {
return fmt.Errorf("failed to assign class teacher: %w", err)
}
return nil
}
// ========================================
// Subject Management
// ========================================
// CreateSubject creates a new subject
func (s *SchoolService) CreateSubject(ctx context.Context, schoolID uuid.UUID, name, shortName string, color *string) (*models.Subject, error) {
subject := &models.Subject{
ID: uuid.New(),
SchoolID: schoolID,
Name: name,
ShortName: shortName,
Color: color,
IsActive: true,
CreatedAt: time.Now(),
}
query := `
INSERT INTO subjects (id, school_id, name, short_name, color, is_active, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
subject.ID, subject.SchoolID, subject.Name, subject.ShortName,
subject.Color, subject.IsActive, subject.CreatedAt,
).Scan(&subject.ID)
if err != nil {
return nil, fmt.Errorf("failed to create subject: %w", err)
}
return subject, nil
}
// ListSubjects lists all subjects for a school
func (s *SchoolService) ListSubjects(ctx context.Context, schoolID uuid.UUID) ([]models.Subject, error) {
query := `
SELECT id, school_id, name, short_name, color, is_active, created_at
FROM subjects
WHERE school_id = $1 AND is_active = true
ORDER BY name`
rows, err := s.db.Pool.Query(ctx, query, schoolID)
if err != nil {
return nil, fmt.Errorf("failed to list subjects: %w", err)
}
defer rows.Close()
var subjects []models.Subject
for rows.Next() {
var subject models.Subject
err := rows.Scan(
&subject.ID, &subject.SchoolID, &subject.Name, &subject.ShortName,
&subject.Color, &subject.IsActive, &subject.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan subject: %w", err)
}
subjects = append(subjects, subject)
}
return subjects, nil
}
// ========================================
// Parent Onboarding
// ========================================
// GenerateParentOnboardingToken generates a QR code token for parent onboarding
func (s *SchoolService) GenerateParentOnboardingToken(ctx context.Context, schoolID, classID, studentID, createdByUserID uuid.UUID, role string) (*models.ParentOnboardingToken, error) {
// Generate secure random token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return nil, fmt.Errorf("failed to generate token: %w", err)
}
token := hex.EncodeToString(tokenBytes)
onboardingToken := &models.ParentOnboardingToken{
ID: uuid.New(),
SchoolID: schoolID,
ClassID: classID,
StudentID: studentID,
Token: token,
Role: role,
ExpiresAt: time.Now().Add(72 * time.Hour), // Valid for 72 hours
CreatedAt: time.Now(),
CreatedBy: createdByUserID,
}
query := `
INSERT INTO parent_onboarding_tokens (id, school_id, class_id, student_id, token, role, expires_at, created_at, created_by)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id`
err := s.db.Pool.QueryRow(ctx, query,
onboardingToken.ID, onboardingToken.SchoolID, onboardingToken.ClassID,
onboardingToken.StudentID, onboardingToken.Token, onboardingToken.Role,
onboardingToken.ExpiresAt, onboardingToken.CreatedAt, onboardingToken.CreatedBy,
).Scan(&onboardingToken.ID)
if err != nil {
return nil, fmt.Errorf("failed to create onboarding token: %w", err)
}
return onboardingToken, nil
}
// ValidateOnboardingToken validates and retrieves info for an onboarding token
func (s *SchoolService) ValidateOnboardingToken(ctx context.Context, token string) (*models.ParentOnboardingToken, error) {
query := `
SELECT id, school_id, class_id, student_id, token, role, expires_at, used_at, used_by_user_id, created_at, created_by
FROM parent_onboarding_tokens
WHERE token = $1 AND used_at IS NULL AND expires_at > NOW()`
onboardingToken := &models.ParentOnboardingToken{}
err := s.db.Pool.QueryRow(ctx, query, token).Scan(
&onboardingToken.ID, &onboardingToken.SchoolID, &onboardingToken.ClassID,
&onboardingToken.StudentID, &onboardingToken.Token, &onboardingToken.Role,
&onboardingToken.ExpiresAt, &onboardingToken.UsedAt, &onboardingToken.UsedByUserID,
&onboardingToken.CreatedAt, &onboardingToken.CreatedBy,
)
if err != nil {
return nil, fmt.Errorf("invalid or expired token: %w", err)
}
return onboardingToken, nil
}
// RedeemOnboardingToken marks a token as used and creates the parent account
func (s *SchoolService) RedeemOnboardingToken(ctx context.Context, token string, userID uuid.UUID) error {
query := `
UPDATE parent_onboarding_tokens
SET used_at = NOW(), used_by_user_id = $1
WHERE token = $2 AND used_at IS NULL AND expires_at > NOW()`
result, err := s.db.Pool.Exec(ctx, query, userID, token)
if err != nil {
return fmt.Errorf("failed to redeem token: %w", err)
}
if result.RowsAffected() == 0 {
return fmt.Errorf("token not found or already used")
}
return nil
}
// ========================================
// Helper Functions
// ========================================
func (s *SchoolService) createDefaultTimetableSlots(ctx context.Context, schoolID uuid.UUID) error {
slots := []struct {
Number int
StartTime string
EndTime string
IsBreak bool
Name string
}{
{1, "08:00", "08:45", false, "1. Stunde"},
{2, "08:45", "09:30", false, "2. Stunde"},
{3, "09:30", "09:50", true, "Erste Pause"},
{4, "09:50", "10:35", false, "3. Stunde"},
{5, "10:35", "11:20", false, "4. Stunde"},
{6, "11:20", "11:40", true, "Zweite Pause"},
{7, "11:40", "12:25", false, "5. Stunde"},
{8, "12:25", "13:10", false, "6. Stunde"},
{9, "13:10", "14:00", true, "Mittagspause"},
{10, "14:00", "14:45", false, "7. Stunde"},
{11, "14:45", "15:30", false, "8. Stunde"},
}
query := `
INSERT INTO timetable_slots (id, school_id, slot_number, start_time, end_time, is_break, name)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (school_id, slot_number) DO NOTHING`
for _, slot := range slots {
_, err := s.db.Pool.Exec(ctx, query,
uuid.New(), schoolID, slot.Number, slot.StartTime, slot.EndTime, slot.IsBreak, slot.Name,
)
if err != nil {
return err
}
}
return nil
}
func (s *SchoolService) createDefaultGradeScale(ctx context.Context, schoolID uuid.UUID) error {
query := `
INSERT INTO grade_scales (id, school_id, name, min_value, max_value, passing_value, is_ascending, is_default, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT DO NOTHING`
_, err := s.db.Pool.Exec(ctx, query,
uuid.New(), schoolID, "1-6 (Noten)", 1.0, 6.0, 4.0, false, true, time.Now(),
)
return err
}

View File

@@ -0,0 +1,424 @@
package services
import (
"testing"
"time"
"github.com/breakpilot/consent-service/internal/models"
"github.com/google/uuid"
)
// TestGenerateOnboardingToken tests the QR code token generation
func TestGenerateOnboardingToken(t *testing.T) {
tests := []struct {
name string
studentID uuid.UUID
createdBy uuid.UUID
role string
expectError bool
}{
{
name: "valid parent token",
studentID: uuid.New(),
createdBy: uuid.New(),
role: "parent",
expectError: false,
},
{
name: "valid parent_representative token",
studentID: uuid.New(),
createdBy: uuid.New(),
role: "parent_representative",
expectError: false,
},
{
name: "empty student ID",
studentID: uuid.Nil,
createdBy: uuid.New(),
role: "parent",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := &models.ParentOnboardingToken{
StudentID: tt.studentID,
CreatedBy: tt.createdBy,
Role: tt.role,
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
}
// Validate token fields
if tt.studentID == uuid.Nil && !tt.expectError {
t.Errorf("expected error for empty student ID")
}
if token.Role != "parent" && token.Role != "parent_representative" {
if !tt.expectError {
t.Errorf("invalid role: %s", token.Role)
}
}
// Check expiration is in the future
if token.ExpiresAt.Before(time.Now()) {
t.Errorf("token expiration should be in the future")
}
})
}
}
// TestValidateSchoolData tests school data validation
func TestValidateSchoolData(t *testing.T) {
address1 := "Musterstraße 1, 20095 Hamburg"
address2 := "Musterstraße 1"
address3 := "Parkweg 5"
tests := []struct {
name string
school models.School
expectValid bool
}{
{
name: "valid school",
school: models.School{
Name: "Testschule Hamburg",
Address: &address1,
Type: "gymnasium",
},
expectValid: true,
},
{
name: "empty name",
school: models.School{
Name: "",
Address: &address2,
Type: "gymnasium",
},
expectValid: false,
},
{
name: "valid grundschule",
school: models.School{
Name: "Grundschule Am Park",
Address: &address3,
Type: "grundschule",
},
expectValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateSchool(tt.school)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateSchool is a helper function for validation
func validateSchool(school models.School) bool {
if school.Name == "" {
return false
}
if school.Type == "" {
return false
}
return true
}
// TestValidateClassData tests class data validation
func TestValidateClassData(t *testing.T) {
tests := []struct {
name string
class models.Class
expectValid bool
}{
{
name: "valid class 5a",
class: models.Class{
Name: "5a",
Grade: 5,
SchoolID: uuid.New(),
SchoolYearID: uuid.New(),
},
expectValid: true,
},
{
name: "invalid grade level 0",
class: models.Class{
Name: "0a",
Grade: 0,
SchoolID: uuid.New(),
SchoolYearID: uuid.New(),
},
expectValid: false,
},
{
name: "invalid grade level 14",
class: models.Class{
Name: "14a",
Grade: 14,
SchoolID: uuid.New(),
SchoolYearID: uuid.New(),
},
expectValid: false,
},
{
name: "missing school ID",
class: models.Class{
Name: "5a",
Grade: 5,
SchoolID: uuid.Nil,
SchoolYearID: uuid.New(),
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateClass(tt.class)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateClass is a helper function for class validation
func validateClass(class models.Class) bool {
if class.Name == "" {
return false
}
if class.Grade < 1 || class.Grade > 13 {
return false
}
if class.SchoolID == uuid.Nil {
return false
}
if class.SchoolYearID == uuid.Nil {
return false
}
return true
}
// TestValidateStudentData tests student data validation
func TestValidateStudentData(t *testing.T) {
dob := time.Date(2014, 5, 15, 0, 0, 0, 0, time.UTC)
futureDob := time.Now().AddDate(1, 0, 0)
studentNum := "2024-001"
tests := []struct {
name string
student models.Student
expectValid bool
}{
{
name: "valid student",
student: models.Student{
FirstName: "Max",
LastName: "Mustermann",
DateOfBirth: &dob,
SchoolID: uuid.New(),
ClassID: uuid.New(),
StudentNumber: &studentNum,
},
expectValid: true,
},
{
name: "empty first name",
student: models.Student{
FirstName: "",
LastName: "Mustermann",
DateOfBirth: &dob,
SchoolID: uuid.New(),
ClassID: uuid.New(),
StudentNumber: &studentNum,
},
expectValid: false,
},
{
name: "future birth date",
student: models.Student{
FirstName: "Max",
LastName: "Mustermann",
DateOfBirth: &futureDob,
SchoolID: uuid.New(),
ClassID: uuid.New(),
StudentNumber: &studentNum,
},
expectValid: false,
},
{
name: "missing class ID",
student: models.Student{
FirstName: "Max",
LastName: "Mustermann",
DateOfBirth: &dob,
SchoolID: uuid.New(),
ClassID: uuid.Nil,
StudentNumber: &studentNum,
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateStudent(tt.student)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateStudent is a helper function for student validation
func validateStudent(student models.Student) bool {
if student.FirstName == "" || student.LastName == "" {
return false
}
if student.DateOfBirth != nil && student.DateOfBirth.After(time.Now()) {
return false
}
if student.ClassID == uuid.Nil {
return false
}
return true
}
// TestValidateTeacherData tests teacher data validation
func TestValidateTeacherData(t *testing.T) {
code := "SCH"
codeLong := "SCHMI"
tests := []struct {
name string
teacher models.Teacher
expectValid bool
}{
{
name: "valid teacher",
teacher: models.Teacher{
FirstName: "Anna",
LastName: "Schmidt",
UserID: uuid.New(),
TeacherCode: &code,
SchoolID: uuid.New(),
},
expectValid: true,
},
{
name: "empty first name",
teacher: models.Teacher{
FirstName: "",
LastName: "Schmidt",
UserID: uuid.New(),
TeacherCode: &code,
SchoolID: uuid.New(),
},
expectValid: false,
},
{
name: "teacher code too long",
teacher: models.Teacher{
FirstName: "Anna",
LastName: "Schmidt",
UserID: uuid.New(),
TeacherCode: &codeLong,
SchoolID: uuid.New(),
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateTeacher(tt.teacher)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateTeacher is a helper function for teacher validation
func validateTeacher(teacher models.Teacher) bool {
if teacher.FirstName == "" || teacher.LastName == "" {
return false
}
if teacher.TeacherCode != nil && len(*teacher.TeacherCode) > 4 {
return false
}
if teacher.SchoolID == uuid.Nil {
return false
}
return true
}
// TestSchoolYearValidation tests school year date validation
func TestSchoolYearValidation(t *testing.T) {
tests := []struct {
name string
year models.SchoolYear
expectValid bool
}{
{
name: "valid school year 2024/2025",
year: models.SchoolYear{
Name: "2024/2025",
StartDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 7, 31, 0, 0, 0, 0, time.UTC),
SchoolID: uuid.New(),
},
expectValid: true,
},
{
name: "end before start",
year: models.SchoolYear{
Name: "2024/2025",
StartDate: time.Date(2025, 8, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2024, 7, 31, 0, 0, 0, 0, time.UTC),
SchoolID: uuid.New(),
},
expectValid: false,
},
{
name: "same start and end",
year: models.SchoolYear{
Name: "2024/2025",
StartDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2024, 8, 1, 0, 0, 0, 0, time.UTC),
SchoolID: uuid.New(),
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := validateSchoolYear(tt.year)
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got valid=%v", tt.expectValid, isValid)
}
})
}
}
// validateSchoolYear is a helper function for school year validation
func validateSchoolYear(year models.SchoolYear) bool {
if year.Name == "" {
return false
}
if year.SchoolID == uuid.Nil {
return false
}
if !year.EndDate.After(year.StartDate) {
return false
}
return true
}

View File

@@ -0,0 +1,15 @@
package services
// ValidationError represents a validation error in tests
// This is a shared test helper type used across multiple test files
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
if e.Field != "" {
return e.Field + ": " + e.Message
}
return e.Message
}

View File

@@ -0,0 +1,485 @@
package services
import (
"bytes"
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"crypto/sha256"
"encoding/base32"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"image/png"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5/pgxpool"
qrcode "github.com/skip2/go-qrcode"
"github.com/breakpilot/consent-service/internal/models"
)
var (
ErrTOTPNotEnabled = errors.New("2FA is not enabled for this user")
ErrTOTPAlreadyEnabled = errors.New("2FA is already enabled for this user")
ErrTOTPInvalidCode = errors.New("invalid 2FA code")
ErrTOTPChallengeExpired = errors.New("2FA challenge expired")
ErrRecoveryCodeInvalid = errors.New("invalid recovery code")
ErrRecoveryCodeUsed = errors.New("recovery code already used")
)
const (
TOTPPeriod = 30 // TOTP period in seconds
TOTPDigits = 6 // Number of digits in TOTP code
TOTPSecretLen = 20 // Length of TOTP secret in bytes
RecoveryCodeCount = 10 // Number of recovery codes to generate
RecoveryCodeLen = 8 // Length of each recovery code
ChallengeExpiry = 5 * time.Minute // 2FA challenge expiry
)
// TOTPService handles Two-Factor Authentication using TOTP
type TOTPService struct {
db *pgxpool.Pool
issuer string
}
// NewTOTPService creates a new TOTPService
func NewTOTPService(db *pgxpool.Pool, issuer string) *TOTPService {
return &TOTPService{
db: db,
issuer: issuer,
}
}
// GenerateSecret generates a new TOTP secret
func (s *TOTPService) GenerateSecret() (string, error) {
secret := make([]byte, TOTPSecretLen)
if _, err := rand.Read(secret); err != nil {
return "", fmt.Errorf("failed to generate secret: %w", err)
}
return base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secret), nil
}
// GenerateRecoveryCodes generates a set of recovery codes
func (s *TOTPService) GenerateRecoveryCodes() ([]string, error) {
codes := make([]string, RecoveryCodeCount)
for i := 0; i < RecoveryCodeCount; i++ {
codeBytes := make([]byte, RecoveryCodeLen/2)
if _, err := rand.Read(codeBytes); err != nil {
return nil, fmt.Errorf("failed to generate recovery code: %w", err)
}
codes[i] = strings.ToUpper(hex.EncodeToString(codeBytes))
}
return codes, nil
}
// GenerateQRCode generates a QR code for TOTP setup
func (s *TOTPService) GenerateQRCode(secret, email string) (string, error) {
// Create otpauth URL
otpauthURL := fmt.Sprintf("otpauth://totp/%s:%s?secret=%s&issuer=%s&algorithm=SHA1&digits=%d&period=%d",
s.issuer, email, secret, s.issuer, TOTPDigits, TOTPPeriod)
// Generate QR code
qr, err := qrcode.New(otpauthURL, qrcode.Medium)
if err != nil {
return "", fmt.Errorf("failed to generate QR code: %w", err)
}
// Convert to PNG
var buf bytes.Buffer
if err := png.Encode(&buf, qr.Image(256)); err != nil {
return "", fmt.Errorf("failed to encode QR code: %w", err)
}
// Convert to data URL
dataURL := fmt.Sprintf("data:image/png;base64,%s", base64.StdEncoding.EncodeToString(buf.Bytes()))
return dataURL, nil
}
// GenerateTOTP generates the current TOTP code for a secret
func (s *TOTPService) GenerateTOTP(secret string) (string, error) {
return s.GenerateTOTPAt(secret, time.Now())
}
// GenerateTOTPAt generates a TOTP code for a specific time
func (s *TOTPService) GenerateTOTPAt(secret string, t time.Time) (string, error) {
// Decode secret
secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(secret))
if err != nil {
return "", fmt.Errorf("invalid secret: %w", err)
}
// Calculate counter
counter := uint64(t.Unix()) / TOTPPeriod
// Generate HOTP
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, counter)
mac := hmac.New(sha1.New, secretBytes)
mac.Write(buf)
hash := mac.Sum(nil)
// Dynamic truncation
offset := hash[len(hash)-1] & 0x0f
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
// Format code
codeStr := fmt.Sprintf("%0*d", TOTPDigits, code%1000000)
return codeStr, nil
}
// ValidateTOTP validates a TOTP code (allows 1 period drift)
func (s *TOTPService) ValidateTOTP(secret, code string) bool {
now := time.Now()
// Check current, previous, and next period
for _, offset := range []int{0, -1, 1} {
t := now.Add(time.Duration(offset*TOTPPeriod) * time.Second)
expected, err := s.GenerateTOTPAt(secret, t)
if err == nil && expected == code {
return true
}
}
return false
}
// Setup2FA initiates 2FA setup for a user
func (s *TOTPService) Setup2FA(ctx context.Context, userID uuid.UUID, email string) (*models.Setup2FAResponse, error) {
// Check if 2FA is already enabled
var exists bool
err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM user_totp WHERE user_id = $1 AND verified = true)`, userID).Scan(&exists)
if err == nil && exists {
return nil, ErrTOTPAlreadyEnabled
}
// Generate secret
secret, err := s.GenerateSecret()
if err != nil {
return nil, err
}
// Generate recovery codes
recoveryCodes, err := s.GenerateRecoveryCodes()
if err != nil {
return nil, err
}
// Generate QR code
qrCode, err := s.GenerateQRCode(secret, email)
if err != nil {
return nil, err
}
// Hash recovery codes for storage
hashedCodes := make([]string, len(recoveryCodes))
for i, code := range recoveryCodes {
hash := sha256.Sum256([]byte(code))
hashedCodes[i] = hex.EncodeToString(hash[:])
}
recoveryCodesJSON, _ := json.Marshal(hashedCodes)
// Store or update TOTP record (unverified)
_, err = s.db.Exec(ctx, `
INSERT INTO user_totp (user_id, secret, verified, recovery_codes, created_at, updated_at)
VALUES ($1, $2, false, $3, NOW(), NOW())
ON CONFLICT (user_id) DO UPDATE SET
secret = $2,
verified = false,
recovery_codes = $3,
updated_at = NOW()
`, userID, secret, recoveryCodesJSON)
if err != nil {
return nil, fmt.Errorf("failed to store TOTP: %w", err)
}
return &models.Setup2FAResponse{
Secret: secret,
QRCodeDataURL: qrCode,
RecoveryCodes: recoveryCodes,
}, nil
}
// Verify2FASetup verifies the 2FA setup with a code
func (s *TOTPService) Verify2FASetup(ctx context.Context, userID uuid.UUID, code string) error {
// Get TOTP record
var secret string
var verified bool
err := s.db.QueryRow(ctx, `SELECT secret, verified FROM user_totp WHERE user_id = $1`, userID).Scan(&secret, &verified)
if err != nil {
return ErrTOTPNotEnabled
}
if verified {
return ErrTOTPAlreadyEnabled
}
// Validate code
if !s.ValidateTOTP(secret, code) {
return ErrTOTPInvalidCode
}
// Mark as verified and enable 2FA
_, err = s.db.Exec(ctx, `
UPDATE user_totp SET verified = true, enabled_at = NOW(), updated_at = NOW() WHERE user_id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to verify TOTP: %w", err)
}
// Update user record
_, err = s.db.Exec(ctx, `
UPDATE users SET two_factor_enabled = true, two_factor_verified_at = NOW(), updated_at = NOW() WHERE id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// CreateChallenge creates a 2FA challenge for login
func (s *TOTPService) CreateChallenge(ctx context.Context, userID uuid.UUID, ipAddress, userAgent string) (string, error) {
// Generate challenge ID
challengeBytes := make([]byte, 32)
if _, err := rand.Read(challengeBytes); err != nil {
return "", fmt.Errorf("failed to generate challenge: %w", err)
}
challengeID := base64.URLEncoding.EncodeToString(challengeBytes)
// Store challenge
_, err := s.db.Exec(ctx, `
INSERT INTO two_factor_challenges (user_id, challenge_id, ip_address, user_agent, expires_at, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())
`, userID, challengeID, ipAddress, userAgent, time.Now().Add(ChallengeExpiry))
if err != nil {
return "", fmt.Errorf("failed to create challenge: %w", err)
}
return challengeID, nil
}
// VerifyChallenge verifies a 2FA challenge with a TOTP code
func (s *TOTPService) VerifyChallenge(ctx context.Context, challengeID, code string) (*uuid.UUID, error) {
var challenge models.TwoFactorChallenge
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1
`, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt)
if err != nil {
return nil, ErrInvalidToken
}
if challenge.UsedAt != nil {
return nil, ErrInvalidToken
}
if time.Now().After(challenge.ExpiresAt) {
return nil, ErrTOTPChallengeExpired
}
// Get TOTP secret
var secret string
err = s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&secret)
if err != nil {
return nil, ErrTOTPNotEnabled
}
// Validate TOTP code
if !s.ValidateTOTP(secret, code) {
return nil, ErrTOTPInvalidCode
}
// Mark challenge as used
_, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID)
if err != nil {
return nil, fmt.Errorf("failed to mark challenge as used: %w", err)
}
// Update last used time
_, _ = s.db.Exec(ctx, `UPDATE user_totp SET last_used_at = NOW() WHERE user_id = $1`, challenge.UserID)
return &challenge.UserID, nil
}
// VerifyChallengeWithRecoveryCode verifies a 2FA challenge with a recovery code
func (s *TOTPService) VerifyChallengeWithRecoveryCode(ctx context.Context, challengeID, recoveryCode string) (*uuid.UUID, error) {
var challenge models.TwoFactorChallenge
err := s.db.QueryRow(ctx, `
SELECT id, user_id, expires_at, used_at FROM two_factor_challenges WHERE challenge_id = $1
`, challengeID).Scan(&challenge.ID, &challenge.UserID, &challenge.ExpiresAt, &challenge.UsedAt)
if err != nil {
return nil, ErrInvalidToken
}
if challenge.UsedAt != nil {
return nil, ErrInvalidToken
}
if time.Now().After(challenge.ExpiresAt) {
return nil, ErrTOTPChallengeExpired
}
// Get recovery codes
var recoveryCodesJSON []byte
err = s.db.QueryRow(ctx, `SELECT recovery_codes FROM user_totp WHERE user_id = $1 AND verified = true`, challenge.UserID).Scan(&recoveryCodesJSON)
if err != nil {
return nil, ErrTOTPNotEnabled
}
var hashedCodes []string
json.Unmarshal(recoveryCodesJSON, &hashedCodes)
// Hash the provided recovery code
codeHash := sha256.Sum256([]byte(strings.ToUpper(recoveryCode)))
codeHashStr := hex.EncodeToString(codeHash[:])
// Find and remove the recovery code
found := false
newCodes := make([]string, 0, len(hashedCodes)-1)
for _, hc := range hashedCodes {
if hc == codeHashStr && !found {
found = true
continue
}
newCodes = append(newCodes, hc)
}
if !found {
return nil, ErrRecoveryCodeInvalid
}
// Update recovery codes
newCodesJSON, _ := json.Marshal(newCodes)
_, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, newCodesJSON, challenge.UserID)
if err != nil {
return nil, fmt.Errorf("failed to update recovery codes: %w", err)
}
// Mark challenge as used
_, err = s.db.Exec(ctx, `UPDATE two_factor_challenges SET used_at = NOW() WHERE id = $1`, challenge.ID)
if err != nil {
return nil, fmt.Errorf("failed to mark challenge as used: %w", err)
}
return &challenge.UserID, nil
}
// Disable2FA disables 2FA for a user
func (s *TOTPService) Disable2FA(ctx context.Context, userID uuid.UUID, code string) error {
// Get TOTP secret
var secret string
err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret)
if err != nil {
return ErrTOTPNotEnabled
}
// Validate code
if !s.ValidateTOTP(secret, code) {
return ErrTOTPInvalidCode
}
// Delete TOTP record
_, err = s.db.Exec(ctx, `DELETE FROM user_totp WHERE user_id = $1`, userID)
if err != nil {
return fmt.Errorf("failed to delete TOTP: %w", err)
}
// Update user record
_, err = s.db.Exec(ctx, `
UPDATE users SET two_factor_enabled = false, two_factor_verified_at = NULL, updated_at = NOW() WHERE id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
return nil
}
// GetStatus returns the 2FA status for a user
func (s *TOTPService) GetStatus(ctx context.Context, userID uuid.UUID) (*models.TwoFactorStatusResponse, error) {
var totp models.UserTOTP
var recoveryCodesJSON []byte
err := s.db.QueryRow(ctx, `
SELECT id, verified, enabled_at, recovery_codes FROM user_totp WHERE user_id = $1
`, userID).Scan(&totp.ID, &totp.Verified, &totp.EnabledAt, &recoveryCodesJSON)
if err != nil {
// 2FA not set up
return &models.TwoFactorStatusResponse{
Enabled: false,
Verified: false,
RecoveryCodesCount: 0,
}, nil
}
var hashedCodes []string
json.Unmarshal(recoveryCodesJSON, &hashedCodes)
return &models.TwoFactorStatusResponse{
Enabled: totp.Verified,
Verified: totp.Verified,
EnabledAt: totp.EnabledAt,
RecoveryCodesCount: len(hashedCodes),
}, nil
}
// RegenerateRecoveryCodes generates new recovery codes (requires current TOTP code)
func (s *TOTPService) RegenerateRecoveryCodes(ctx context.Context, userID uuid.UUID, code string) ([]string, error) {
// Get TOTP secret
var secret string
err := s.db.QueryRow(ctx, `SELECT secret FROM user_totp WHERE user_id = $1 AND verified = true`, userID).Scan(&secret)
if err != nil {
return nil, ErrTOTPNotEnabled
}
// Validate code
if !s.ValidateTOTP(secret, code) {
return nil, ErrTOTPInvalidCode
}
// Generate new recovery codes
recoveryCodes, err := s.GenerateRecoveryCodes()
if err != nil {
return nil, err
}
// Hash recovery codes for storage
hashedCodes := make([]string, len(recoveryCodes))
for i, rc := range recoveryCodes {
hash := sha256.Sum256([]byte(rc))
hashedCodes[i] = hex.EncodeToString(hash[:])
}
recoveryCodesJSON, _ := json.Marshal(hashedCodes)
// Update recovery codes
_, err = s.db.Exec(ctx, `UPDATE user_totp SET recovery_codes = $1, updated_at = NOW() WHERE user_id = $2`, recoveryCodesJSON, userID)
if err != nil {
return nil, fmt.Errorf("failed to update recovery codes: %w", err)
}
return recoveryCodes, nil
}
// IsTwoFactorEnabled checks if 2FA is enabled for a user
func (s *TOTPService) IsTwoFactorEnabled(ctx context.Context, userID uuid.UUID) (bool, error) {
var enabled bool
err := s.db.QueryRow(ctx, `SELECT two_factor_enabled FROM users WHERE id = $1`, userID).Scan(&enabled)
if err != nil {
return false, err
}
return enabled, nil
}

View File

@@ -0,0 +1,378 @@
package services
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base32"
"encoding/binary"
"encoding/hex"
"strings"
"testing"
"time"
)
// TestTOTPGeneration tests TOTP code generation
func TestTOTPGeneration_ValidSecret(t *testing.T) {
// Test secret (Base32 encoded)
secret := "JBSWY3DPEHPK3PXP" // This is "Hello!" in Base32
// Decode secret
secretBytes, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
if err != nil {
t.Fatalf("Failed to decode secret: %v", err)
}
// Generate TOTP for current time
now := time.Now()
counter := uint64(now.Unix()) / 30
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, counter)
mac := hmac.New(sha1.New, secretBytes)
mac.Write(buf)
hash := mac.Sum(nil)
// Dynamic truncation
offset := hash[len(hash)-1] & 0x0f
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
totpCode := code % 1000000
// Check that code is 6 digits
if totpCode < 0 || totpCode > 999999 {
t.Errorf("TOTP code should be 6 digits, got %d", totpCode)
}
}
// TestTOTPGeneration_SameTimeProducesSameCode tests deterministic generation
func TestTOTPGeneration_Deterministic(t *testing.T) {
secret := "JBSWY3DPEHPK3PXP"
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
fixedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
code1 := generateTOTPAt(secretBytes, fixedTime)
code2 := generateTOTPAt(secretBytes, fixedTime)
if code1 != code2 {
t.Errorf("Same time should produce same code: got %s and %s", code1, code2)
}
}
// TestTOTPGeneration_DifferentTimesProduceDifferentCodes tests time sensitivity
func TestTOTPGeneration_TimeSensitive(t *testing.T) {
secret := "JBSWY3DPEHPK3PXP"
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
time1 := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
time2 := time1.Add(30 * time.Second) // Next TOTP period
code1 := generateTOTPAt(secretBytes, time1)
code2 := generateTOTPAt(secretBytes, time2)
if code1 == code2 {
t.Error("Different TOTP periods should produce different codes")
}
}
// Helper function for TOTP generation at specific time
func generateTOTPAt(secretBytes []byte, t time.Time) string {
counter := uint64(t.Unix()) / 30
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, counter)
mac := hmac.New(sha1.New, secretBytes)
mac.Write(buf)
hash := mac.Sum(nil)
offset := hash[len(hash)-1] & 0x0f
code := binary.BigEndian.Uint32(hash[offset:offset+4]) & 0x7fffffff
return padCode(code % 1000000)
}
func padCode(code uint32) string {
s := ""
for i := 0; i < 6; i++ {
s = string(rune('0'+code%10)) + s
code /= 10
}
return s
}
// TestTOTPValidation_WithDrift tests validation with clock drift allowance
func TestTOTPValidation_WithDrift(t *testing.T) {
secret := "JBSWY3DPEHPK3PXP"
secretBytes, _ := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
now := time.Now()
// Generate current code
currentCode := generateTOTPAt(secretBytes, now)
// Generate previous period code
previousCode := generateTOTPAt(secretBytes, now.Add(-30*time.Second))
// Generate next period code
nextCode := generateTOTPAt(secretBytes, now.Add(30*time.Second))
// All three should be valid for current validation (allowing 1 period drift)
validCodes := []string{currentCode, previousCode, nextCode}
for _, code := range validCodes {
isValid := validateTOTPWithDrift(secretBytes, code, now)
if !isValid {
t.Errorf("Code %s should be valid with drift allowance", code)
}
}
}
// validateTOTPWithDrift validates a TOTP code allowing for clock drift
func validateTOTPWithDrift(secretBytes []byte, code string, now time.Time) bool {
for _, offset := range []int{0, -1, 1} {
t := now.Add(time.Duration(offset*30) * time.Second)
expected := generateTOTPAt(secretBytes, t)
if expected == code {
return true
}
}
return false
}
// TestRecoveryCodeGeneration tests recovery code format
func TestRecoveryCodeGeneration_Format(t *testing.T) {
// Simulate recovery code generation
codeBytes := make([]byte, 4) // 8 hex chars = 4 bytes
for i := range codeBytes {
codeBytes[i] = byte(i + 1) // Deterministic for testing
}
code := strings.ToUpper(hex.EncodeToString(codeBytes))
// Check format
if len(code) != 8 {
t.Errorf("Recovery code should be 8 characters, got %d", len(code))
}
// Check uppercase
if code != strings.ToUpper(code) {
t.Error("Recovery code should be uppercase")
}
// Check alphanumeric (hex only contains 0-9 and A-F)
for _, c := range code {
if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'F')) {
t.Errorf("Recovery code should only contain hex characters, found '%c'", c)
}
}
}
// TestRecoveryCodeHashing tests that recovery codes are hashed for storage
func TestRecoveryCodeHashing_Consistency(t *testing.T) {
code := "ABCD1234"
hash1 := sha256.Sum256([]byte(code))
hash2 := sha256.Sum256([]byte(code))
if hash1 != hash2 {
t.Error("Recovery code hashing should be consistent")
}
}
func TestRecoveryCodeHashing_CaseInsensitive(t *testing.T) {
code1 := "ABCD1234"
code2 := "abcd1234"
hash1 := sha256.Sum256([]byte(strings.ToUpper(code1)))
hash2 := sha256.Sum256([]byte(strings.ToUpper(code2)))
if hash1 != hash2 {
t.Error("Recovery codes should be case-insensitive when normalized to uppercase")
}
}
// TestSecretGeneration tests that secrets are valid Base32
func TestSecretGeneration_ValidBase32(t *testing.T) {
// Simulate secret generation (20 bytes -> Base32 without padding)
secretBytes := make([]byte, 20)
for i := range secretBytes {
secretBytes[i] = byte(i * 13) // Deterministic for testing
}
secret := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(secretBytes)
// Verify it can be decoded
decoded, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret)
if err != nil {
t.Errorf("Generated secret should be valid Base32: %v", err)
}
if len(decoded) != 20 {
t.Errorf("Decoded secret should be 20 bytes, got %d", len(decoded))
}
}
// TestQRCodeOtpauthURL tests otpauth URL format
func TestQRCodeOtpauthURL_Format(t *testing.T) {
issuer := "BreakPilot"
email := "test@example.com"
secret := "JBSWY3DPEHPK3PXP"
period := 30
digits := 6
url := "otpauth://totp/" + issuer + ":" + email +
"?secret=" + secret +
"&issuer=" + issuer +
"&algorithm=SHA1" +
"&digits=" + string(rune('0'+digits)) +
"&period=" + string(rune('0'+period/10)) + string(rune('0'+period%10))
// Check URL starts with otpauth://totp/
if !strings.HasPrefix(url, "otpauth://totp/") {
t.Error("OTP auth URL should start with otpauth://totp/")
}
// Check contains required parameters
if !strings.Contains(url, "secret=") {
t.Error("OTP auth URL should contain secret parameter")
}
if !strings.Contains(url, "issuer=") {
t.Error("OTP auth URL should contain issuer parameter")
}
}
// TestChallengeExpiry tests 2FA challenge expiration
func TestChallengeExpiry_Logic(t *testing.T) {
tests := []struct {
name string
expiryMins int
usedAfter int // minutes after creation
shouldAllow bool
}{
{
name: "challenge used within expiry",
expiryMins: 5,
usedAfter: 2,
shouldAllow: true,
},
{
name: "challenge used at expiry",
expiryMins: 5,
usedAfter: 5,
shouldAllow: false, // Expired
},
{
name: "challenge used after expiry",
expiryMins: 5,
usedAfter: 10,
shouldAllow: false,
},
{
name: "challenge used immediately",
expiryMins: 5,
usedAfter: 0,
shouldAllow: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tt.usedAfter < tt.expiryMins
if isValid != tt.shouldAllow {
t.Errorf("Expected allow=%v for challenge used after %d mins (expiry: %d mins)",
tt.shouldAllow, tt.usedAfter, tt.expiryMins)
}
})
}
}
// TestRecoveryCodeOneTimeUse tests that recovery codes can only be used once
func TestRecoveryCodeOneTimeUse(t *testing.T) {
initialCodes := []string{
sha256Hash("CODE0001"),
sha256Hash("CODE0002"),
sha256Hash("CODE0003"),
}
// Use CODE0002
usedCodeHash := sha256Hash("CODE0002")
// Remove used code from list
var remainingCodes []string
for _, code := range initialCodes {
if code != usedCodeHash {
remainingCodes = append(remainingCodes, code)
}
}
if len(remainingCodes) != 2 {
t.Errorf("Should have 2 remaining codes after using one, got %d", len(remainingCodes))
}
// Verify used code is not in remaining
for _, code := range remainingCodes {
if code == usedCodeHash {
t.Error("Used recovery code should be removed from list")
}
}
}
func sha256Hash(s string) string {
h := sha256.Sum256([]byte(s))
return hex.EncodeToString(h[:])
}
// TestTwoFactorEnableFlow tests the 2FA enable workflow
func TestTwoFactorEnableFlow_States(t *testing.T) {
tests := []struct {
name string
initialState bool // verified
action string
expectedState bool
}{
{
name: "fresh user - not verified",
initialState: false,
action: "none",
expectedState: false,
},
{
name: "user verifies 2FA",
initialState: false,
action: "verify",
expectedState: true,
},
{
name: "already verified - stays verified",
initialState: true,
action: "verify",
expectedState: true,
},
{
name: "user disables 2FA",
initialState: true,
action: "disable",
expectedState: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
state := tt.initialState
switch tt.action {
case "verify":
state = true
case "disable":
state = false
}
if state != tt.expectedState {
t.Errorf("Expected state=%v after action '%s', got state=%v",
tt.expectedState, tt.action, state)
}
})
}
}